1#![allow(clippy::integer_arithmetic)]
2use serde::{
3 de::{self, Deserializer, SeqAccess, Visitor},
4 ser::{self, SerializeTuple, Serializer},
5 {Deserialize, Serialize},
6};
7use std::{convert::TryFrom, fmt, marker::PhantomData};
8
9#[derive(AbiExample)]
15pub struct ShortU16(pub u16);
16
17impl Serialize for ShortU16 {
18 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
19 where
20 S: Serializer,
21 {
22 let mut seq = serializer.serialize_tuple(1)?;
25
26 let mut rem_val = self.0;
27 loop {
28 let mut elem = (rem_val & 0x7f) as u8;
29 rem_val >>= 7;
30 if rem_val == 0 {
31 seq.serialize_element(&elem)?;
32 break;
33 } else {
34 elem |= 0x80;
35 seq.serialize_element(&elem)?;
36 }
37 }
38 seq.end()
39 }
40}
41
42enum VisitStatus {
43 Done(u16),
44 More(u16),
45}
46
47#[derive(Debug)]
48enum VisitError {
49 TooLong(usize),
50 TooShort(usize),
51 Overflow(u32),
52 Alias,
53 ByteThreeContinues,
54}
55
56impl VisitError {
57 fn into_de_error<'de, A>(self) -> A::Error
58 where
59 A: SeqAccess<'de>,
60 {
61 match self {
62 VisitError::TooLong(len) => {
63 de::Error::invalid_length(len as usize, &"three or fewer bytes")
64 }
65 VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
66 VisitError::Overflow(val) => de::Error::invalid_value(
67 de::Unexpected::Unsigned(val as u64),
68 &"a value in the range [0, 65535]",
69 ),
70 VisitError::Alias => de::Error::invalid_value(
71 de::Unexpected::Other("alias encoding"),
72 &"strict form encoding",
73 ),
74 VisitError::ByteThreeContinues => de::Error::invalid_value(
75 de::Unexpected::Other("continue signal on byte-three"),
76 &"a terminal signal on or before byte-three",
77 ),
78 }
79 }
80}
81
82type VisitResult = Result<VisitStatus, VisitError>;
83
84const MAX_ENCODING_LENGTH: usize = 3;
85fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
86 if elem == 0 && nth_byte != 0 {
87 return Err(VisitError::Alias);
88 }
89
90 let val = u32::from(val);
91 let elem = u32::from(elem);
92 let elem_val = elem & 0x7f;
93 let elem_done = (elem & 0x80) == 0;
94
95 if nth_byte >= MAX_ENCODING_LENGTH {
96 return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
97 } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
98 return Err(VisitError::ByteThreeContinues);
99 }
100
101 let shift = u32::try_from(nth_byte)
102 .unwrap_or(std::u32::MAX)
103 .saturating_mul(7);
104 let elem_val = elem_val.checked_shl(shift).unwrap_or(std::u32::MAX);
105
106 let new_val = val | elem_val;
107 let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
108
109 if elem_done {
110 Ok(VisitStatus::Done(val))
111 } else {
112 Ok(VisitStatus::More(val))
113 }
114}
115
116struct ShortU16Visitor;
117
118impl<'de> Visitor<'de> for ShortU16Visitor {
119 type Value = ShortU16;
120
121 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
122 formatter.write_str("a ShortU16")
123 }
124
125 fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
126 where
127 A: SeqAccess<'de>,
128 {
129 let mut val: u16 = 0;
134 for nth_byte in 0..MAX_ENCODING_LENGTH {
135 let elem: u8 = seq.next_element()?.ok_or_else(|| {
136 VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
137 })?;
138 match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
139 VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
140 VisitStatus::More(new_val) => val = new_val,
141 }
142 }
143
144 Err(VisitError::ByteThreeContinues.into_de_error::<A>())
145 }
146}
147
148impl<'de> Deserialize<'de> for ShortU16 {
149 fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
150 where
151 D: Deserializer<'de>,
152 {
153 deserializer.deserialize_tuple(3, ShortU16Visitor)
154 }
155}
156
157pub fn serialize<S: Serializer, T: Serialize>(
163 elements: &[T],
164 serializer: S,
165) -> Result<S::Ok, S::Error> {
166 let mut seq = serializer.serialize_tuple(1)?;
169
170 let len = elements.len();
171 if len > std::u16::MAX as usize {
172 return Err(ser::Error::custom("length larger than u16"));
173 }
174 let short_len = ShortU16(len as u16);
175 seq.serialize_element(&short_len)?;
176
177 for element in elements {
178 seq.serialize_element(element)?;
179 }
180 seq.end()
181}
182
183struct ShortVecVisitor<T> {
184 _t: PhantomData<T>,
185}
186
187impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
188where
189 T: Deserialize<'de>,
190{
191 type Value = Vec<T>;
192
193 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194 formatter.write_str("a Vec with a multi-byte length")
195 }
196
197 fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
198 where
199 A: SeqAccess<'de>,
200 {
201 let short_len: ShortU16 = seq
202 .next_element()?
203 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
204 let len = short_len.0 as usize;
205
206 let mut result = Vec::with_capacity(len);
207 for i in 0..len {
208 let elem = seq
209 .next_element()?
210 .ok_or_else(|| de::Error::invalid_length(i, &self))?;
211 result.push(elem);
212 }
213 Ok(result)
214 }
215}
216
217pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
223where
224 D: Deserializer<'de>,
225 T: Deserialize<'de>,
226{
227 let visitor = ShortVecVisitor { _t: PhantomData };
228 deserializer.deserialize_tuple(std::usize::MAX, visitor)
229}
230
231pub struct ShortVec<T>(pub Vec<T>);
232
233impl<T: Serialize> Serialize for ShortVec<T> {
234 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
235 where
236 S: Serializer,
237 {
238 serialize(&self.0, serializer)
239 }
240}
241
242impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
243 fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
244 where
245 D: Deserializer<'de>,
246 {
247 deserialize(deserializer).map(ShortVec)
248 }
249}
250
251#[allow(clippy::result_unit_err)]
253pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
254 let mut val = 0;
255 for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
256 match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
257 VisitStatus::More(new_val) => val = new_val,
258 VisitStatus::Done(new_val) => {
259 return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
260 }
261 }
262 }
263 Err(())
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use assert_matches::assert_matches;
270 use bincode::{deserialize, serialize};
271
272 fn encode_len(len: u16) -> Vec<u8> {
274 bincode::serialize(&ShortU16(len)).unwrap()
275 }
276
277 fn assert_len_encoding(len: u16, bytes: &[u8]) {
278 assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
279 assert_eq!(
280 decode_shortu16_len(bytes).unwrap(),
281 (usize::from(len), bytes.len()),
282 "unexpected usize decoding"
283 );
284 }
285
286 #[test]
287 fn test_short_vec_encode_len() {
288 assert_len_encoding(0x0, &[0x0]);
289 assert_len_encoding(0x7f, &[0x7f]);
290 assert_len_encoding(0x80, &[0x80, 0x01]);
291 assert_len_encoding(0xff, &[0xff, 0x01]);
292 assert_len_encoding(0x100, &[0x80, 0x02]);
293 assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
294 assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
295 }
296
297 fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
298 assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
299 }
300
301 fn assert_bad_deserialized_value(bytes: &[u8]) {
302 assert!(deserialize::<ShortU16>(bytes).is_err());
303 }
304
305 #[test]
306 fn test_deserialize() {
307 assert_good_deserialized_value(0x0000, &[0x00]);
308 assert_good_deserialized_value(0x007f, &[0x7f]);
309 assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
310 assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
311 assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
312 assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
313 assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
314 assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
315 assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
316
317 assert_bad_deserialized_value(&[0x80, 0x00]);
320 assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
321 assert_bad_deserialized_value(&[0xff, 0x00]);
323 assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
324 assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
326 assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
328 assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
330 assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
332 assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
334
335 assert_bad_deserialized_value(&[]);
337 assert_bad_deserialized_value(&[0x80]);
338
339 assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
341
342 assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
345 assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
347 }
348
349 #[test]
350 fn test_short_vec_u8() {
351 let vec = ShortVec(vec![4u8; 32]);
352 let bytes = serialize(&vec).unwrap();
353 assert_eq!(bytes.len(), vec.0.len() + 1);
354
355 let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
356 assert_eq!(vec.0, vec1.0);
357 }
358
359 #[test]
360 fn test_short_vec_u8_too_long() {
361 let vec = ShortVec(vec![4u8; std::u16::MAX as usize]);
362 assert_matches!(serialize(&vec), Ok(_));
363
364 let vec = ShortVec(vec![4u8; std::u16::MAX as usize + 1]);
365 assert_matches!(serialize(&vec), Err(_));
366 }
367
368 #[test]
369 fn test_short_vec_json() {
370 let vec = ShortVec(vec![0, 1, 2]);
371 let s = serde_json::to_string(&vec).unwrap();
372 assert_eq!(s, "[[3],0,1,2]");
373 }
374
375 #[test]
376 fn test_short_vec_aliased_length() {
377 let bytes = [
378 0x81, 0x80, 0x00, 0x00,
380 ];
381 assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
382 }
383}