1#![cfg_attr(feature = "frozen-abi", feature(min_specialization))]
3#![allow(clippy::arithmetic_side_effects)]
4use std::{convert::TryFrom, fmt, marker::PhantomData};
5
6#[cfg(feature = "frozen-abi")]
7use rialo_frozen_abi_macro::AbiExample;
8use serde::{
9 de::{self, Deserializer, SeqAccess, Visitor},
10 ser::{self, SerializeTuple, Serializer},
11 Deserialize, Serialize,
12};
13
14#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
20pub struct ShortU16(pub u16);
21
22impl Serialize for ShortU16 {
23 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
24 where
25 S: Serializer,
26 {
27 let mut seq = serializer.serialize_tuple(1)?;
30
31 let mut rem_val = self.0;
32 loop {
33 let mut elem = (rem_val & 0x7f) as u8;
34 rem_val >>= 7;
35 if rem_val == 0 {
36 seq.serialize_element(&elem)?;
37 break;
38 } else {
39 elem |= 0x80;
40 seq.serialize_element(&elem)?;
41 }
42 }
43 seq.end()
44 }
45}
46
47enum VisitStatus {
48 Done(u16),
49 More(u16),
50}
51
52#[derive(Debug)]
53enum VisitError {
54 TooLong(usize),
55 TooShort(usize),
56 Overflow(u32),
57 Alias,
58 ByteThreeContinues,
59}
60
61impl VisitError {
62 fn into_de_error<'de, A>(self) -> A::Error
63 where
64 A: SeqAccess<'de>,
65 {
66 match self {
67 VisitError::TooLong(len) => de::Error::invalid_length(len, &"three or fewer bytes"),
68 VisitError::TooShort(len) => de::Error::invalid_length(len, &"more bytes"),
69 VisitError::Overflow(val) => de::Error::invalid_value(
70 de::Unexpected::Unsigned(val as u64),
71 &"a value in the range [0, 65535]",
72 ),
73 VisitError::Alias => de::Error::invalid_value(
74 de::Unexpected::Other("alias encoding"),
75 &"strict form encoding",
76 ),
77 VisitError::ByteThreeContinues => de::Error::invalid_value(
78 de::Unexpected::Other("continue signal on byte-three"),
79 &"a terminal signal on or before byte-three",
80 ),
81 }
82 }
83}
84
85type VisitResult = Result<VisitStatus, VisitError>;
86
87const MAX_ENCODING_LENGTH: usize = 3;
88fn visit_byte(elem: u8, val: u16, nth_byte: usize) -> VisitResult {
89 if elem == 0 && nth_byte != 0 {
90 return Err(VisitError::Alias);
91 }
92
93 let val = u32::from(val);
94 let elem = u32::from(elem);
95 let elem_val = elem & 0x7f;
96 let elem_done = (elem & 0x80) == 0;
97
98 if nth_byte >= MAX_ENCODING_LENGTH {
99 return Err(VisitError::TooLong(nth_byte.saturating_add(1)));
100 } else if nth_byte == MAX_ENCODING_LENGTH.saturating_sub(1) && !elem_done {
101 return Err(VisitError::ByteThreeContinues);
102 }
103
104 let shift = u32::try_from(nth_byte)
105 .unwrap_or(u32::MAX)
106 .saturating_mul(7);
107 let elem_val = elem_val.checked_shl(shift).unwrap_or(u32::MAX);
108
109 let new_val = val | elem_val;
110 let val = u16::try_from(new_val).map_err(|_| VisitError::Overflow(new_val))?;
111
112 if elem_done {
113 Ok(VisitStatus::Done(val))
114 } else {
115 Ok(VisitStatus::More(val))
116 }
117}
118
119struct ShortU16Visitor;
120
121impl<'de> Visitor<'de> for ShortU16Visitor {
122 type Value = ShortU16;
123
124 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
125 formatter.write_str("a ShortU16")
126 }
127
128 fn visit_seq<A>(self, mut seq: A) -> Result<ShortU16, A::Error>
129 where
130 A: SeqAccess<'de>,
131 {
132 let mut val: u16 = 0;
137 for nth_byte in 0..MAX_ENCODING_LENGTH {
138 let elem: u8 = seq.next_element()?.ok_or_else(|| {
139 VisitError::TooShort(nth_byte.saturating_add(1)).into_de_error::<A>()
140 })?;
141 match visit_byte(elem, val, nth_byte).map_err(|e| e.into_de_error::<A>())? {
142 VisitStatus::Done(new_val) => return Ok(ShortU16(new_val)),
143 VisitStatus::More(new_val) => val = new_val,
144 }
145 }
146
147 Err(VisitError::ByteThreeContinues.into_de_error::<A>())
148 }
149}
150
151impl<'de> Deserialize<'de> for ShortU16 {
152 fn deserialize<D>(deserializer: D) -> Result<ShortU16, D::Error>
153 where
154 D: Deserializer<'de>,
155 {
156 deserializer.deserialize_tuple(3, ShortU16Visitor)
157 }
158}
159
160pub fn serialize<S: Serializer, T: Serialize>(
166 elements: &[T],
167 serializer: S,
168) -> Result<S::Ok, S::Error> {
169 let mut seq = serializer.serialize_tuple(1)?;
172
173 let len = elements.len();
174 if len > u16::MAX as usize {
175 return Err(ser::Error::custom("length larger than u16"));
176 }
177 let short_len = ShortU16(len as u16);
178 seq.serialize_element(&short_len)?;
179
180 for element in elements {
181 seq.serialize_element(element)?;
182 }
183 seq.end()
184}
185
186struct ShortVecVisitor<T> {
187 _t: PhantomData<T>,
188}
189
190impl<'de, T> Visitor<'de> for ShortVecVisitor<T>
191where
192 T: Deserialize<'de>,
193{
194 type Value = Vec<T>;
195
196 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
197 formatter.write_str("a Vec with a multi-byte length")
198 }
199
200 fn visit_seq<A>(self, mut seq: A) -> Result<Vec<T>, A::Error>
201 where
202 A: SeqAccess<'de>,
203 {
204 let short_len: ShortU16 = seq
205 .next_element()?
206 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
207 let len = short_len.0 as usize;
208
209 let mut result = Vec::with_capacity(len);
210 for i in 0..len {
211 let elem = seq
212 .next_element()?
213 .ok_or_else(|| de::Error::invalid_length(i, &self))?;
214 result.push(elem);
215 }
216 Ok(result)
217 }
218}
219
220pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Vec<T>, D::Error>
226where
227 D: Deserializer<'de>,
228 T: Deserialize<'de>,
229{
230 let visitor = ShortVecVisitor { _t: PhantomData };
231 deserializer.deserialize_tuple(usize::MAX, visitor)
232}
233
234pub struct ShortVec<T>(pub Vec<T>);
235
236impl<T: Serialize> Serialize for ShortVec<T> {
237 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
238 where
239 S: Serializer,
240 {
241 serialize(&self.0, serializer)
242 }
243}
244
245impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec<T> {
246 fn deserialize<D>(deserializer: D) -> Result<ShortVec<T>, D::Error>
247 where
248 D: Deserializer<'de>,
249 {
250 deserialize(deserializer).map(ShortVec)
251 }
252}
253
254#[allow(clippy::result_unit_err)]
256pub fn decode_shortu16_len(bytes: &[u8]) -> Result<(usize, usize), ()> {
257 let mut val = 0;
258 for (nth_byte, byte) in bytes.iter().take(MAX_ENCODING_LENGTH).enumerate() {
259 match visit_byte(*byte, val, nth_byte).map_err(|_| ())? {
260 VisitStatus::More(new_val) => val = new_val,
261 VisitStatus::Done(new_val) => {
262 return Ok((usize::from(new_val), nth_byte.saturating_add(1)));
263 }
264 }
265 }
266 Err(())
267}
268
269#[cfg(test)]
270mod tests {
271 use assert_matches::assert_matches;
272 use bincode::{deserialize, serialize};
273
274 use super::*;
275
276 fn encode_len(len: u16) -> Vec<u8> {
278 bincode::serialize(&ShortU16(len)).unwrap()
279 }
280
281 fn assert_len_encoding(len: u16, bytes: &[u8]) {
282 assert_eq!(encode_len(len), bytes, "unexpected usize encoding");
283 assert_eq!(
284 decode_shortu16_len(bytes).unwrap(),
285 (usize::from(len), bytes.len()),
286 "unexpected usize decoding"
287 );
288 }
289
290 #[test]
291 fn test_short_vec_encode_len() {
292 assert_len_encoding(0x0, &[0x0]);
293 assert_len_encoding(0x7f, &[0x7f]);
294 assert_len_encoding(0x80, &[0x80, 0x01]);
295 assert_len_encoding(0xff, &[0xff, 0x01]);
296 assert_len_encoding(0x100, &[0x80, 0x02]);
297 assert_len_encoding(0x7fff, &[0xff, 0xff, 0x01]);
298 assert_len_encoding(0xffff, &[0xff, 0xff, 0x03]);
299 }
300
301 fn assert_good_deserialized_value(value: u16, bytes: &[u8]) {
302 assert_eq!(value, deserialize::<ShortU16>(bytes).unwrap().0);
303 }
304
305 fn assert_bad_deserialized_value(bytes: &[u8]) {
306 assert!(deserialize::<ShortU16>(bytes).is_err());
307 }
308
309 #[test]
310 fn test_deserialize() {
311 assert_good_deserialized_value(0x0000, &[0x00]);
312 assert_good_deserialized_value(0x007f, &[0x7f]);
313 assert_good_deserialized_value(0x0080, &[0x80, 0x01]);
314 assert_good_deserialized_value(0x00ff, &[0xff, 0x01]);
315 assert_good_deserialized_value(0x0100, &[0x80, 0x02]);
316 assert_good_deserialized_value(0x07ff, &[0xff, 0x0f]);
317 assert_good_deserialized_value(0x3fff, &[0xff, 0x7f]);
318 assert_good_deserialized_value(0x4000, &[0x80, 0x80, 0x01]);
319 assert_good_deserialized_value(0xffff, &[0xff, 0xff, 0x03]);
320
321 assert_bad_deserialized_value(&[0x80, 0x00]);
324 assert_bad_deserialized_value(&[0x80, 0x80, 0x00]);
325 assert_bad_deserialized_value(&[0xff, 0x00]);
327 assert_bad_deserialized_value(&[0xff, 0x80, 0x00]);
328 assert_bad_deserialized_value(&[0x80, 0x81, 0x00]);
330 assert_bad_deserialized_value(&[0xff, 0x81, 0x00]);
332 assert_bad_deserialized_value(&[0x80, 0x82, 0x00]);
334 assert_bad_deserialized_value(&[0xff, 0x8f, 0x00]);
336 assert_bad_deserialized_value(&[0xff, 0xff, 0x00]);
338
339 assert_bad_deserialized_value(&[]);
341 assert_bad_deserialized_value(&[0x80]);
342
343 assert_bad_deserialized_value(&[0x80, 0x80, 0x80, 0x00]);
345
346 assert_bad_deserialized_value(&[0x80, 0x80, 0x04]);
349 assert_bad_deserialized_value(&[0x80, 0x80, 0x06]);
351 }
352
353 #[test]
354 fn test_short_vec_u8() {
355 let vec = ShortVec(vec![4u8; 32]);
356 let bytes = serialize(&vec).unwrap();
357 assert_eq!(bytes.len(), vec.0.len() + 1);
358
359 let vec1: ShortVec<u8> = deserialize(&bytes).unwrap();
360 assert_eq!(vec.0, vec1.0);
361 }
362
363 #[test]
364 fn test_short_vec_u8_too_long() {
365 let vec = ShortVec(vec![4u8; u16::MAX as usize]);
366 assert_matches!(serialize(&vec), Ok(_));
367
368 let vec = ShortVec(vec![4u8; u16::MAX as usize + 1]);
369 assert_matches!(serialize(&vec), Err(_));
370 }
371
372 #[test]
373 fn test_short_vec_json() {
374 let vec = ShortVec(vec![0, 1, 2]);
375 let s = serde_json::to_string(&vec).unwrap();
376 assert_eq!(s, "[[3],0,1,2]");
377 }
378
379 #[test]
380 fn test_short_vec_aliased_length() {
381 let bytes = [
382 0x81, 0x80, 0x00, 0x00,
384 ];
385 assert!(deserialize::<ShortVec<u8>>(&bytes).is_err());
386 }
387}