siphon_protocol/
codec.rs

1use bytes::{Buf, BufMut, BytesMut};
2use serde::{de::DeserializeOwned, Serialize};
3use thiserror::Error;
4use tokio_util::codec::{Decoder, Encoder};
5
6/// Maximum frame size (16 MB)
7const MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
8
9/// Errors that can occur during encoding/decoding
10#[derive(Debug, Error)]
11pub enum CodecError {
12    #[error("Frame too large: {0} bytes (max {MAX_FRAME_SIZE})")]
13    FrameTooLarge(usize),
14
15    #[error("IO error: {0}")]
16    Io(#[from] std::io::Error),
17
18    #[error("JSON error: {0}")]
19    Json(#[from] serde_json::Error),
20}
21
22/// Length-delimited JSON codec for tunnel messages
23///
24/// Wire format:
25/// ```text
26/// +----------------+------------------+
27/// | Length (4 bytes| JSON payload     |
28/// | big-endian u32)| (variable)       |
29/// +----------------+------------------+
30/// ```
31pub struct TunnelCodec<T> {
32    _phantom: std::marker::PhantomData<T>,
33}
34
35impl<T> TunnelCodec<T> {
36    pub fn new() -> Self {
37        Self {
38            _phantom: std::marker::PhantomData,
39        }
40    }
41}
42
43impl<T> Default for TunnelCodec<T> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl<T: DeserializeOwned> Decoder for TunnelCodec<T> {
50    type Item = T;
51    type Error = CodecError;
52
53    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
54        // Need at least 4 bytes for length prefix
55        if src.len() < 4 {
56            return Ok(None);
57        }
58
59        // Peek at the length without consuming
60        let length = u32::from_be_bytes([src[0], src[1], src[2], src[3]]) as usize;
61
62        // Check frame size limit
63        if length > MAX_FRAME_SIZE {
64            return Err(CodecError::FrameTooLarge(length));
65        }
66
67        // Check if we have the full frame
68        let total_len = 4 + length;
69        if src.len() < total_len {
70            // Reserve space for the full frame
71            src.reserve(total_len - src.len());
72            return Ok(None);
73        }
74
75        // Consume the length prefix
76        src.advance(4);
77
78        // Take the JSON payload
79        let payload = src.split_to(length);
80
81        // Deserialize
82        let message = serde_json::from_slice(&payload)?;
83        Ok(Some(message))
84    }
85}
86
87impl<T: Serialize> Encoder<T> for TunnelCodec<T> {
88    type Error = CodecError;
89
90    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
91        // Serialize to JSON
92        let json = serde_json::to_vec(&item)?;
93
94        // Check frame size limit
95        if json.len() > MAX_FRAME_SIZE {
96            return Err(CodecError::FrameTooLarge(json.len()));
97        }
98
99        // Write length prefix
100        dst.reserve(4 + json.len());
101        dst.put_u32(json.len() as u32);
102        dst.put_slice(&json);
103
104        Ok(())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::messages::{ClientMessage, ServerMessage, TunnelType};
112
113    #[test]
114    fn test_roundtrip_client_message() {
115        let mut codec = TunnelCodec::<ClientMessage>::new();
116        let msg = ClientMessage::RequestTunnel {
117            subdomain: Some("test".to_string()),
118            tunnel_type: TunnelType::Http,
119            local_port: 8080,
120        };
121
122        // Encode
123        let mut buf = BytesMut::new();
124        codec.encode(msg.clone(), &mut buf).unwrap();
125
126        // Decode
127        let decoded = codec.decode(&mut buf).unwrap().unwrap();
128        match decoded {
129            ClientMessage::RequestTunnel {
130                subdomain,
131                tunnel_type,
132                local_port,
133            } => {
134                assert_eq!(subdomain, Some("test".to_string()));
135                assert_eq!(tunnel_type, TunnelType::Http);
136                assert_eq!(local_port, 8080);
137            }
138            _ => panic!("Wrong variant"),
139        }
140    }
141
142    #[test]
143    fn test_roundtrip_server_message() {
144        let mut codec = TunnelCodec::<ServerMessage>::new();
145        let msg = ServerMessage::HttpRequest {
146            stream_id: 42,
147            method: "GET".to_string(),
148            uri: "/api/test".to_string(),
149            headers: vec![("Host".to_string(), "example.com".to_string())],
150            body: vec![],
151        };
152
153        // Encode
154        let mut buf = BytesMut::new();
155        codec.encode(msg, &mut buf).unwrap();
156
157        // Decode
158        let decoded = codec.decode(&mut buf).unwrap().unwrap();
159        match decoded {
160            ServerMessage::HttpRequest {
161                stream_id,
162                method,
163                uri,
164                ..
165            } => {
166                assert_eq!(stream_id, 42);
167                assert_eq!(method, "GET");
168                assert_eq!(uri, "/api/test");
169            }
170            _ => panic!("Wrong variant"),
171        }
172    }
173
174    #[test]
175    fn test_partial_frame() {
176        let mut codec = TunnelCodec::<ClientMessage>::new();
177        let msg = ClientMessage::Ping { timestamp: 12345 };
178
179        // Encode
180        let mut buf = BytesMut::new();
181        codec.encode(msg, &mut buf).unwrap();
182
183        // Split the buffer in half
184        let full_len = buf.len();
185        let mut partial = buf.split_to(full_len / 2);
186
187        // Should return None (incomplete)
188        assert!(codec.decode(&mut partial).unwrap().is_none());
189
190        // Add the rest
191        partial.unsplit(buf);
192
193        // Now should decode
194        let decoded = codec.decode(&mut partial).unwrap().unwrap();
195        match decoded {
196            ClientMessage::Ping { timestamp } => assert_eq!(timestamp, 12345),
197            _ => panic!("Wrong variant"),
198        }
199    }
200}