gemachain_program/
short_vec.rs

1#![allow(clippy::integer_arithmetic)]
2use serde::{
3    de::{self, Deserializer, SeqAccess, Visitor},
4    ser::{self, SerializeTuple, Serializer},
5    {Deserialize, Serialize},
6};
7use std::{convert::TryFrom, fmt, marker::PhantomData};
8
9/// Same as u16, but serialized with 1 to 3 bytes. If the value is above
10/// 0x7f, the top bit is set and the remaining value is stored in the next
11/// bytes. Each byte follows the same pattern until the 3rd byte. The 3rd
12/// byte, if needed, uses all 8 bits to store the last byte of the original
13/// value.
14#[derive(AbiExample)]
15pub struct ShortU16(pub u16);
16
17impl Serialize for ShortU16 {
18    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19    where
20        S: Serializer,
21    {
22        // Pass a non-zero value to serialize_tuple() so that serde_json will
23        // generate an open bracket.
24        let mut seq = serializer.serialize_tuple(1)?;
25
26        let mut rem_val = self.0;
27        loop {
28            let mut elem = (rem_val & 0x7f) as u8;
29            rem_val >>= 7;
30            if rem_val == 0 {
31                seq.serialize_element(&elem)?;
32                break;
33            } else {
34                elem |= 0x80;
35                seq.serialize_element(&elem)?;
36            }
37        }
38        seq.end()
39    }
40}
41
42enum VisitStatus {
43    Done(u16),
44    More(u16),
45}
46
47#[derive(Debug)]
48enum VisitError {
49    TooLong(usize),
50    TooShort(usize),
51    Overflow(u32),
52    Alias,
53    ByteThreeContinues,
54}
55
56impl VisitError {
57    fn into_de_error<'de, A>(self) -> A::Error
58    where
59        A: SeqAccess<'de>,
60    {
61        match self {
62            VisitError::TooLong(len) => {
63                de::Error::invalid_length(len as usize, &"three or fewer bytes")
64            }
65            VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
66            VisitError::Overflow(val) => de::Error::invalid_value(
67                de::Unexpected::Unsigned(val as u64),
68                &"a value in the range [0, 65535]",
69            ),
70            VisitError::Alias => de::Error::invalid_value(
71                de::Unexpected::Other("alias encoding"),
72                &"strict form encoding",
73            ),
74            VisitError::ByteThreeContinues => de::Error::invalid_value(
75                de::Unexpected::Other("continue signal on byte-three"),
76                &"a terminal signal on or before byte-three",
77            ),
78        }
79    }
80}
81
82type VisitResult = Result<VisitStatus, VisitError>;
83
84const MAX_ENCODING_LENGTH: usize = 3;
85fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
86    if elem == 0 && nth_byte != 0 {
87        return Err(VisitError::Alias);
88    }
89
90    let val = u32::from(val);
91    let elem = u32::from(elem);
92    let elem_val = elem & 0x7f;
93    let elem_done = (elem & 0x80) == 0;
94
95    if nth_byte >= MAX_ENCODING_LENGTH {
96        return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
97    } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
98        return Err(VisitError::ByteThreeContinues);
99    }
100
101    let shift = u32::try_from(nth_byte)
102        .unwrap_or(std::u32::MAX)
103        .saturating_mul(7);
104    let elem_val = elem_val.checked_shl(shift).unwrap_or(std::u32::MAX);
105
106    let new_val = val | elem_val;
107    let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
108
109    if elem_done {
110        Ok(VisitStatus::Done(val))
111    } else {
112        Ok(VisitStatus::More(val))
113    }
114}
115
116struct ShortU16Visitor;
117
118impl<'de> Visitor<'de> for ShortU16Visitor {
119    type Value = ShortU16;
120
121    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
122        formatter.write_str("a ShortU16")
123    }
124
125    fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
126    where
127        A: SeqAccess<'de>,
128    {
129        // Decodes an unsigned 16 bit integer one-to-one encoded as follows:
130        // 1 byte  : 0xxxxxxx                   => 00000000 0xxxxxxx :      0 -    127
131        // 2 bytes : 1xxxxxxx 0yyyyyyy          => 00yyyyyy yxxxxxxx :    128 - 16,383
132        // 3 bytes : 1xxxxxxx 1yyyyyyy 000000zz => zzyyyyyy yxxxxxxx : 16,384 - 65,535
133        let mut val: u16 = 0;
134        for nth_byte in 0..MAX_ENCODING_LENGTH {
135            let elem: u8 = seq.next_element()?.ok_or_else(|| {
136                VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
137            })?;
138            match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
139                VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
140                VisitStatus::More(new_val) => val = new_val,
141            }
142        }
143
144        Err(VisitError::ByteThreeContinues.into_de_error::<A>())
145    }
146}
147
148impl<'de> Deserialize<'de> for ShortU16 {
149    fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
150    where
151        D: Deserializer<'de>,
152    {
153        deserializer.deserialize_tuple(3, ShortU16Visitor)
154    }
155}
156
157/// If you don't want to use the ShortVec newtype, you can do ShortVec
158/// serialization on an ordinary vector with the following field annotation:
159///
160/// #[serde(with = "short_vec")]
161///
162pub fn serialize<S: Serializer, T: Serialize>(
163    elements: &[T],
164    serializer: S,
165) -> Result<S::Ok, S::Error> {
166    // Pass a non-zero value to serialize_tuple() so that serde_json will
167    // generate an open bracket.
168    let mut seq = serializer.serialize_tuple(1)?;
169
170    let len = elements.len();
171    if len > std::u16::MAX as usize {
172        return Err(ser::Error::custom("length larger than u16"));
173    }
174    let short_len = ShortU16(len as u16);
175    seq.serialize_element(&short_len)?;
176
177    for element in elements {
178        seq.serialize_element(element)?;
179    }
180    seq.end()
181}
182
183struct ShortVecVisitor<T> {
184    _t: PhantomData<T>,
185}
186
187impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
188where
189    T: Deserialize<'de>,
190{
191    type Value = Vec<T>;
192
193    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194        formatter.write_str("a Vec with a multi-byte length")
195    }
196
197    fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
198    where
199        A: SeqAccess<'de>,
200    {
201        let short_len: ShortU16 = seq
202            .next_element()?
203            .ok_or_else(|| de::Error::invalid_length(0, &self))?;
204        let len = short_len.0 as usize;
205
206        let mut result = Vec::with_capacity(len);
207        for i in 0..len {
208            let elem = seq
209                .next_element()?
210                .ok_or_else(|| de::Error::invalid_length(i, &self))?;
211            result.push(elem);
212        }
213        Ok(result)
214    }
215}
216
217/// If you don't want to use the ShortVec newtype, you can do ShortVec
218/// deserialization on an ordinary vector with the following field annotation:
219///
220/// #[serde(with = "short_vec")]
221///
222pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
223where
224    D: Deserializer<'de>,
225    T: Deserialize<'de>,
226{
227    let visitor = ShortVecVisitor { _t: PhantomData };
228    deserializer.deserialize_tuple(std::usize::MAX, visitor)
229}
230
231pub struct ShortVec<T>(pub Vec<T>);
232
233impl<T: Serialize> Serialize for ShortVec<T> {
234    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
235    where
236        S: Serializer,
237    {
238        serialize(&self.0, serializer)
239    }
240}
241
242impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
243    fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
244    where
245        D: Deserializer<'de>,
246    {
247        deserialize(deserializer).map(ShortVec)
248    }
249}
250
251/// Return the decoded value and how many bytes it consumed.
252#[allow(clippy::result_unit_err)]
253pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
254    let mut val = 0;
255    for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
256        match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
257            VisitStatus::More(new_val) => val = new_val,
258            VisitStatus::Done(new_val) => {
259                return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
260            }
261        }
262    }
263    Err(())
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use assert_matches::assert_matches;
270    use bincode::{deserialize, serialize};
271
272    /// Return the serialized length.
273    fn encode_len(len: u16) -> Vec<u8> {
274        bincode::serialize(&ShortU16(len)).unwrap()
275    }
276
277    fn assert_len_encoding(len: u16, bytes: &[u8]) {
278        assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
279        assert_eq!(
280            decode_shortu16_len(bytes).unwrap(),
281            (usize::from(len), bytes.len()),
282            "unexpected usize decoding"
283        );
284    }
285
286    #[test]
287    fn test_short_vec_encode_len() {
288        assert_len_encoding(0x0, &[0x0]);
289        assert_len_encoding(0x7f, &[0x7f]);
290        assert_len_encoding(0x80, &[0x80, 0x01]);
291        assert_len_encoding(0xff, &[0xff, 0x01]);
292        assert_len_encoding(0x100, &[0x80, 0x02]);
293        assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
294        assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
295    }
296
297    fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
298        assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
299    }
300
301    fn assert_bad_deserialized_value(bytes: &[u8]) {
302        assert!(deserialize::<ShortU16>(bytes).is_err());
303    }
304
305    #[test]
306    fn test_deserialize() {
307        assert_good_deserialized_value(0x0000, &[0x00]);
308        assert_good_deserialized_value(0x007f, &[0x7f]);
309        assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
310        assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
311        assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
312        assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
313        assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
314        assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
315        assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
316
317        // aliases
318        // 0x0000
319        assert_bad_deserialized_value(&[0x80, 0x00]);
320        assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
321        // 0x007f
322        assert_bad_deserialized_value(&[0xff, 0x00]);
323        assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
324        // 0x0080
325        assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
326        // 0x00ff
327        assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
328        // 0x0100
329        assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
330        // 0x07ff
331        assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
332        // 0x3fff
333        assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
334
335        // too short
336        assert_bad_deserialized_value(&[]);
337        assert_bad_deserialized_value(&[0x80]);
338
339        // too long
340        assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
341
342        // too large
343        // 0x0001_0000
344        assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
345        // 0x0001_8000
346        assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
347    }
348
349    #[test]
350    fn test_short_vec_u8() {
351        let vec = ShortVec(vec![4u8; 32]);
352        let bytes = serialize(&vec).unwrap();
353        assert_eq!(bytes.len(), vec.0.len() + 1);
354
355        let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
356        assert_eq!(vec.0, vec1.0);
357    }
358
359    #[test]
360    fn test_short_vec_u8_too_long() {
361        let vec = ShortVec(vec![4u8; std::u16::MAX as usize]);
362        assert_matches!(serialize(&vec), Ok(_));
363
364        let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]);
365        assert_matches!(serialize(&vec), Err(_));
366    }
367
368    #[test]
369    fn test_short_vec_json() {
370        let vec = ShortVec(vec![0, 1, 2]);
371        let s = serde_json::to_string(&vec).unwrap();
372        assert_eq!(s, "[[3],0,1,2]");
373    }
374
375    #[test]
376    fn test_short_vec_aliased_length() {
377        let bytes = [
378            0x81, 0x80, 0x00, // 3-byte alias of 1
379            0x00,
380        ];
381        assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
382    }
383}