Skip to main content

rialo_s_short_vec/
lib.rs

1//! Compact serde-encoding of vectors with small length.
2#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![allow(clippy::arithmetic_side_effects)]
4use std::{convert::TryFrom, fmt, marker::PhantomData};
5
6#[cfg(feature = "frozen-abi")]
7use rialo_frozen_abi_macro::AbiExample;
8use serde::{
9    de::{self, Deserializer, SeqAccess, Visitor},
10    ser::{self, SerializeTuple, Serializer},
11    Deserialize, Serialize,
12};
13
14/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
15/// 0x7f, the top bit is set and the remaining value is stored in the next
16/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
17/// byte may only have the 2 least-significant bits set, otherwise the encoded
18/// value will overflow the u16.
19#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
20pub struct ShortU16(pub u16);
21
22impl Serialize for ShortU16 {
23    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24    where
25        S: Serializer,
26    {
27        // Pass a non-zero value to serialize_tuple() so that serde_json will
28        // generate an open bracket.
29        let mut seq = serializer.serialize_tuple(1)?;
30
31        let mut rem_val = self.0;
32        loop {
33            let mut elem = (rem_val & 0x7f) as u8;
34            rem_val >>= 7;
35            if rem_val == 0 {
36                seq.serialize_element(&elem)?;
37                break;
38            } else {
39                elem |= 0x80;
40                seq.serialize_element(&elem)?;
41            }
42        }
43        seq.end()
44    }
45}
46
47enum VisitStatus {
48    Done(u16),
49    More(u16),
50}
51
52#[derive(Debug)]
53enum VisitError {
54    TooLong(usize),
55    TooShort(usize),
56    Overflow(u32),
57    Alias,
58    ByteThreeContinues,
59}
60
61impl VisitError {
62    fn into_de_error<'de, A>(self) -> A::Error
63    where
64        A: SeqAccess<'de>,
65    {
66        match self {
67            VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
68            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
69            VisitError::Overflow(val) => de::Error::invalid_value(
70                de::Unexpected::Unsigned(val as u64),
71                &"a value in the range [0, 65535]",
72            ),
73            VisitError::Alias => de::Error::invalid_value(
74                de::Unexpected::Other("alias encoding"),
75                &"strict form encoding",
76            ),
77            VisitError::ByteThreeContinues => de::Error::invalid_value(
78                de::Unexpected::Other("continue signal on byte-three"),
79                &"a terminal signal on or before byte-three",
80            ),
81        }
82    }
83}
84
85type VisitResult = Result<VisitStatus, VisitError>;
86
87const MAX_ENCODING_LENGTH: usize = 3;
88fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
89    if elem == 0 && nth_byte != 0 {
90        return Err(VisitError::Alias);
91    }
92
93    let val = u32::from(val);
94    let elem = u32::from(elem);
95    let elem_val = elem & 0x7f;
96    let elem_done = (elem & 0x80) == 0;
97
98    if nth_byte >= MAX_ENCODING_LENGTH {
99        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
100    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
101        return Err(VisitError::ByteThreeContinues);
102    }
103
104    let shift = u32::try_from(nth_byte)
105        .unwrap_or(u32::MAX)
106        .saturating_mul(7);
107    let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
108
109    let new_val = val | elem_val;
110    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
111
112    if elem_done {
113        Ok(VisitStatus::Done(val))
114    } else {
115        Ok(VisitStatus::More(val))
116    }
117}
118
119struct ShortU16Visitor;
120
121impl<'de> Visitor<'de> for ShortU16Visitor {
122    type Value = ShortU16;
123
124    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
125        formatter.write_str("a ShortU16")
126    }
127
128    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
129    where
130        A: SeqAccess<'de>,
131    {
132        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
133        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
134        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
135        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
136        let mut val: u16 = 0;
137        for nth_byte in 0..MAX_ENCODING_LENGTH {
138            let elem: u8 = seq.next_element()?.ok_or_else(|| {
139                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
140            })?;
141            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
142                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
143                VisitStatus::More(new_val) => val = new_val,
144            }
145        }
146
147        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
148    }
149}
150
151impl<'de> Deserialize<'de> for ShortU16 {
152    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
153    where
154        D: Deserializer<'de>,
155    {
156        deserializer.deserialize_tuple(3, ShortU16Visitor)
157    }
158}
159
160/// If you don't want to use the ShortVec newtype, you can do ShortVec
161/// serialization on an ordinary vector with the following field annotation:
162///
163/// #[serde(with = "short_vec")]
164///
165pub fn serialize<S: Serializer, T: Serialize>(
166    elements: &[T],
167    serializer: S,
168) -> Result<S::Ok, S::Error> {
169    // Pass a non-zero value to serialize_tuple() so that serde_json will
170    // generate an open bracket.
171    let mut seq = serializer.serialize_tuple(1)?;
172
173    let len = elements.len();
174    if len > u16::MAX as usize {
175        return Err(ser::Error::custom("length larger than u16"));
176    }
177    let short_len = ShortU16(len as u16);
178    seq.serialize_element(&short_len)?;
179
180    for element in elements {
181        seq.serialize_element(element)?;
182    }
183    seq.end()
184}
185
186struct ShortVecVisitor<T> {
187    _t: PhantomData<T>,
188}
189
190impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
191where
192    T: Deserialize<'de>,
193{
194    type Value = Vec<T>;
195
196    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
197        formatter.write_str("a Vec with a multi-byte length")
198    }
199
200    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
201    where
202        A: SeqAccess<'de>,
203    {
204        let short_len: ShortU16 = seq
205            .next_element()?
206            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
207        let len = short_len.0 as usize;
208
209        let mut result = Vec::with_capacity(len);
210        for i in 0..len {
211            let elem = seq
212                .next_element()?
213                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
214            result.push(elem);
215        }
216        Ok(result)
217    }
218}
219
220/// If you don't want to use the ShortVec newtype, you can do ShortVec
221/// deserialization on an ordinary vector with the following field annotation:
222///
223/// #[serde(with = "short_vec")]
224///
225pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
226where
227    D: Deserializer<'de>,
228    T: Deserialize<'de>,
229{
230    let visitor = ShortVecVisitor { _t: PhantomData };
231    deserializer.deserialize_tuple(usize::MAX, visitor)
232}
233
234pub struct ShortVec<T>(pub Vec<T>);
235
236impl<T: Serialize> Serialize for ShortVec<T> {
237    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
238    where
239        S: Serializer,
240    {
241        serialize(&self.0, serializer)
242    }
243}
244
245impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
246    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
247    where
248        D: Deserializer<'de>,
249    {
250        deserialize(deserializer).map(ShortVec)
251    }
252}
253
254/// Return the decoded value and how many bytes it consumed.
255#[allow(clippy::result_unit_err)]
256pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
257    let mut val = 0;
258    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
259        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
260            VisitStatus::More(new_val) => val = new_val,
261            VisitStatus::Done(new_val) => {
262                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
263            }
264        }
265    }
266    Err(())
267}
268
269#[cfg(test)]
270mod tests {
271    use assert_matches::assert_matches;
272    use bincode::{deserialize, serialize};
273
274    use super::*;
275
276    /// Return the serialized length.
277    fn encode_len(len: u16) -> Vec<u8> {
278        bincode::serialize(&ShortU16(len)).unwrap()
279    }
280
281    fn assert_len_encoding(len: u16, bytes: &[u8]) {
282        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
283        assert_eq!(
284            decode_shortu16_len(bytes).unwrap(),
285            (usize::from(len), bytes.len()),
286            "unexpected usize decoding"
287        );
288    }
289
290    #[test]
291    fn test_short_vec_encode_len() {
292        assert_len_encoding(0x0, &[0x0]);
293        assert_len_encoding(0x7f, &[0x7f]);
294        assert_len_encoding(0x80, &[0x80, 0x01]);
295        assert_len_encoding(0xff, &[0xff, 0x01]);
296        assert_len_encoding(0x100, &[0x80, 0x02]);
297        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
298        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
299    }
300
301    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
302        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
303    }
304
305    fn assert_bad_deserialized_value(bytes: &[u8]) {
306        assert!(deserialize::<ShortU16>(bytes).is_err());
307    }
308
309    #[test]
310    fn test_deserialize() {
311        assert_good_deserialized_value(0x0000, &[0x00]);
312        assert_good_deserialized_value(0x007f, &[0x7f]);
313        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
314        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
315        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
316        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
317        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
318        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
319        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
320
321        // aliases
322        // 0x0000
323        assert_bad_deserialized_value(&[0x80, 0x00]);
324        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
325        // 0x007f
326        assert_bad_deserialized_value(&[0xff, 0x00]);
327        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
328        // 0x0080
329        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
330        // 0x00ff
331        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
332        // 0x0100
333        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
334        // 0x07ff
335        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
336        // 0x3fff
337        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
338
339        // too short
340        assert_bad_deserialized_value(&[]);
341        assert_bad_deserialized_value(&[0x80]);
342
343        // too long
344        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
345
346        // too large
347        // 0x0001_0000
348        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
349        // 0x0001_8000
350        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
351    }
352
353    #[test]
354    fn test_short_vec_u8() {
355        let vec = ShortVec(vec![4u8; 32]);
356        let bytes = serialize(&vec).unwrap();
357        assert_eq!(bytes.len(), vec.0.len() + 1);
358
359        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
360        assert_eq!(vec.0, vec1.0);
361    }
362
363    #[test]
364    fn test_short_vec_u8_too_long() {
365        let vec = ShortVec(vec![4u8; u16::MAX as usize]);
366        assert_matches!(serialize(&vec), Ok(_));
367
368        let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
369        assert_matches!(serialize(&vec), Err(_));
370    }
371
372    #[test]
373    fn test_short_vec_json() {
374        let vec = ShortVec(vec![0, 1, 2]);
375        let s = serde_json::to_string(&vec).unwrap();
376        assert_eq!(s, "[[3],0,1,2]");
377    }
378
379    #[test]
380    fn test_short_vec_aliased_length() {
381        let bytes = [
382            0x81, 0x80, 0x00, // 3-byte alias of 1
383            0x00,
384        ];
385        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
386    }
387}