Skip to main content

cp_tor/
wire.rs

1//! Length-prefixed wire format for Canon search protocol messages.
2//!
3//! All messages use a 4-byte big-endian u32 length prefix followed by
4//! a CBOR-encoded payload. Maximum message size is 1MB.
5
6use crate::error::{Result, TorError};
7use crate::types::MAX_MESSAGE_SIZE;
8use serde::{de::DeserializeOwned, Serialize};
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11/// Encode a message to length-prefixed CBOR wire format.
12pub fn encode<T: Serialize>(msg: &T) -> Result<Vec<u8>> {
13    let mut cbor_buf = Vec::new();
14    ciborium::into_writer(msg, &mut cbor_buf)
15        .map_err(|e| TorError::Serialization(e.to_string()))?;
16
17    if cbor_buf.len() > MAX_MESSAGE_SIZE as usize {
18        return Err(TorError::MessageTooLarge {
19            size: cbor_buf.len(),
20            max: MAX_MESSAGE_SIZE as usize,
21        });
22    }
23
24    let len = cbor_buf.len() as u32;
25    let mut wire = Vec::with_capacity(4 + cbor_buf.len());
26    wire.extend_from_slice(&len.to_be_bytes());
27    wire.extend_from_slice(&cbor_buf);
28    Ok(wire)
29}
30
31/// Decode a message from length-prefixed CBOR wire format.
32pub fn decode<T: DeserializeOwned>(data: &[u8]) -> Result<T> {
33    if data.len() < 4 {
34        return Err(TorError::Serialization(
35            "Message too short for length prefix".to_string(),
36        ));
37    }
38
39    let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]) as usize;
40
41    if len > MAX_MESSAGE_SIZE as usize {
42        return Err(TorError::MessageTooLarge {
43            size: len,
44            max: MAX_MESSAGE_SIZE as usize,
45        });
46    }
47
48    if data.len() < 4 + len {
49        return Err(TorError::Serialization(format!(
50            "Incomplete message: expected {} bytes, got {}",
51            4 + len,
52            data.len()
53        )));
54    }
55
56    ciborium::from_reader(&data[4..4 + len])
57        .map_err(|e| TorError::Serialization(format!("CBOR decode failed: {e}")))
58}
59
60/// Write a length-prefixed CBOR message to an async writer.
61pub async fn write_message<W, T>(writer: &mut W, msg: &T) -> Result<()>
62where
63    W: AsyncWrite + Unpin,
64    T: Serialize,
65{
66    let wire = encode(msg)?;
67    writer.write_all(&wire).await.map_err(TorError::Io)?;
68    writer.flush().await.map_err(TorError::Io)?;
69    Ok(())
70}
71
72/// Read a length-prefixed CBOR message from an async reader.
73pub async fn read_message<R, T>(reader: &mut R) -> Result<T>
74where
75    R: AsyncRead + Unpin,
76    T: DeserializeOwned,
77{
78    // Read 4-byte length prefix
79    let mut len_buf = [0u8; 4];
80    reader
81        .read_exact(&mut len_buf)
82        .await
83        .map_err(TorError::Io)?;
84    let len = u32::from_be_bytes(len_buf) as usize;
85
86    // Zero-length message is a keepalive probe
87    if len == 0 {
88        return Err(TorError::Keepalive);
89    }
90
91    if len > MAX_MESSAGE_SIZE as usize {
92        return Err(TorError::MessageTooLarge {
93            size: len,
94            max: MAX_MESSAGE_SIZE as usize,
95        });
96    }
97
98    // Read payload
99    let mut payload = vec![0u8; len];
100    reader
101        .read_exact(&mut payload)
102        .await
103        .map_err(TorError::Io)?;
104
105    ciborium::from_reader(payload.as_slice())
106        .map_err(|e| TorError::Serialization(format!("CBOR decode failed: {e}")))
107}
108
109/// Send a keepalive probe (zero-length message).
110pub async fn write_keepalive<W: AsyncWrite + Unpin>(writer: &mut W) -> Result<()> {
111    writer
112        .write_all(&0u32.to_be_bytes())
113        .await
114        .map_err(TorError::Io)?;
115    writer.flush().await.map_err(TorError::Io)?;
116    Ok(())
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::types::{RemoteSearchResult, SearchRequest, SearchResponse, SearchStatus};
123
124    #[test]
125    fn test_encode_decode_roundtrip() {
126        let msg = vec![1u32, 2, 3, 4, 5];
127        let wire = encode(&msg).unwrap();
128
129        // First 4 bytes are length prefix
130        let len = u32::from_be_bytes([wire[0], wire[1], wire[2], wire[3]]);
131        assert_eq!(len as usize, wire.len() - 4);
132
133        let decoded: Vec<u32> = decode(&wire).unwrap();
134        assert_eq!(decoded, msg);
135    }
136
137    #[test]
138    fn test_encode_search_request() {
139        let req = SearchRequest {
140            request_id: [1u8; 16],
141            query_embedding: vec![100, -200, 300],
142            query_text: None,
143            max_results: 10,
144            include_proofs: false,
145            model_hash: [2u8; 32],
146            timestamp: 1000,
147            signature: [3u8; 64],
148            public_key: [4u8; 32],
149        };
150
151        let wire = encode(&req).unwrap();
152        let decoded: SearchRequest = decode(&wire).unwrap();
153        assert_eq!(decoded.request_id, req.request_id);
154        assert_eq!(decoded.query_embedding, req.query_embedding);
155    }
156
157    #[test]
158    fn test_encode_search_response() {
159        let resp = SearchResponse {
160            request_id: [1u8; 16],
161            status: SearchStatus::Ok,
162            results: vec![RemoteSearchResult {
163                chunk_id: [2u8; 16],
164                chunk_text: "hello".to_string(),
165                document_path: "test.md".to_string(),
166                score: 42,
167                merkle_proof: None,
168            }],
169            peer_state_root: [3u8; 32],
170            search_latency_ms: 100,
171            timestamp: 2000,
172            signature: [4u8; 64],
173        };
174
175        let wire = encode(&resp).unwrap();
176        let decoded: SearchResponse = decode(&wire).unwrap();
177        assert_eq!(decoded.results.len(), 1);
178        assert_eq!(decoded.results[0].chunk_text, "hello");
179    }
180
181    #[test]
182    fn test_message_too_large() {
183        // Create a message larger than 1MB
184        let big = vec![0u8; MAX_MESSAGE_SIZE as usize + 100];
185        let result = encode(&big);
186        assert!(result.is_err());
187    }
188
189    #[test]
190    fn test_decode_truncated() {
191        let msg = vec![1u32, 2, 3];
192        let wire = encode(&msg).unwrap();
193
194        // Truncate the payload
195        let result: std::result::Result<Vec<u32>, _> = decode(&wire[..wire.len() - 2]);
196        assert!(result.is_err());
197    }
198
199    #[test]
200    fn test_decode_empty() {
201        let result: std::result::Result<Vec<u32>, _> = decode(&[]);
202        assert!(result.is_err());
203    }
204
205    #[tokio::test]
206    async fn test_async_write_read_roundtrip() {
207        let req = SearchRequest {
208            request_id: [5u8; 16],
209            query_embedding: vec![1, 2, 3, 4, 5],
210            query_text: Some("test".to_string()),
211            max_results: 5,
212            include_proofs: true,
213            model_hash: [6u8; 32],
214            timestamp: 3000,
215            signature: [7u8; 64],
216            public_key: [8u8; 32],
217        };
218
219        // Write to buffer
220        let mut buf = Vec::new();
221        write_message(&mut buf, &req).await.unwrap();
222
223        // Read back
224        let mut cursor = &buf[..];
225        let decoded: SearchRequest = read_message(&mut cursor).await.unwrap();
226        assert_eq!(decoded.request_id, req.request_id);
227        assert_eq!(decoded.query_text, Some("test".to_string()));
228    }
229
230    #[tokio::test]
231    async fn test_keepalive_roundtrip() {
232        let mut buf = Vec::new();
233        write_keepalive(&mut buf).await.unwrap();
234
235        assert_eq!(buf.len(), 4);
236        assert_eq!(buf, vec![0, 0, 0, 0]);
237
238        // Reading a keepalive should return Keepalive error
239        let mut cursor = &buf[..];
240        let result: std::result::Result<SearchRequest, _> = read_message(&mut cursor).await;
241        assert!(matches!(result, Err(TorError::Keepalive)));
242    }
243
244    #[tokio::test]
245    async fn test_multiple_messages_on_stream() {
246        let msg1 = "hello".to_string();
247        let msg2 = "world".to_string();
248        let msg3 = vec![1u32, 2, 3];
249
250        let mut buf = Vec::new();
251        write_message(&mut buf, &msg1).await.unwrap();
252        write_message(&mut buf, &msg2).await.unwrap();
253        write_message(&mut buf, &msg3).await.unwrap();
254
255        let mut cursor = &buf[..];
256        let d1: String = read_message(&mut cursor).await.unwrap();
257        let d2: String = read_message(&mut cursor).await.unwrap();
258        let d3: Vec<u32> = read_message(&mut cursor).await.unwrap();
259
260        assert_eq!(d1, "hello");
261        assert_eq!(d2, "world");
262        assert_eq!(d3, vec![1, 2, 3]);
263    }
264}