aldrin_core/message/
packetizer.rs

1use bytes::{Buf, BytesMut};
2use std::mem::MaybeUninit;
3
4const MIN_RESERVE_CAPACITY: usize = 64 * 1024;
5const MAX_RESERVE_CAPACITY: usize = 4 * 1024 * 1024;
6
7/// Splits a continuous stream of bytes into individual messages.
8#[derive(Debug)]
9pub struct Packetizer {
10    buf: BytesMut,
11    len: Option<usize>,
12}
13
14impl Packetizer {
15    pub fn new() -> Self {
16        Self {
17            buf: BytesMut::new(),
18            len: None,
19        }
20    }
21
22    pub fn extend_from_slice(&mut self, bytes: &[u8]) {
23        self.buf.extend_from_slice(bytes.as_ref());
24    }
25
26    /// Returns a slice of uninitialized bytes at the end of the internal buffer.
27    ///
28    /// This function, together with [`bytes_written`](Self::bytes_written), make it possible to
29    /// fill the packetizer without an intermediate buffer.
30    ///
31    /// The slice returned by this function is guaranteed to be non-empty.
32    pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<u8>] {
33        if let Some(len) = self.len {
34            if self.buf.capacity() < len {
35                let reserve =
36                    (len - self.buf.len()).clamp(MIN_RESERVE_CAPACITY, MAX_RESERVE_CAPACITY);
37                self.buf.reserve(reserve);
38            }
39        } else if self.buf.capacity() == self.buf.len() {
40            self.buf.reserve(MIN_RESERVE_CAPACITY);
41        }
42
43        let slice = self.buf.spare_capacity_mut();
44        debug_assert!(!slice.is_empty());
45        slice
46    }
47
48    /// Asserts that the next `len` bytes have been initialized.
49    ///
50    /// # Safety
51    ///
52    /// You must ensure that prior to calling this function, at least `len` bytes of the slice
53    /// returned by [`spare_capacity_mut`](Self::spare_capacity_mut) have been initialized.
54    pub unsafe fn bytes_written(&mut self, len: usize) {
55        unsafe {
56            self.buf.set_len(self.buf.len() + len);
57        }
58    }
59
60    pub fn next_message(&mut self) -> Option<BytesMut> {
61        if self.buf.len() < 4 {
62            return None;
63        }
64
65        let len = match self.len {
66            Some(len) => len,
67
68            None => {
69                let len = (&self.buf[..4]).get_u32_le() as usize;
70                self.len = Some(len);
71                len
72            }
73        };
74
75        if self.buf.len() >= len {
76            let mut msg = self.buf.split_to(len.max(4));
77            msg.truncate(len);
78            self.len = None;
79            Some(msg)
80        } else {
81            None
82        }
83    }
84}
85
86impl Default for Packetizer {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92#[cfg(test)]
93mod test {
94    use super::super::{CreateChannel, CreateObject, Message, MessageOps, Shutdown};
95    use super::Packetizer;
96    use crate::{ChannelEndWithCapacity, ObjectUuid};
97    use bytes::Buf;
98    use std::mem::MaybeUninit;
99    use uuid::uuid;
100
101    #[test]
102    fn extend_from_slice() {
103        let msg1 = Message::Shutdown(Shutdown);
104        let msg2 = Message::CreateObject(CreateObject {
105            serial: 1,
106            uuid: ObjectUuid(uuid!("b7c3be13-5377-466e-b4bf-373876523d1b")),
107        });
108        let msg3 = Message::CreateChannel(CreateChannel {
109            serial: 0,
110            end: ChannelEndWithCapacity::Sender,
111        });
112
113        let mut serialized = msg1.clone().serialize_message().unwrap();
114        let tmp = msg2.clone().serialize_message().unwrap();
115        serialized.extend_from_slice(&tmp);
116        let tmp = msg3.clone().serialize_message().unwrap();
117        serialized.extend_from_slice(&tmp);
118        assert_eq!(
119            serialized[..],
120            [
121                5, 0, 0, 0, 2, 22, 0, 0, 0, 3, 1, 0xb7, 0xc3, 0xbe, 0x13, 0x53, 0x77, 0x46, 0x6e,
122                0xb4, 0xbf, 0x37, 0x38, 0x76, 0x52, 0x3d, 0x1b, 7, 0, 0, 0, 19, 0, 0,
123            ]
124        );
125
126        let mut packetizer = Packetizer::new();
127        assert_eq!(packetizer.next_message(), None);
128
129        packetizer.extend_from_slice(&serialized[..3]);
130        serialized.advance(3);
131        assert_eq!(packetizer.next_message(), None);
132
133        packetizer.extend_from_slice(&serialized[..25]);
134        serialized.advance(25);
135        let msg1_serialized = packetizer.next_message().unwrap();
136        assert_eq!(Message::deserialize_message(msg1_serialized), Ok(msg1));
137        let msg2_serialized = packetizer.next_message().unwrap();
138        assert_eq!(Message::deserialize_message(msg2_serialized), Ok(msg2));
139        assert_eq!(packetizer.next_message(), None);
140
141        packetizer.extend_from_slice(&serialized[..6]);
142        serialized.advance(6);
143        let msg3_serialized = packetizer.next_message().unwrap();
144        assert_eq!(Message::deserialize_message(msg3_serialized), Ok(msg3));
145        assert_eq!(packetizer.next_message(), None);
146
147        assert_eq!(serialized[..], []);
148    }
149
150    #[test]
151    fn spare_capacity_mut() {
152        fn write_slice(dst: &mut [MaybeUninit<u8>], src: &[u8]) {
153            for (&src, dst) in src.iter().zip(dst) {
154                dst.write(src);
155            }
156        }
157
158        let msg1 = Message::Shutdown(Shutdown);
159        let msg2 = Message::CreateObject(CreateObject {
160            serial: 1,
161            uuid: ObjectUuid(uuid!("b7c3be13-5377-466e-b4bf-373876523d1b")),
162        });
163        let msg3 = Message::CreateChannel(CreateChannel {
164            serial: 0,
165            end: ChannelEndWithCapacity::Sender,
166        });
167
168        let mut serialized = msg1.clone().serialize_message().unwrap();
169        let tmp = msg2.clone().serialize_message().unwrap();
170        serialized.extend_from_slice(&tmp);
171        let tmp = msg3.clone().serialize_message().unwrap();
172        serialized.extend_from_slice(&tmp);
173        assert_eq!(
174            serialized[..],
175            [
176                5, 0, 0, 0, 2, 22, 0, 0, 0, 3, 1, 0xb7, 0xc3, 0xbe, 0x13, 0x53, 0x77, 0x46, 0x6e,
177                0xb4, 0xbf, 0x37, 0x38, 0x76, 0x52, 0x3d, 0x1b, 7, 0, 0, 0, 19, 0, 0,
178            ]
179        );
180
181        let mut packetizer = Packetizer::new();
182        assert_eq!(packetizer.next_message(), None);
183
184        write_slice(packetizer.spare_capacity_mut(), &serialized[..3]);
185        unsafe {
186            packetizer.bytes_written(3);
187        }
188        serialized.advance(3);
189        assert_eq!(packetizer.next_message(), None);
190
191        write_slice(packetizer.spare_capacity_mut(), &serialized[..25]);
192        unsafe {
193            packetizer.bytes_written(25);
194        }
195        serialized.advance(25);
196        let msg1_serialized = packetizer.next_message().unwrap();
197        assert_eq!(Message::deserialize_message(msg1_serialized), Ok(msg1));
198        let msg2_serialized = packetizer.next_message().unwrap();
199        assert_eq!(Message::deserialize_message(msg2_serialized), Ok(msg2));
200        assert_eq!(packetizer.next_message(), None);
201
202        write_slice(packetizer.spare_capacity_mut(), &serialized[..6]);
203        unsafe {
204            packetizer.bytes_written(6);
205        }
206        serialized.advance(6);
207        let msg3_serialized = packetizer.next_message().unwrap();
208        assert_eq!(Message::deserialize_message(msg3_serialized), Ok(msg3));
209        assert_eq!(packetizer.next_message(), None);
210
211        assert_eq!(serialized[..], []);
212    }
213}