Skip to main content

pack_io/
impls.rs

1//! `Serialize` / `Deserialize` implementations for primitive, container, and
2//! collection types.
3//!
4//! ## Wire format (full reference: [`docs/WIRE_FORMAT.md`])
5//!
6//! - `u8` / `i8` — one byte each (fixed). `i8` is two's-complement.
7//! - `u16` / `u32` / `u64` / `u128` / `usize` — LEB128 varint. `usize` is
8//!   encoded through `u64`; on a 32-bit target a decoded value outside
9//!   `usize::MAX` is rejected with [`SerialError::IntegerOutOfRange`].
10//! - `i16` / `i32` / `i64` / `i128` / `isize` — ZigZag mapping followed by
11//!   LEB128 varint.
12//! - `bool` — one byte (`0x00` / `0x01`); any other byte is rejected.
13//! - `f32` / `f64` — IEEE 754 bit pattern, little-endian. NaN, ±Inf,
14//!   subnormals, and signed zeros all round-trip bit-for-bit.
15//! - `String` / `&str` — varint length prefix, then UTF-8 bytes.
16//! - `[T; N]` — `N` consecutive `T` encodings, no length prefix (the length
17//!   is in the type).
18//! - `Vec<T>` / `&[T]` — varint length prefix, then `len` consecutive `T`
19//!   encodings.
20//! - tuples (arity 1..=12) — fields concatenated in declaration order.
21//! - `Option<T>` — one tag byte (`0x00` = `None`, `0x01` = `Some`) followed
22//!   by the inner value when present.
23//! - `Result<T, E>` — one tag byte (`0x00` = `Ok`, `0x01` = `Err`) followed
24//!   by the inner value.
25//! - `()` (unit) — zero bytes.
26//! - `BTreeMap` / `BTreeSet` / `HashMap` / `HashSet` — varint count followed
27//!   by the entries sorted lexicographically by their **encoded key bytes**.
28//!   This canonical ordering means a `HashMap` and a `BTreeMap` holding the
29//!   same logical data encode to the same bytes, regardless of insertion
30//!   order or build-flag-dependent hash randomisation.
31//!
32//! [`docs/WIRE_FORMAT.md`]: https://github.com/jamesgober/pack-io/blob/main/docs/WIRE_FORMAT.md
33
34use alloc::collections::{BTreeMap, BTreeSet};
35use alloc::string::String;
36use alloc::vec::Vec;
37#[cfg(feature = "std")]
38use std::collections::{HashMap, HashSet};
39#[cfg(feature = "std")]
40use std::hash::{BuildHasher, Hash};
41
42use crate::codec::{Decode, Encode, Encoder};
43use crate::error::{Result, SerialError};
44use crate::traits::{Deserialize, Serialize};
45use crate::varint;
46
47// ---------------------------------------------------------------------------
48// Unsigned integers
49// ---------------------------------------------------------------------------
50
51impl Serialize for u8 {
52    #[inline]
53    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
54        encoder.write_byte(*self)
55    }
56}
57
58impl Deserialize for u8 {
59    #[inline]
60    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
61        decoder.read_byte()
62    }
63}
64
65impl Serialize for u16 {
66    #[inline]
67    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
68        encoder.write_varint_u64(u64::from(*self))
69    }
70}
71
72impl Deserialize for u16 {
73    #[inline]
74    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
75        let value = decoder.read_varint_u64()?;
76        u16::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
77    }
78}
79
80impl Serialize for u32 {
81    #[inline]
82    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
83        encoder.write_varint_u64(u64::from(*self))
84    }
85}
86
87impl Deserialize for u32 {
88    #[inline]
89    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
90        let value = decoder.read_varint_u64()?;
91        u32::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
92    }
93}
94
95impl Serialize for u64 {
96    #[inline]
97    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
98        encoder.write_varint_u64(*self)
99    }
100}
101
102impl Deserialize for u64 {
103    #[inline]
104    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
105        decoder.read_varint_u64()
106    }
107}
108
109impl Serialize for u128 {
110    #[inline]
111    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
112        encoder.write_varint_u128(*self)
113    }
114}
115
116impl Deserialize for u128 {
117    #[inline]
118    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
119        decoder.read_varint_u128()
120    }
121}
122
123impl Serialize for usize {
124    #[inline]
125    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
126        encoder.write_varint_u64(*self as u64)
127    }
128}
129
130impl Deserialize for usize {
131    #[inline]
132    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
133        let value = decoder.read_varint_u64()?;
134        usize::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Signed integers — ZigZag + varint
140// ---------------------------------------------------------------------------
141
142impl Serialize for i8 {
143    #[inline]
144    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
145        encoder.write_byte(*self as u8)
146    }
147}
148
149impl Deserialize for i8 {
150    #[inline]
151    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
152        Ok(decoder.read_byte()? as i8)
153    }
154}
155
156impl Serialize for i16 {
157    #[inline]
158    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
159        encoder.write_varint_u64(varint::zigzag_encode_i64(i64::from(*self)))
160    }
161}
162
163impl Deserialize for i16 {
164    #[inline]
165    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
166        let value = varint::zigzag_decode_i64(decoder.read_varint_u64()?);
167        i16::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
168    }
169}
170
171impl Serialize for i32 {
172    #[inline]
173    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
174        encoder.write_varint_u64(varint::zigzag_encode_i64(i64::from(*self)))
175    }
176}
177
178impl Deserialize for i32 {
179    #[inline]
180    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
181        let value = varint::zigzag_decode_i64(decoder.read_varint_u64()?);
182        i32::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
183    }
184}
185
186impl Serialize for i64 {
187    #[inline]
188    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
189        encoder.write_varint_u64(varint::zigzag_encode_i64(*self))
190    }
191}
192
193impl Deserialize for i64 {
194    #[inline]
195    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
196        Ok(varint::zigzag_decode_i64(decoder.read_varint_u64()?))
197    }
198}
199
200impl Serialize for i128 {
201    #[inline]
202    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
203        encoder.write_varint_u128(varint::zigzag_encode_i128(*self))
204    }
205}
206
207impl Deserialize for i128 {
208    #[inline]
209    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
210        Ok(varint::zigzag_decode_i128(decoder.read_varint_u128()?))
211    }
212}
213
214impl Serialize for isize {
215    #[inline]
216    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
217        encoder.write_varint_u64(varint::zigzag_encode_i64(*self as i64))
218    }
219}
220
221impl Deserialize for isize {
222    #[inline]
223    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
224        let value = varint::zigzag_decode_i64(decoder.read_varint_u64()?);
225        isize::try_from(value).map_err(|_| SerialError::IntegerOutOfRange)
226    }
227}
228
229// ---------------------------------------------------------------------------
230// Bool
231// ---------------------------------------------------------------------------
232
233impl Serialize for bool {
234    #[inline]
235    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
236        encoder.write_byte(u8::from(*self))
237    }
238}
239
240impl Deserialize for bool {
241    #[inline]
242    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
243        match decoder.read_byte()? {
244            0x00 => Ok(false),
245            0x01 => Ok(true),
246            other => Err(SerialError::InvalidBool { byte: other }),
247        }
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Floats — IEEE 754 bit pattern, little-endian
253// ---------------------------------------------------------------------------
254
255impl Serialize for f32 {
256    #[inline]
257    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
258        encoder.write_bytes(&self.to_bits().to_le_bytes())
259    }
260}
261
262impl Deserialize for f32 {
263    #[inline]
264    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
265        let mut buf = [0u8; 4];
266        decoder.read_into(&mut buf)?;
267        Ok(f32::from_bits(u32::from_le_bytes(buf)))
268    }
269}
270
271impl Serialize for f64 {
272    #[inline]
273    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
274        encoder.write_bytes(&self.to_bits().to_le_bytes())
275    }
276}
277
278impl Deserialize for f64 {
279    #[inline]
280    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
281        let mut buf = [0u8; 8];
282        decoder.read_into(&mut buf)?;
283        Ok(f64::from_bits(u64::from_le_bytes(buf)))
284    }
285}
286
287// ---------------------------------------------------------------------------
288// String / &str
289// ---------------------------------------------------------------------------
290
291impl Serialize for str {
292    #[inline]
293    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
294        let bytes = self.as_bytes();
295        encoder.write_varint_u64(bytes.len() as u64)?;
296        encoder.write_bytes(bytes)
297    }
298}
299
300impl Serialize for String {
301    #[inline]
302    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
303        Serialize::serialize(self.as_str(), encoder)
304    }
305}
306
307impl Deserialize for String {
308    #[inline]
309    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
310        let bytes = decoder.read_length_prefixed()?;
311        String::from_utf8(bytes).map_err(|_| SerialError::InvalidUtf8)
312    }
313}
314
315// ---------------------------------------------------------------------------
316// Slices and Vec<T>
317// ---------------------------------------------------------------------------
318
319impl<T: Serialize> Serialize for [T] {
320    #[inline]
321    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
322        encoder.write_varint_u64(self.len() as u64)?;
323        for item in self {
324            item.serialize(encoder)?;
325        }
326        Ok(())
327    }
328}
329
330impl<T: Serialize> Serialize for Vec<T> {
331    #[inline]
332    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
333        Serialize::serialize(self.as_slice(), encoder)
334    }
335}
336
337impl<T: Deserialize> Deserialize for Vec<T> {
338    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
339        let declared = decoder.read_varint_u64()?;
340        let len = guard_element_count::<T, _>(declared, decoder)?;
341        let mut out = Vec::with_capacity(initial_capacity(len));
342        for _ in 0..len {
343            out.push(T::deserialize(decoder)?);
344        }
345        Ok(out)
346    }
347}
348
349// ---------------------------------------------------------------------------
350// Fixed-size arrays — [T; N]
351// ---------------------------------------------------------------------------
352
353impl<T: Serialize, const N: usize> Serialize for [T; N] {
354    #[inline]
355    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
356        for item in self {
357            item.serialize(encoder)?;
358        }
359        Ok(())
360    }
361}
362
363impl<T: Deserialize, const N: usize> Deserialize for [T; N] {
364    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
365        let mut out: Vec<T> = Vec::with_capacity(N);
366        for _ in 0..N {
367            out.push(T::deserialize(decoder)?);
368        }
369        out.try_into().map_err(|_| SerialError::IntegerOutOfRange)
370    }
371}
372
373// ---------------------------------------------------------------------------
374// Tuples — arity 0..=12
375// ---------------------------------------------------------------------------
376
377impl Serialize for () {
378    #[inline]
379    fn serialize<E: Encode + ?Sized>(&self, _encoder: &mut E) -> Result<()> {
380        Ok(())
381    }
382}
383
384impl Deserialize for () {
385    #[inline]
386    fn deserialize<D: Decode + ?Sized>(_decoder: &mut D) -> Result<Self> {
387        Ok(())
388    }
389}
390
391macro_rules! impl_tuple {
392    ($($name:ident: $idx:tt),+) => {
393        impl<$($name: Serialize),+> Serialize for ($($name,)+) {
394            #[inline]
395            fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
396                $( self.$idx.serialize(encoder)?; )+
397                Ok(())
398            }
399        }
400
401        impl<$($name: Deserialize),+> Deserialize for ($($name,)+) {
402            #[inline]
403            fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
404                Ok(( $( $name::deserialize(decoder)?, )+ ))
405            }
406        }
407    };
408}
409
410impl_tuple!(T0: 0);
411impl_tuple!(T0: 0, T1: 1);
412impl_tuple!(T0: 0, T1: 1, T2: 2);
413impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3);
414impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4);
415impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5);
416impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6);
417impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6, T7: 7);
418impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6, T7: 7, T8: 8);
419impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6, T7: 7, T8: 8, T9: 9);
420impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6, T7: 7, T8: 8, T9: 9, T10: 10);
421impl_tuple!(T0: 0, T1: 1, T2: 2, T3: 3, T4: 4, T5: 5, T6: 6, T7: 7, T8: 8, T9: 9, T10: 10, T11: 11);
422
423// ---------------------------------------------------------------------------
424// Option<T>
425// ---------------------------------------------------------------------------
426
427impl<T: Serialize> Serialize for Option<T> {
428    #[inline]
429    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
430        match self {
431            None => encoder.write_byte(0x00),
432            Some(value) => {
433                encoder.write_byte(0x01)?;
434                value.serialize(encoder)
435            }
436        }
437    }
438}
439
440impl<T: Deserialize> Deserialize for Option<T> {
441    #[inline]
442    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
443        match decoder.read_byte()? {
444            0x00 => Ok(None),
445            0x01 => Ok(Some(T::deserialize(decoder)?)),
446            tag => Err(SerialError::InvalidTag {
447                kind: "Option",
448                tag,
449            }),
450        }
451    }
452}
453
454// ---------------------------------------------------------------------------
455// Result<T, E>
456// ---------------------------------------------------------------------------
457
458impl<T: Serialize, E: Serialize> Serialize for core::result::Result<T, E> {
459    #[inline]
460    fn serialize<Enc: Encode + ?Sized>(&self, encoder: &mut Enc) -> Result<()> {
461        match self {
462            Ok(value) => {
463                encoder.write_byte(0x00)?;
464                value.serialize(encoder)
465            }
466            Err(err) => {
467                encoder.write_byte(0x01)?;
468                err.serialize(encoder)
469            }
470        }
471    }
472}
473
474impl<T: Deserialize, E: Deserialize> Deserialize for core::result::Result<T, E> {
475    #[inline]
476    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
477        match decoder.read_byte()? {
478            0x00 => Ok(Ok(T::deserialize(decoder)?)),
479            0x01 => Ok(Err(E::deserialize(decoder)?)),
480            tag => Err(SerialError::InvalidTag {
481                kind: "Result",
482                tag,
483            }),
484        }
485    }
486}
487
488// ---------------------------------------------------------------------------
489// References — &T forwards to T's Serialize impl
490// ---------------------------------------------------------------------------
491
492impl<T: Serialize + ?Sized> Serialize for &T {
493    #[inline]
494    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
495        (**self).serialize(encoder)
496    }
497}
498
499// ---------------------------------------------------------------------------
500// Map and set collections
501// ---------------------------------------------------------------------------
502//
503// Encoding contract: `varint(count) ++ sorted_entries`, where entries are
504// sorted lexicographically by their **encoded key bytes**. This means a
505// `HashMap` and a `BTreeMap` holding the same logical data encode to the
506// same bytes. Hash-randomisation across runs and insertion order are both
507// irrelevant to the output — the byte-determinism contract holds.
508
509/// Encode `count` entries as `varint(count) ++ each (key, value) pair`,
510/// where entries are pre-sorted by encoded-key bytes.
511///
512/// `count` is a fresh ascending iteration over the source collection. The
513/// helper encodes each `(K, V)` pair to a temporary `Vec<u8>` (capturing the
514/// length of the key portion so the sort step does not re-encode), sorts
515/// those byte representations, then concatenates them onto `encoder`.
516fn encode_map_like<K, V, I, E>(count: usize, entries: I, encoder: &mut E) -> Result<()>
517where
518    K: Serialize,
519    V: Serialize,
520    I: IntoIterator<Item = (K, V)>,
521    E: Encode + ?Sized,
522{
523    let mut buffered: Vec<(usize, Vec<u8>)> = Vec::with_capacity(count);
524    for (k, v) in entries {
525        let mut tmp = Encoder::new();
526        k.serialize(&mut tmp)?;
527        let key_len = tmp.as_bytes().len();
528        v.serialize(&mut tmp)?;
529        buffered.push((key_len, tmp.into_inner()));
530    }
531    buffered.sort_by(|a, b| {
532        let ka = &a.1[..a.0];
533        let kb = &b.1[..b.0];
534        ka.cmp(kb)
535    });
536    encoder.write_varint_u64(count as u64)?;
537    for (_, bytes) in &buffered {
538        encoder.write_bytes(bytes)?;
539    }
540    Ok(())
541}
542
543/// Encode `count` set elements as `varint(count) ++ sorted_elements`.
544fn encode_set_like<T, I, E>(count: usize, items: I, encoder: &mut E) -> Result<()>
545where
546    T: Serialize,
547    I: IntoIterator<Item = T>,
548    E: Encode + ?Sized,
549{
550    let mut buffered: Vec<Vec<u8>> = Vec::with_capacity(count);
551    for item in items {
552        let mut tmp = Encoder::new();
553        item.serialize(&mut tmp)?;
554        buffered.push(tmp.into_inner());
555    }
556    buffered.sort();
557    encoder.write_varint_u64(count as u64)?;
558    for bytes in &buffered {
559        encoder.write_bytes(bytes)?;
560    }
561    Ok(())
562}
563
564impl<K, V> Serialize for BTreeMap<K, V>
565where
566    K: Serialize,
567    V: Serialize,
568{
569    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
570        encode_map_like(self.len(), self.iter(), encoder)
571    }
572}
573
574impl<K, V> Deserialize for BTreeMap<K, V>
575where
576    K: Deserialize + Ord,
577    V: Deserialize,
578{
579    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
580        let declared = decoder.read_varint_u64()?;
581        let len = guard_element_count::<(K, V), _>(declared, decoder)?;
582        let mut out = BTreeMap::new();
583        for _ in 0..len {
584            let k = K::deserialize(decoder)?;
585            let v = V::deserialize(decoder)?;
586            let _ = out.insert(k, v);
587        }
588        Ok(out)
589    }
590}
591
592impl<T> Serialize for BTreeSet<T>
593where
594    T: Serialize,
595{
596    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
597        encode_set_like(self.len(), self.iter(), encoder)
598    }
599}
600
601impl<T> Deserialize for BTreeSet<T>
602where
603    T: Deserialize + Ord,
604{
605    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
606        let declared = decoder.read_varint_u64()?;
607        let len = guard_element_count::<T, _>(declared, decoder)?;
608        let mut out = BTreeSet::new();
609        for _ in 0..len {
610            let _ = out.insert(T::deserialize(decoder)?);
611        }
612        Ok(out)
613    }
614}
615
616#[cfg(feature = "std")]
617impl<K, V, S> Serialize for HashMap<K, V, S>
618where
619    K: Serialize,
620    V: Serialize,
621{
622    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
623        encode_map_like(self.len(), self.iter(), encoder)
624    }
625}
626
627#[cfg(feature = "std")]
628impl<K, V, S> Deserialize for HashMap<K, V, S>
629where
630    K: Deserialize + Hash + Eq,
631    V: Deserialize,
632    S: BuildHasher + Default,
633{
634    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
635        let declared = decoder.read_varint_u64()?;
636        let len = guard_element_count::<(K, V), _>(declared, decoder)?;
637        let mut out = HashMap::with_capacity_and_hasher(initial_capacity(len), S::default());
638        for _ in 0..len {
639            let k = K::deserialize(decoder)?;
640            let v = V::deserialize(decoder)?;
641            let _ = out.insert(k, v);
642        }
643        Ok(out)
644    }
645}
646
647#[cfg(feature = "std")]
648impl<T, S> Serialize for HashSet<T, S>
649where
650    T: Serialize,
651{
652    fn serialize<E: Encode + ?Sized>(&self, encoder: &mut E) -> Result<()> {
653        encode_set_like(self.len(), self.iter(), encoder)
654    }
655}
656
657#[cfg(feature = "std")]
658impl<T, S> Deserialize for HashSet<T, S>
659where
660    T: Deserialize + Hash + Eq,
661    S: BuildHasher + Default,
662{
663    fn deserialize<D: Decode + ?Sized>(decoder: &mut D) -> Result<Self> {
664        let declared = decoder.read_varint_u64()?;
665        let len = guard_element_count::<T, _>(declared, decoder)?;
666        let mut out = HashSet::with_capacity_and_hasher(initial_capacity(len), S::default());
667        for _ in 0..len {
668            let _ = out.insert(T::deserialize(decoder)?);
669        }
670        Ok(out)
671    }
672}
673
674/// Validate `declared` (an element count) against the decoder's
675/// `max_alloc`, treating each element as occupying at least one byte. This
676/// prevents the obvious "declare `u64::MAX` elements, force a giant
677/// `Vec::with_capacity`" attack — declaring more elements than the decoder
678/// could ever supply bytes for is refused before we allocate.
679#[inline]
680fn guard_element_count<T, D: Decode + ?Sized>(declared: u64, decoder: &D) -> Result<usize> {
681    let max = decoder.max_alloc() as u64;
682    if declared > max {
683        return Err(SerialError::InvalidLength {
684            declared,
685            remaining: 0,
686        });
687    }
688    // Type tag silences the unused-type-parameter lint and documents intent.
689    let _phantom: core::marker::PhantomData<T> = core::marker::PhantomData;
690    usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)
691}
692
693/// Cap the **initial** capacity of a collection allocation regardless of
694/// the declared element count.
695///
696/// `guard_element_count` only validates that the declared count fits within
697/// `max_alloc`, but `max_alloc` is denominated in bytes-per-value, while a
698/// `HashMap` slot can occupy 30–50 bytes including hash-table overhead.
699/// A declared count near `max_alloc` would therefore trigger a multi-tens-
700/// of-gigabytes pre-allocation in `with_capacity` — even though the rest of
701/// the decode would fail almost immediately on `UnexpectedEof`.
702///
703/// Capping the initial capacity lets legitimate large collections grow
704/// naturally during the decode loop, while hostile inputs fail fast on the
705/// first missing byte instead of OOMing the host first.
706#[inline]
707fn initial_capacity(declared: usize) -> usize {
708    /// Big enough to avoid most grow-and-copy churn for ordinary-sized
709    /// collections; small enough that an attacker cannot force a multi-GiB
710    /// allocation by sending an inflated count.
711    const INITIAL_CAPACITY_CAP: usize = 4096;
712    declared.min(INITIAL_CAPACITY_CAP)
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use crate::{decode, encode};
719    use alloc::vec;
720
721    fn round_trip<T>(value: T)
722    where
723        T: Serialize + Deserialize + PartialEq + core::fmt::Debug,
724    {
725        let bytes = encode(&value).expect("encode");
726        let back: T = decode(&bytes).expect("decode");
727        assert_eq!(back, value);
728    }
729
730    #[test]
731    fn u8_round_trips() {
732        for v in [0u8, 1, 127, 128, 255] {
733            round_trip(v);
734        }
735    }
736
737    #[test]
738    fn u64_round_trips() {
739        for v in [0u64, 1, u32::MAX as u64, u64::MAX] {
740            round_trip(v);
741        }
742    }
743
744    #[test]
745    fn i64_round_trips() {
746        for v in [0i64, -1, 1, i64::MIN, i64::MAX] {
747            round_trip(v);
748        }
749    }
750
751    #[test]
752    fn bool_round_trips() {
753        round_trip(true);
754        round_trip(false);
755    }
756
757    #[test]
758    fn string_round_trips() {
759        for s in ["", "hello", "a longer string with some content"] {
760            round_trip(String::from(s));
761        }
762    }
763
764    #[test]
765    fn vec_u8_round_trips() {
766        round_trip::<Vec<u8>>(vec![]);
767        round_trip::<Vec<u8>>(vec![0u8, 1, 2, 3]);
768        round_trip::<Vec<u8>>(vec![0xffu8; 1024]);
769    }
770
771    #[test]
772    fn vec_u32_round_trips() {
773        round_trip::<Vec<u32>>(vec![]);
774        round_trip::<Vec<u32>>(vec![1, 2, u32::MAX]);
775    }
776
777    #[test]
778    fn vec_string_round_trips() {
779        round_trip(vec![
780            String::from("hello"),
781            String::from("world"),
782            String::new(),
783        ]);
784    }
785
786    #[test]
787    fn array_round_trips() {
788        round_trip([1u32, 2, 3, 4]);
789        round_trip([0u8; 0]);
790    }
791
792    #[test]
793    fn tuple_round_trips() {
794        round_trip((1u8, 2u16, 3u32));
795        round_trip((String::from("a"), 42u64, -3i32));
796    }
797
798    #[test]
799    fn option_round_trips() {
800        round_trip::<Option<u64>>(None);
801        round_trip::<Option<u64>>(Some(42));
802        round_trip::<Option<String>>(Some(String::from("hi")));
803    }
804
805    #[test]
806    fn result_round_trips() {
807        round_trip::<core::result::Result<u64, String>>(Ok(7));
808        round_trip::<core::result::Result<u64, String>>(Err(String::from("nope")));
809    }
810
811    #[test]
812    fn invalid_bool_byte_is_rejected() {
813        let err = decode::<bool>(&[0x7f]).expect_err("0x7f is not a bool");
814        assert!(matches!(err, SerialError::InvalidBool { byte: 0x7f }));
815    }
816
817    #[test]
818    fn string_with_invalid_utf8_is_rejected() {
819        let bytes = [0x02, 0xff, 0xff];
820        let err = decode::<String>(&bytes).expect_err("invalid UTF-8 should fail");
821        assert!(matches!(err, SerialError::InvalidUtf8));
822    }
823
824    #[test]
825    fn option_invalid_tag_is_rejected() {
826        let err = decode::<Option<u8>>(&[0x02]).expect_err("0x02 is not a valid Option tag");
827        assert!(matches!(
828            err,
829            SerialError::InvalidTag {
830                kind: "Option",
831                tag: 0x02
832            }
833        ));
834    }
835
836    #[test]
837    fn result_invalid_tag_is_rejected() {
838        let err =
839            decode::<core::result::Result<u8, u8>>(&[0x02]).expect_err("0x02 is not a Result tag");
840        assert!(matches!(
841            err,
842            SerialError::InvalidTag {
843                kind: "Result",
844                tag: 0x02
845            }
846        ));
847    }
848
849    #[test]
850    fn f64_round_trips_including_inf() {
851        for v in [
852            0.0f64,
853            -0.0,
854            1.0,
855            -1.0,
856            f64::MIN,
857            f64::MAX,
858            f64::INFINITY,
859            f64::NEG_INFINITY,
860        ] {
861            let bytes = encode(&v).unwrap();
862            let back: f64 = decode(&bytes).unwrap();
863            assert_eq!(back.to_bits(), v.to_bits());
864        }
865    }
866
867    #[test]
868    fn f64_nan_round_trips_bit_for_bit() {
869        let v = f64::NAN;
870        let bytes = encode(&v).unwrap();
871        let back: f64 = decode(&bytes).unwrap();
872        assert_eq!(back.to_bits(), v.to_bits());
873        assert!(back.is_nan());
874    }
875
876    #[test]
877    fn btreemap_round_trips() {
878        let mut m = BTreeMap::new();
879        let _ = m.insert(String::from("a"), 1u32);
880        let _ = m.insert(String::from("b"), 2);
881        let _ = m.insert(String::from("c"), 3);
882        round_trip(m);
883    }
884
885    #[test]
886    fn btreemap_empty_round_trips() {
887        round_trip(BTreeMap::<u32, u32>::new());
888    }
889
890    #[test]
891    fn btreeset_round_trips() {
892        let mut s = BTreeSet::new();
893        let _ = s.insert(1u32);
894        let _ = s.insert(2);
895        let _ = s.insert(3);
896        round_trip(s);
897    }
898
899    #[cfg(feature = "std")]
900    #[test]
901    fn hashmap_round_trips() {
902        let mut m: HashMap<String, u32> = HashMap::new();
903        let _ = m.insert(String::from("alpha"), 1);
904        let _ = m.insert(String::from("beta"), 2);
905        round_trip(m);
906    }
907
908    #[cfg(feature = "std")]
909    #[test]
910    fn hashset_round_trips() {
911        let mut s: HashSet<u32> = HashSet::new();
912        let _ = s.insert(1);
913        let _ = s.insert(7);
914        let _ = s.insert(42);
915        round_trip(s);
916    }
917
918    #[cfg(feature = "std")]
919    #[test]
920    fn hashmap_and_btreemap_encode_identically_for_same_data() {
921        let mut h: HashMap<String, u32> = HashMap::new();
922        let mut b: BTreeMap<String, u32> = BTreeMap::new();
923        for (k, v) in [("zeta", 5u32), ("alpha", 1), ("delta", 4), ("beta", 2)] {
924            let _ = h.insert(k.into(), v);
925            let _ = b.insert(k.into(), v);
926        }
927        assert_eq!(encode(&h).unwrap(), encode(&b).unwrap());
928    }
929
930    #[cfg(feature = "std")]
931    #[test]
932    fn hashmap_insertion_order_independent() {
933        let mut a: HashMap<u32, u32> = HashMap::new();
934        let _ = a.insert(1, 10);
935        let _ = a.insert(2, 20);
936        let _ = a.insert(3, 30);
937
938        let mut b: HashMap<u32, u32> = HashMap::new();
939        let _ = b.insert(3, 30);
940        let _ = b.insert(1, 10);
941        let _ = b.insert(2, 20);
942
943        assert_eq!(encode(&a).unwrap(), encode(&b).unwrap());
944    }
945
946    #[test]
947    fn collection_with_hostile_element_count_is_rejected() {
948        // varint(u64::MAX) is 10 bytes of 0xff... 0x01.
949        let bytes: [u8; 10] = [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01];
950        let err = decode::<Vec<u32>>(&bytes).expect_err("hostile count");
951        assert!(matches!(err, SerialError::InvalidLength { .. }));
952    }
953}