ironsbe_client/
session.rs1use bytes::BytesMut;
4use futures::{SinkExt, StreamExt};
5use tokio::net::TcpStream;
6use tokio_util::codec::{Decoder, Framed};
7
8pub struct SbeFrameCodec {
10 max_frame_size: usize,
11}
12
13impl SbeFrameCodec {
14 #[must_use]
16 pub fn new() -> Self {
17 Self {
18 max_frame_size: 64 * 1024,
19 }
20 }
21
22 #[must_use]
24 pub fn with_max_frame_size(max_frame_size: usize) -> Self {
25 Self { max_frame_size }
26 }
27}
28
29impl Default for SbeFrameCodec {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl Decoder for SbeFrameCodec {
36 type Item = BytesMut;
37 type Error = std::io::Error;
38
39 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
40 use bytes::Buf;
41
42 if src.len() < 4 {
43 return Ok(None);
44 }
45
46 let length = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
47
48 if length > self.max_frame_size {
49 return Err(std::io::Error::new(
50 std::io::ErrorKind::InvalidData,
51 "frame too large",
52 ));
53 }
54
55 if src.len() < 4 + length {
56 src.reserve(4 + length - src.len());
57 return Ok(None);
58 }
59
60 src.advance(4);
61 Ok(Some(src.split_to(length)))
62 }
63}
64
65impl<T: AsRef<[u8]>> tokio_util::codec::Encoder<T> for SbeFrameCodec {
66 type Error = std::io::Error;
67
68 fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
69 use bytes::BufMut;
70
71 let data = item.as_ref();
72 if data.len() > self.max_frame_size {
73 return Err(std::io::Error::new(
74 std::io::ErrorKind::InvalidData,
75 "frame too large",
76 ));
77 }
78
79 dst.reserve(4 + data.len());
80 dst.put_u32_le(data.len() as u32);
81 dst.put_slice(data);
82 Ok(())
83 }
84}
85
86pub struct ClientSession {
88 framed: Framed<TcpStream, SbeFrameCodec>,
89}
90
91impl ClientSession {
92 #[must_use]
94 pub fn new(stream: TcpStream) -> Self {
95 Self {
96 framed: Framed::new(stream, SbeFrameCodec::default()),
97 }
98 }
99
100 #[must_use]
102 pub fn with_max_frame_size(stream: TcpStream, max_frame_size: usize) -> Self {
103 Self {
104 framed: Framed::new(stream, SbeFrameCodec::with_max_frame_size(max_frame_size)),
105 }
106 }
107
108 pub async fn send(&mut self, message: &[u8]) -> std::io::Result<()> {
113 self.framed.send(message).await
114 }
115
116 pub async fn recv(&mut self) -> std::io::Result<Option<BytesMut>> {
124 match self.framed.next().await {
125 Some(result) => result.map(Some),
126 None => Ok(None),
127 }
128 }
129
130 pub async fn close(mut self) -> std::io::Result<()> {
132 SinkExt::<&[u8]>::close(&mut self.framed).await
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_sbe_frame_codec_new() {
142 let codec = SbeFrameCodec::new();
143 assert_eq!(codec.max_frame_size, 64 * 1024);
144 }
145
146 #[test]
147 fn test_sbe_frame_codec_with_max_frame_size() {
148 let codec = SbeFrameCodec::with_max_frame_size(128 * 1024);
149 assert_eq!(codec.max_frame_size, 128 * 1024);
150 }
151
152 #[test]
153 fn test_sbe_frame_codec_default() {
154 let codec = SbeFrameCodec::default();
155 assert_eq!(codec.max_frame_size, 64 * 1024);
156 }
157
158 #[test]
159 fn test_decode_incomplete_header() {
160 let mut codec = SbeFrameCodec::new();
161 let mut buf = BytesMut::from(&[0u8, 1, 2][..]);
162
163 let result = codec.decode(&mut buf);
164 assert!(result.is_ok());
165 assert!(result.unwrap().is_none());
166 }
167
168 #[test]
169 fn test_decode_incomplete_frame() {
170 let mut codec = SbeFrameCodec::new();
171 let mut buf = BytesMut::new();
172 buf.extend_from_slice(&10u32.to_le_bytes()); buf.extend_from_slice(&[1, 2, 3, 4, 5]); let result = codec.decode(&mut buf);
176 assert!(result.is_ok());
177 assert!(result.unwrap().is_none());
178 }
179
180 #[test]
181 fn test_decode_complete_frame() {
182 let mut codec = SbeFrameCodec::new();
183 let mut buf = BytesMut::new();
184 let data = b"Hello";
185 buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
186 buf.extend_from_slice(data);
187
188 let result = codec.decode(&mut buf);
189 assert!(result.is_ok());
190 let frame = result.unwrap();
191 assert!(frame.is_some());
192 assert_eq!(frame.unwrap().as_ref(), data);
193 }
194
195 #[test]
196 fn test_decode_frame_too_large() {
197 let mut codec = SbeFrameCodec::with_max_frame_size(10);
198 let mut buf = BytesMut::new();
199 buf.extend_from_slice(&100u32.to_le_bytes()); let result = codec.decode(&mut buf);
202 assert!(result.is_err());
203 let err = result.unwrap_err();
204 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
205 }
206
207 #[test]
208 fn test_encode_frame() {
209 use tokio_util::codec::Encoder;
210
211 let mut codec = SbeFrameCodec::new();
212 let mut buf = BytesMut::new();
213 let data = b"Hello";
214
215 let result = codec.encode(data.as_slice(), &mut buf);
216 assert!(result.is_ok());
217
218 let len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
220 assert_eq!(len, data.len());
221
222 assert_eq!(&buf[4..], data);
224 }
225
226 #[test]
227 fn test_encode_frame_too_large() {
228 use tokio_util::codec::Encoder;
229
230 let mut codec = SbeFrameCodec::with_max_frame_size(5);
231 let mut buf = BytesMut::new();
232 let data = b"Hello World"; let result = codec.encode(data.as_slice(), &mut buf);
235 assert!(result.is_err());
236 let err = result.unwrap_err();
237 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
238 }
239}