cbe_program/
short_vec.rs

1//! Compact serde-encoding of vectors with small length.
2
3#![allow(clippy::integer_arithmetic)]
4use {
5    serde::{
6        de::{self, Deserializer, SeqAccess, Visitor},
7        ser::{self, SerializeTuple, Serializer},
8        Deserialize, Serialize,
9    },
10    std::{convert::TryFrom, fmt, marker::PhantomData},
11};
12
13/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
14/// 0x7f, the top bit is set and the remaining value is stored in the next
15/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
16/// byte, if needed, uses all 8 bits to store the last byte of the original
17/// value.
18#[derive(AbiExample)]
19pub struct ShortU16(pub u16);
20
21impl Serialize for ShortU16 {
22    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
23    where
24        S: Serializer,
25    {
26        // Pass a non-zero value to serialize_tuple() so that serde_json will
27        // generate an open bracket.
28        let mut seq = serializer.serialize_tuple(1)?;
29
30        let mut rem_val = self.0;
31        loop {
32            let mut elem = (rem_val & 0x7f) as u8;
33            rem_val >>= 7;
34            if rem_val == 0 {
35                seq.serialize_element(&elem)?;
36                break;
37            } else {
38                elem |= 0x80;
39                seq.serialize_element(&elem)?;
40            }
41        }
42        seq.end()
43    }
44}
45
46enum VisitStatus {
47    Done(u16),
48    More(u16),
49}
50
51#[derive(Debug)]
52enum VisitError {
53    TooLong(usize),
54    TooShort(usize),
55    Overflow(u32),
56    Alias,
57    ByteThreeContinues,
58}
59
60impl VisitError {
61    fn into_de_error<'de, A>(self) -> A::Error
62    where
63        A: SeqAccess<'de>,
64    {
65        match self {
66            VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
67            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
68            VisitError::Overflow(val) => de::Error::invalid_value(
69                de::Unexpected::Unsigned(val as u64),
70                &"a value in the range [0, 65535]",
71            ),
72            VisitError::Alias => de::Error::invalid_value(
73                de::Unexpected::Other("alias encoding"),
74                &"strict form encoding",
75            ),
76            VisitError::ByteThreeContinues => de::Error::invalid_value(
77                de::Unexpected::Other("continue signal on byte-three"),
78                &"a terminal signal on or before byte-three",
79            ),
80        }
81    }
82}
83
84type VisitResult = Result<VisitStatus, VisitError>;
85
86const MAX_ENCODING_LENGTH: usize = 3;
87fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
88    if elem == 0 && nth_byte != 0 {
89        return Err(VisitError::Alias);
90    }
91
92    let val = u32::from(val);
93    let elem = u32::from(elem);
94    let elem_val = elem & 0x7f;
95    let elem_done = (elem & 0x80) == 0;
96
97    if nth_byte >= MAX_ENCODING_LENGTH {
98        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
99    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
100        return Err(VisitError::ByteThreeContinues);
101    }
102
103    let shift = u32::try_from(nth_byte)
104        .unwrap_or(std::u32::MAX)
105        .saturating_mul(7);
106    let elem_val = elem_val.checked_shl(shift).unwrap_or(std::u32::MAX);
107
108    let new_val = val | elem_val;
109    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
110
111    if elem_done {
112        Ok(VisitStatus::Done(val))
113    } else {
114        Ok(VisitStatus::More(val))
115    }
116}
117
118struct ShortU16Visitor;
119
120impl<'de> Visitor<'de> for ShortU16Visitor {
121    type Value = ShortU16;
122
123    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
124        formatter.write_str("a ShortU16")
125    }
126
127    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
128    where
129        A: SeqAccess<'de>,
130    {
131        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
132        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
133        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
134        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
135        let mut val: u16 = 0;
136        for nth_byte in 0..MAX_ENCODING_LENGTH {
137            let elem: u8 = seq.next_element()?.ok_or_else(|| {
138                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
139            })?;
140            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
141                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
142                VisitStatus::More(new_val) => val = new_val,
143            }
144        }
145
146        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
147    }
148}
149
150impl<'de> Deserialize<'de> for ShortU16 {
151    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
152    where
153        D: Deserializer<'de>,
154    {
155        deserializer.deserialize_tuple(3, ShortU16Visitor)
156    }
157}
158
159/// If you don't want to use the ShortVec newtype, you can do ShortVec
160/// serialization on an ordinary vector with the following field annotation:
161///
162/// #[serde(with = "short_vec")]
163///
164pub fn serialize<S: Serializer, T: Serialize>(
165    elements: &[T],
166    serializer: S,
167) -> Result<S::Ok, S::Error> {
168    // Pass a non-zero value to serialize_tuple() so that serde_json will
169    // generate an open bracket.
170    let mut seq = serializer.serialize_tuple(1)?;
171
172    let len = elements.len();
173    if len > std::u16::MAX as usize {
174        return Err(ser::Error::custom("length larger than u16"));
175    }
176    let short_len = ShortU16(len as u16);
177    seq.serialize_element(&short_len)?;
178
179    for element in elements {
180        seq.serialize_element(element)?;
181    }
182    seq.end()
183}
184
185struct ShortVecVisitor<T> {
186    _t: PhantomData<T>,
187}
188
189impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
190where
191    T: Deserialize<'de>,
192{
193    type Value = Vec<T>;
194
195    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
196        formatter.write_str("a Vec with a multi-byte length")
197    }
198
199    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
200    where
201        A: SeqAccess<'de>,
202    {
203        let short_len: ShortU16 = seq
204            .next_element()?
205            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
206        let len = short_len.0 as usize;
207
208        let mut result = Vec::with_capacity(len);
209        for i in 0..len {
210            let elem = seq
211                .next_element()?
212                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
213            result.push(elem);
214        }
215        Ok(result)
216    }
217}
218
219/// If you don't want to use the ShortVec newtype, you can do ShortVec
220/// deserialization on an ordinary vector with the following field annotation:
221///
222/// #[serde(with = "short_vec")]
223///
224pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
225where
226    D: Deserializer<'de>,
227    T: Deserialize<'de>,
228{
229    let visitor = ShortVecVisitor { _t: PhantomData };
230    deserializer.deserialize_tuple(std::usize::MAX, visitor)
231}
232
233pub struct ShortVec<T>(pub Vec<T>);
234
235impl<T: Serialize> Serialize for ShortVec<T> {
236    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237    where
238        S: Serializer,
239    {
240        serialize(&self.0, serializer)
241    }
242}
243
244impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
245    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
246    where
247        D: Deserializer<'de>,
248    {
249        deserialize(deserializer).map(ShortVec)
250    }
251}
252
253/// Return the decoded value and how many bytes it consumed.
254#[allow(clippy::result_unit_err)]
255pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
256    let mut val = 0;
257    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
258        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
259            VisitStatus::More(new_val) => val = new_val,
260            VisitStatus::Done(new_val) => {
261                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
262            }
263        }
264    }
265    Err(())
266}
267
268#[cfg(test)]
269mod tests {
270    use {
271        super::*,
272        assert_matches::assert_matches,
273        bincode::{deserialize, serialize},
274    };
275
276    /// Return the serialized length.
277    fn encode_len(len: u16) -> Vec<u8> {
278        bincode::serialize(&ShortU16(len)).unwrap()
279    }
280
281    fn assert_len_encoding(len: u16, bytes: &[u8]) {
282        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
283        assert_eq!(
284            decode_shortu16_len(bytes).unwrap(),
285            (usize::from(len), bytes.len()),
286            "unexpected usize decoding"
287        );
288    }
289
290    #[test]
291    fn test_short_vec_encode_len() {
292        assert_len_encoding(0x0, &[0x0]);
293        assert_len_encoding(0x7f, &[0x7f]);
294        assert_len_encoding(0x80, &[0x80, 0x01]);
295        assert_len_encoding(0xff, &[0xff, 0x01]);
296        assert_len_encoding(0x100, &[0x80, 0x02]);
297        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
298        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
299    }
300
301    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
302        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
303    }
304
305    fn assert_bad_deserialized_value(bytes: &[u8]) {
306        assert!(deserialize::<ShortU16>(bytes).is_err());
307    }
308
309    #[test]
310    fn test_deserialize() {
311        assert_good_deserialized_value(0x0000, &[0x00]);
312        assert_good_deserialized_value(0x007f, &[0x7f]);
313        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
314        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
315        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
316        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
317        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
318        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
319        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
320
321        // aliases
322        // 0x0000
323        assert_bad_deserialized_value(&[0x80, 0x00]);
324        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
325        // 0x007f
326        assert_bad_deserialized_value(&[0xff, 0x00]);
327        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
328        // 0x0080
329        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
330        // 0x00ff
331        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
332        // 0x0100
333        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
334        // 0x07ff
335        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
336        // 0x3fff
337        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
338
339        // too short
340        assert_bad_deserialized_value(&[]);
341        assert_bad_deserialized_value(&[0x80]);
342
343        // too long
344        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
345
346        // too large
347        // 0x0001_0000
348        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
349        // 0x0001_8000
350        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
351    }
352
353    #[test]
354    fn test_short_vec_u8() {
355        let vec = ShortVec(vec![4u8; 32]);
356        let bytes = serialize(&vec).unwrap();
357        assert_eq!(bytes.len(), vec.0.len() + 1);
358
359        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
360        assert_eq!(vec.0, vec1.0);
361    }
362
363    #[test]
364    fn test_short_vec_u8_too_long() {
365        let vec = ShortVec(vec![4u8; std::u16::MAX as usize]);
366        assert_matches!(serialize(&vec), Ok(_));
367
368        let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]);
369        assert_matches!(serialize(&vec), Err(_));
370    }
371
372    #[test]
373    fn test_short_vec_json() {
374        let vec = ShortVec(vec![0, 1, 2]);
375        let s = serde_json::to_string(&vec).unwrap();
376        assert_eq!(s, "[[3],0,1,2]");
377    }
378
379    #[test]
380    fn test_short_vec_aliased_length() {
381        let bytes = [
382            0x81, 0x80, 0x00, // 3-byte alias of 1
383            0x00,
384        ];
385        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
386    }
387}