1use crate::error::{Result, TorError};
7use crate::types::MAX_MESSAGE_SIZE;
8use serde::{de::DeserializeOwned, Serialize};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11pub fn encode<T: Serialize>(msg: &T) -> Result<Vec<u8>> {
13 let mut cbor_buf = Vec::new();
14 ciborium::into_writer(msg, &mut cbor_buf)
15 .map_err(|e| TorError::Serialization(e.to_string()))?;
16
17 if cbor_buf.len() > MAX_MESSAGE_SIZE as usize {
18 return Err(TorError::MessageTooLarge {
19 size: cbor_buf.len(),
20 max: MAX_MESSAGE_SIZE as usize,
21 });
22 }
23
24 let len = cbor_buf.len() as u32;
25 let mut wire = Vec::with_capacity(4 + cbor_buf.len());
26 wire.extend_from_slice(&len.to_be_bytes());
27 wire.extend_from_slice(&cbor_buf);
28 Ok(wire)
29}
30
31pub fn decode<T: DeserializeOwned>(data: &[u8]) -> Result<T> {
33 if data.len() < 4 {
34 return Err(TorError::Serialization(
35 "Message too short for length prefix".to_string(),
36 ));
37 }
38
39 let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
40
41 if len > MAX_MESSAGE_SIZE as usize {
42 return Err(TorError::MessageTooLarge {
43 size: len,
44 max: MAX_MESSAGE_SIZE as usize,
45 });
46 }
47
48 if data.len() < 4 + len {
49 return Err(TorError::Serialization(format!(
50 "Incomplete message: expected {} bytes, got {}",
51 4 + len,
52 data.len()
53 )));
54 }
55
56 ciborium::from_reader(&data[4..4 + len])
57 .map_err(|e| TorError::Serialization(format!("CBOR decode failed: {e}")))
58}
59
60pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<()>
62where
63 W: AsyncWrite + Unpin,
64 T: Serialize,
65{
66 let wire = encode(msg)?;
67 writer.write_all(&wire).await.map_err(TorError::Io)?;
68 writer.flush().await.map_err(TorError::Io)?;
69 Ok(())
70}
71
72pub async fn read_message<R, T>(reader: &mut R) -> Result<T>
74where
75 R: AsyncRead + Unpin,
76 T: DeserializeOwned,
77{
78 let mut len_buf = [0u8; 4];
80 reader
81 .read_exact(&mut len_buf)
82 .await
83 .map_err(TorError::Io)?;
84 let len = u32::from_be_bytes(len_buf) as usize;
85
86 if len == 0 {
88 return Err(TorError::Keepalive);
89 }
90
91 if len > MAX_MESSAGE_SIZE as usize {
92 return Err(TorError::MessageTooLarge {
93 size: len,
94 max: MAX_MESSAGE_SIZE as usize,
95 });
96 }
97
98 let mut payload = vec![0u8; len];
100 reader
101 .read_exact(&mut payload)
102 .await
103 .map_err(TorError::Io)?;
104
105 ciborium::from_reader(payload.as_slice())
106 .map_err(|e| TorError::Serialization(format!("CBOR decode failed: {e}")))
107}
108
109pub async fn write_keepalive<W: AsyncWrite + Unpin>(writer: &mut W) -> Result<()> {
111 writer
112 .write_all(&0u32.to_be_bytes())
113 .await
114 .map_err(TorError::Io)?;
115 writer.flush().await.map_err(TorError::Io)?;
116 Ok(())
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::types::{RemoteSearchResult, SearchRequest, SearchResponse, SearchStatus};
123
124 #[test]
125 fn test_encode_decode_roundtrip() {
126 let msg = vec![1u32, 2, 3, 4, 5];
127 let wire = encode(&msg).unwrap();
128
129 let len = u32::from_be_bytes([wire[0], wire[1], wire[2], wire[3]]);
131 assert_eq!(len as usize, wire.len() - 4);
132
133 let decoded: Vec<u32> = decode(&wire).unwrap();
134 assert_eq!(decoded, msg);
135 }
136
137 #[test]
138 fn test_encode_search_request() {
139 let req = SearchRequest {
140 request_id: [1u8; 16],
141 query_embedding: vec![100, -200, 300],
142 query_text: None,
143 max_results: 10,
144 include_proofs: false,
145 model_hash: [2u8; 32],
146 timestamp: 1000,
147 signature: [3u8; 64],
148 public_key: [4u8; 32],
149 };
150
151 let wire = encode(&req).unwrap();
152 let decoded: SearchRequest = decode(&wire).unwrap();
153 assert_eq!(decoded.request_id, req.request_id);
154 assert_eq!(decoded.query_embedding, req.query_embedding);
155 }
156
157 #[test]
158 fn test_encode_search_response() {
159 let resp = SearchResponse {
160 request_id: [1u8; 16],
161 status: SearchStatus::Ok,
162 results: vec![RemoteSearchResult {
163 chunk_id: [2u8; 16],
164 chunk_text: "hello".to_string(),
165 document_path: "test.md".to_string(),
166 score: 42,
167 merkle_proof: None,
168 }],
169 peer_state_root: [3u8; 32],
170 search_latency_ms: 100,
171 timestamp: 2000,
172 signature: [4u8; 64],
173 };
174
175 let wire = encode(&resp).unwrap();
176 let decoded: SearchResponse = decode(&wire).unwrap();
177 assert_eq!(decoded.results.len(), 1);
178 assert_eq!(decoded.results[0].chunk_text, "hello");
179 }
180
181 #[test]
182 fn test_message_too_large() {
183 let big = vec![0u8; MAX_MESSAGE_SIZE as usize + 100];
185 let result = encode(&big);
186 assert!(result.is_err());
187 }
188
189 #[test]
190 fn test_decode_truncated() {
191 let msg = vec![1u32, 2, 3];
192 let wire = encode(&msg).unwrap();
193
194 let result: std::result::Result<Vec<u32>, _> = decode(&wire[..wire.len() - 2]);
196 assert!(result.is_err());
197 }
198
199 #[test]
200 fn test_decode_empty() {
201 let result: std::result::Result<Vec<u32>, _> = decode(&[]);
202 assert!(result.is_err());
203 }
204
205 #[tokio::test]
206 async fn test_async_write_read_roundtrip() {
207 let req = SearchRequest {
208 request_id: [5u8; 16],
209 query_embedding: vec![1, 2, 3, 4, 5],
210 query_text: Some("test".to_string()),
211 max_results: 5,
212 include_proofs: true,
213 model_hash: [6u8; 32],
214 timestamp: 3000,
215 signature: [7u8; 64],
216 public_key: [8u8; 32],
217 };
218
219 let mut buf = Vec::new();
221 write_message(&mut buf, &req).await.unwrap();
222
223 let mut cursor = &buf[..];
225 let decoded: SearchRequest = read_message(&mut cursor).await.unwrap();
226 assert_eq!(decoded.request_id, req.request_id);
227 assert_eq!(decoded.query_text, Some("test".to_string()));
228 }
229
230 #[tokio::test]
231 async fn test_keepalive_roundtrip() {
232 let mut buf = Vec::new();
233 write_keepalive(&mut buf).await.unwrap();
234
235 assert_eq!(buf.len(), 4);
236 assert_eq!(buf, vec![0, 0, 0, 0]);
237
238 let mut cursor = &buf[..];
240 let result: std::result::Result<SearchRequest, _> = read_message(&mut cursor).await;
241 assert!(matches!(result, Err(TorError::Keepalive)));
242 }
243
244 #[tokio::test]
245 async fn test_multiple_messages_on_stream() {
246 let msg1 = "hello".to_string();
247 let msg2 = "world".to_string();
248 let msg3 = vec![1u32, 2, 3];
249
250 let mut buf = Vec::new();
251 write_message(&mut buf, &msg1).await.unwrap();
252 write_message(&mut buf, &msg2).await.unwrap();
253 write_message(&mut buf, &msg3).await.unwrap();
254
255 let mut cursor = &buf[..];
256 let d1: String = read_message(&mut cursor).await.unwrap();
257 let d2: String = read_message(&mut cursor).await.unwrap();
258 let d3: Vec<u32> = read_message(&mut cursor).await.unwrap();
259
260 assert_eq!(d1, "hello");
261 assert_eq!(d2, "world");
262 assert_eq!(d3, vec![1, 2, 3]);
263 }
264}