Skip to main content

ironsbe_client/
session.rs

1//! Client session management.
2
3use bytes::BytesMut;
4use futures::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_util::codec::{Decoder, Framed};
7
8/// Simple length-prefixed framing codec for SBE messages.
9pub struct SbeFrameCodec {
10    max_frame_size: usize,
11}
12
13impl SbeFrameCodec {
14    /// Creates a new frame codec with default max frame size.
15    #[must_use]
16    pub fn new() -> Self {
17        Self {
18            max_frame_size: 64 * 1024,
19        }
20    }
21
22    /// Creates a new frame codec with custom max frame size.
23    #[must_use]
24    pub fn with_max_frame_size(max_frame_size: usize) -> Self {
25        Self { max_frame_size }
26    }
27}
28
29impl Default for SbeFrameCodec {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl Decoder for SbeFrameCodec {
36    type Item = BytesMut;
37    type Error = std::io::Error;
38
39    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
40        use bytes::Buf;
41
42        if src.len() < 4 {
43            return Ok(None);
44        }
45
46        let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
47
48        if length > self.max_frame_size {
49            return Err(std::io::Error::new(
50                std::io::ErrorKind::InvalidData,
51                "frame too large",
52            ));
53        }
54
55        if src.len() < 4 + length {
56            src.reserve(4 + length - src.len());
57            return Ok(None);
58        }
59
60        src.advance(4);
61        Ok(Some(src.split_to(length)))
62    }
63}
64
65impl<T: AsRef<[u8]>> tokio_util::codec::Encoder<T> for SbeFrameCodec {
66    type Error = std::io::Error;
67
68    fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
69        use bytes::BufMut;
70
71        let data = item.as_ref();
72        if data.len() > self.max_frame_size {
73            return Err(std::io::Error::new(
74                std::io::ErrorKind::InvalidData,
75                "frame too large",
76            ));
77        }
78
79        dst.reserve(4 + data.len());
80        dst.put_u32_le(data.len() as u32);
81        dst.put_slice(data);
82        Ok(())
83    }
84}
85
86/// Client session wrapping a TCP connection.
87pub struct ClientSession {
88    framed: Framed<TcpStream, SbeFrameCodec>,
89}
90
91impl ClientSession {
92    /// Creates a new client session from a TCP stream.
93    #[must_use]
94    pub fn new(stream: TcpStream) -> Self {
95        Self {
96            framed: Framed::new(stream, SbeFrameCodec::default()),
97        }
98    }
99
100    /// Creates a new client session with custom frame size.
101    #[must_use]
102    pub fn with_max_frame_size(stream: TcpStream, max_frame_size: usize) -> Self {
103        Self {
104            framed: Framed::new(stream, SbeFrameCodec::with_max_frame_size(max_frame_size)),
105        }
106    }
107
108    /// Sends a message to the server.
109    ///
110    /// # Errors
111    /// Returns IO error if send fails.
112    pub async fn send(&mut self, message: &[u8]) -> std::io::Result<()> {
113        self.framed.send(message).await
114    }
115
116    /// Receives a message from the server.
117    ///
118    /// # Returns
119    /// `Ok(Some(bytes))` if received, `Ok(None)` if connection closed.
120    ///
121    /// # Errors
122    /// Returns IO error if receive fails.
123    pub async fn recv(&mut self) -> std::io::Result<Option<BytesMut>> {
124        match self.framed.next().await {
125            Some(result) => result.map(Some),
126            None => Ok(None),
127        }
128    }
129
130    /// Closes the session.
131    pub async fn close(mut self) -> std::io::Result<()> {
132        SinkExt::<&[u8]>::close(&mut self.framed).await
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_sbe_frame_codec_new() {
142        let codec = SbeFrameCodec::new();
143        assert_eq!(codec.max_frame_size, 64 * 1024);
144    }
145
146    #[test]
147    fn test_sbe_frame_codec_with_max_frame_size() {
148        let codec = SbeFrameCodec::with_max_frame_size(128 * 1024);
149        assert_eq!(codec.max_frame_size, 128 * 1024);
150    }
151
152    #[test]
153    fn test_sbe_frame_codec_default() {
154        let codec = SbeFrameCodec::default();
155        assert_eq!(codec.max_frame_size, 64 * 1024);
156    }
157
158    #[test]
159    fn test_decode_incomplete_header() {
160        let mut codec = SbeFrameCodec::new();
161        let mut buf = BytesMut::from(&[0u8, 1, 2][..]);
162
163        let result = codec.decode(&mut buf);
164        assert!(result.is_ok());
165        assert!(result.unwrap().is_none());
166    }
167
168    #[test]
169    fn test_decode_incomplete_frame() {
170        let mut codec = SbeFrameCodec::new();
171        let mut buf = BytesMut::new();
172        buf.extend_from_slice(&10u32.to_le_bytes()); // length = 10
173        buf.extend_from_slice(&[1, 2, 3, 4, 5]); // only 5 bytes, need 10
174
175        let result = codec.decode(&mut buf);
176        assert!(result.is_ok());
177        assert!(result.unwrap().is_none());
178    }
179
180    #[test]
181    fn test_decode_complete_frame() {
182        let mut codec = SbeFrameCodec::new();
183        let mut buf = BytesMut::new();
184        let data = b"Hello";
185        buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
186        buf.extend_from_slice(data);
187
188        let result = codec.decode(&mut buf);
189        assert!(result.is_ok());
190        let frame = result.unwrap();
191        assert!(frame.is_some());
192        assert_eq!(frame.unwrap().as_ref(), data);
193    }
194
195    #[test]
196    fn test_decode_frame_too_large() {
197        let mut codec = SbeFrameCodec::with_max_frame_size(10);
198        let mut buf = BytesMut::new();
199        buf.extend_from_slice(&100u32.to_le_bytes()); // length = 100, exceeds max
200
201        let result = codec.decode(&mut buf);
202        assert!(result.is_err());
203        let err = result.unwrap_err();
204        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
205    }
206
207    #[test]
208    fn test_encode_frame() {
209        use tokio_util::codec::Encoder;
210
211        let mut codec = SbeFrameCodec::new();
212        let mut buf = BytesMut::new();
213        let data = b"Hello";
214
215        let result = codec.encode(data.as_slice(), &mut buf);
216        assert!(result.is_ok());
217
218        // Check length prefix
219        let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
220        assert_eq!(len, data.len());
221
222        // Check data
223        assert_eq!(&buf[4..], data);
224    }
225
226    #[test]
227    fn test_encode_frame_too_large() {
228        use tokio_util::codec::Encoder;
229
230        let mut codec = SbeFrameCodec::with_max_frame_size(5);
231        let mut buf = BytesMut::new();
232        let data = b"Hello World"; // 11 bytes, exceeds max of 5
233
234        let result = codec.encode(data.as_slice(), &mut buf);
235        assert!(result.is_err());
236        let err = result.unwrap_err();
237        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
238    }
239}