msgpacker/
extension.rs

1use super::{
2    error::Error,
3    helpers::{take_buffer, take_buffer_iter, take_byte, take_byte_iter, take_num, take_num_iter},
4    Format, Packable, Unpackable,
5};
6use alloc::{vec, vec::Vec};
7use core::{iter, time::Duration};
8
9/// Custom extension definition as reference to a bytes source.
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub enum Extension {
12    /// n-bytes custom extension
13    Ext(i8, Vec<u8>),
14    /// Protocol reserved extension to represent timestamps
15    Timestamp(Duration),
16}
17
18impl Extension {
19    /// Protocol constant for a timestamp extension
20    pub const TIMESTAMP: i8 = -1;
21}
22
23impl Packable for Extension {
24    #[allow(unreachable_code)]
25    fn pack<T>(&self, buf: &mut T) -> usize
26    where
27        T: Extend<u8>,
28    {
29        match self {
30            Extension::Ext(t, b) if b.len() == 1 => {
31                buf.extend(
32                    iter::once(Format::FIXEXT1)
33                        .chain(iter::once(*t as u8))
34                        .chain(iter::once(b[0])),
35                );
36                3
37            }
38
39            Extension::Ext(t, b) if b.len() == 2 => {
40                buf.extend(
41                    iter::once(Format::FIXEXT2)
42                        .chain(iter::once(*t as u8))
43                        .chain(b.iter().copied()),
44                );
45                4
46            }
47
48            Extension::Ext(t, b) if b.len() == 4 => {
49                buf.extend(
50                    iter::once(Format::FIXEXT4)
51                        .chain(iter::once(*t as u8))
52                        .chain(b.iter().copied()),
53                );
54                6
55            }
56
57            Extension::Ext(t, b) if b.len() == 8 => {
58                buf.extend(
59                    iter::once(Format::FIXEXT8)
60                        .chain(iter::once(*t as u8))
61                        .chain(b.iter().copied()),
62                );
63                10
64            }
65
66            Extension::Ext(t, b) if b.len() == 16 => {
67                buf.extend(
68                    iter::once(Format::FIXEXT16)
69                        .chain(iter::once(*t as u8))
70                        .chain(b.iter().copied()),
71                );
72                18
73            }
74
75            Extension::Ext(t, b) if b.len() <= u8::MAX as usize => {
76                buf.extend(
77                    iter::once(Format::EXT8)
78                        .chain(iter::once(b.len() as u8))
79                        .chain(iter::once(*t as u8))
80                        .chain(b.iter().copied()),
81                );
82                3 + b.len()
83            }
84
85            Extension::Ext(t, b) if b.len() <= u16::MAX as usize => {
86                buf.extend(
87                    iter::once(Format::EXT16)
88                        .chain((b.len() as u16).to_be_bytes().iter().copied())
89                        .chain(iter::once(*t as u8))
90                        .chain(b.iter().copied()),
91                );
92                4 + b.len()
93            }
94
95            Extension::Ext(t, b) if b.len() <= u32::MAX as usize => {
96                buf.extend(
97                    iter::once(Format::EXT32)
98                        .chain((b.len() as u32).to_be_bytes().iter().copied())
99                        .chain(iter::once(*t as u8))
100                        .chain(b.iter().copied()),
101                );
102                6 + b.len()
103            }
104
105            Extension::Ext(_, _) => {
106                #[cfg(feature = "strict")]
107                panic!("strict serialization enabled; the buffer is too large");
108                0
109            }
110
111            Extension::Timestamp(d) if d.as_secs() <= u32::MAX as u64 && d.subsec_nanos() == 0 => {
112                buf.extend(
113                    iter::once(Format::FIXEXT4)
114                        .chain(iter::once(Self::TIMESTAMP as u8))
115                        .chain((d.as_secs() as u32).to_be_bytes().iter().copied()),
116                );
117                6
118            }
119
120            Extension::Timestamp(d)
121                if d.as_secs() < 1u64 << 34 && d.subsec_nanos() < 1u32 << 30 =>
122            {
123                let secs = d.as_secs();
124                let secs_nanos = ((secs >> 32) & 0b11) as u32;
125                let secs = secs as u32;
126
127                let nanos = d.subsec_nanos() << 2;
128                let nanos = nanos | secs_nanos;
129
130                buf.extend(
131                    iter::once(Format::FIXEXT8)
132                        .chain(iter::once(Self::TIMESTAMP as u8))
133                        .chain(nanos.to_be_bytes().iter().copied())
134                        .chain(secs.to_be_bytes().iter().copied()),
135                );
136                10
137            }
138
139            Extension::Timestamp(d) => {
140                buf.extend(
141                    iter::once(Format::EXT8)
142                        .chain(iter::once(12))
143                        .chain(iter::once(Self::TIMESTAMP as u8))
144                        .chain(d.subsec_nanos().to_be_bytes().iter().copied())
145                        .chain(d.as_secs().to_be_bytes().iter().copied()),
146                );
147                15
148            }
149        }
150    }
151}
152
153impl Unpackable for Extension {
154    type Error = Error;
155
156    fn unpack(mut buf: &[u8]) -> Result<(usize, Self), Self::Error> {
157        let format = take_byte(&mut buf)?;
158        match format {
159            Format::FIXEXT1 => {
160                let t = take_byte(&mut buf)? as i8;
161                let x = take_byte(&mut buf)?;
162                Ok((3, Extension::Ext(t, vec![x])))
163            }
164            Format::FIXEXT2 => {
165                let t = take_byte(&mut buf)? as i8;
166                let b = take_buffer(&mut buf, 2)?;
167                Ok((4, Extension::Ext(t, b.to_vec())))
168            }
169            Format::FIXEXT4 => {
170                let t = take_byte(&mut buf)? as i8;
171                if t == Self::TIMESTAMP {
172                    let secs = take_num(&mut buf, u32::from_be_bytes)?;
173                    Ok((6, Extension::Timestamp(Duration::from_secs(secs as u64))))
174                } else {
175                    let b = take_buffer(&mut buf, 4)?;
176                    Ok((6, Extension::Ext(t, b.to_vec())))
177                }
178            }
179            Format::FIXEXT8 => {
180                let t = take_byte(&mut buf)? as i8;
181                if t == Self::TIMESTAMP {
182                    let data = take_num(&mut buf, u64::from_be_bytes)?;
183
184                    let nanos = (data >> 34) as u32;
185                    let secs = data & ((1u64 << 34) - 1);
186
187                    Ok((10, Extension::Timestamp(Duration::new(secs, nanos))))
188                } else {
189                    let b = take_buffer(&mut buf, 8)?;
190                    Ok((10, Extension::Ext(t, b.to_vec())))
191                }
192            }
193            Format::FIXEXT16 => {
194                let t = take_byte(&mut buf)? as i8;
195                let b = take_buffer(&mut buf, 16)?;
196                Ok((18, Extension::Ext(t, b.to_vec())))
197            }
198            Format::EXT8 => {
199                let len = take_byte(&mut buf)? as usize;
200                let t = take_byte(&mut buf)? as i8;
201                if len == 12 && t == Self::TIMESTAMP {
202                    let nanos = take_num(&mut buf, u32::from_be_bytes)?;
203                    let secs = take_num(&mut buf, u64::from_be_bytes)?;
204                    Ok((15, Extension::Timestamp(Duration::new(secs, nanos))))
205                } else {
206                    let b = take_buffer(&mut buf, len)?;
207                    Ok((3 + len, Extension::Ext(t, b.to_vec())))
208                }
209            }
210            Format::EXT16 => {
211                let len = take_num(&mut buf, u16::from_be_bytes)? as usize;
212                let t = take_byte(&mut buf)? as i8;
213                let b = take_buffer(&mut buf, len)?;
214                Ok((4 + len, Extension::Ext(t, b.to_vec())))
215            }
216            Format::EXT32 => {
217                let len = take_num(&mut buf, u32::from_be_bytes)? as usize;
218                let t = take_byte(&mut buf)? as i8;
219                let b = take_buffer(&mut buf, len)?;
220                Ok((6 + len, Extension::Ext(t, b.to_vec())))
221            }
222            _ => Err(Error::InvalidExtension),
223        }
224    }
225
226    fn unpack_iter<I>(bytes: I) -> Result<(usize, Self), Self::Error>
227    where
228        I: IntoIterator<Item = u8>,
229    {
230        let mut bytes = bytes.into_iter();
231        let format = take_byte_iter(bytes.by_ref())?;
232        match format {
233            Format::FIXEXT1 => {
234                let t = take_byte_iter(bytes.by_ref())? as i8;
235                let x = take_byte_iter(bytes.by_ref())?;
236                Ok((3, Extension::Ext(t, vec![x])))
237            }
238            Format::FIXEXT2 => {
239                let t = take_byte_iter(bytes.by_ref())? as i8;
240                let b = take_buffer_iter(bytes.by_ref(), 2)?;
241                Ok((4, Extension::Ext(t, b)))
242            }
243            Format::FIXEXT4 => {
244                let t = take_byte_iter(bytes.by_ref())? as i8;
245                if t == Self::TIMESTAMP {
246                    let secs = take_num_iter(bytes.by_ref(), u32::from_be_bytes)?;
247                    Ok((6, Extension::Timestamp(Duration::from_secs(secs as u64))))
248                } else {
249                    let b = take_buffer_iter(bytes.by_ref(), 4)?;
250                    Ok((6, Extension::Ext(t, b)))
251                }
252            }
253            Format::FIXEXT8 => {
254                let t = take_byte_iter(bytes.by_ref())? as i8;
255                if t == Self::TIMESTAMP {
256                    let data = take_num_iter(bytes.by_ref(), u64::from_be_bytes)?;
257
258                    let nanos = (data >> 34) as u32;
259                    let secs = data & ((1u64 << 34) - 1);
260
261                    Ok((10, Extension::Timestamp(Duration::new(secs, nanos))))
262                } else {
263                    let b = take_buffer_iter(bytes.by_ref(), 8)?;
264                    Ok((10, Extension::Ext(t, b)))
265                }
266            }
267            Format::FIXEXT16 => {
268                let t = take_byte_iter(bytes.by_ref())? as i8;
269                let b = take_buffer_iter(bytes.by_ref(), 16)?;
270                Ok((18, Extension::Ext(t, b)))
271            }
272            Format::EXT8 => {
273                let len = take_byte_iter(bytes.by_ref())? as usize;
274                let t = take_byte_iter(bytes.by_ref())? as i8;
275                if len == 12 && t == Self::TIMESTAMP {
276                    let nanos = take_num_iter(bytes.by_ref(), u32::from_be_bytes)?;
277                    let secs = take_num_iter(bytes.by_ref(), u64::from_be_bytes)?;
278                    Ok((15, Extension::Timestamp(Duration::new(secs, nanos))))
279                } else {
280                    let b = take_buffer_iter(bytes.by_ref(), len)?;
281                    Ok((3 + len, Extension::Ext(t, b)))
282                }
283            }
284            Format::EXT16 => {
285                let len = take_num_iter(bytes.by_ref(), u16::from_be_bytes)? as usize;
286                let t = take_byte_iter(bytes.by_ref())? as i8;
287                let b = take_buffer_iter(bytes.by_ref(), len)?;
288                Ok((4 + len, Extension::Ext(t, b)))
289            }
290            Format::EXT32 => {
291                let len = take_num_iter(bytes.by_ref(), u32::from_be_bytes)? as usize;
292                let t = take_byte_iter(bytes.by_ref())? as i8;
293                let b = take_buffer_iter(bytes.by_ref(), len)?;
294                Ok((6 + len, Extension::Ext(t, b)))
295            }
296            _ => Err(Error::InvalidExtension),
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use proptest::prelude::*;
305
306    proptest! {
307        #[test]
308        fn extension_bytes(mut t: i8, b: Vec<u8>) {
309            if t == Extension::TIMESTAMP {
310                t -= 1;
311            }
312            let x = Extension::Ext(t, b);
313            let mut bytes = vec![];
314            x.pack(&mut bytes);
315            let (_, y) = Extension::unpack(&bytes).unwrap();
316            assert_eq!(x, y);
317        }
318
319        #[test]
320        fn extension_duration(d: Duration) {
321            let x = Extension::Timestamp(d);
322            let mut bytes = vec![];
323            x.pack(&mut bytes);
324            let (_, y) = Extension::unpack(&bytes).unwrap();
325            assert_eq!(x, y);
326        }
327    }
328}