Skip to main content

irontide_bencode/
de.rs

1use serde::de::{self, Visitor};
2
3use crate::error::{Error, Result};
4
5/// Bencode deserializer.
6///
7/// Parses bencode from a byte slice, supporting zero-copy deserialization
8/// of byte strings via `deserialize_bytes`.
9pub struct Deserializer<'de> {
10    input: &'de [u8],
11    pos: usize,
12    strict_order: bool,
13}
14
15impl<'de> Deserializer<'de> {
16    /// Create a new deserializer from a byte slice.
17    ///
18    /// By default, dictionary key ordering is enforced per BEP 3.
19    /// Use [`Deserializer::lenient`] to accept unsorted keys.
20    pub fn new(input: &'de [u8]) -> Self {
21        Deserializer {
22            input,
23            pos: 0,
24            strict_order: true,
25        }
26    }
27
28    /// Create a lenient deserializer that accepts unsorted dictionary keys.
29    ///
30    /// Many real-world clients send extension handshakes with unsorted keys.
31    /// Use this for parsing peer wire messages where strict BEP 3 compliance
32    /// would reject otherwise valid data.
33    pub fn lenient(input: &'de [u8]) -> Self {
34        Deserializer {
35            input,
36            pos: 0,
37            strict_order: false,
38        }
39    }
40
41    /// Verify all input has been consumed. Call after deserializing.
42    pub fn finish(&self) -> Result<()> {
43        if self.pos < self.input.len() {
44            Err(Error::TrailingData {
45                position: self.pos,
46                count: self.input.len() - self.pos,
47            })
48        } else {
49            Ok(())
50        }
51    }
52
53    fn peek(&self) -> Result<u8> {
54        self.input
55            .get(self.pos)
56            .copied()
57            .ok_or(Error::UnexpectedEof {
58                position: self.pos,
59                context: "expected more data".into(),
60            })
61    }
62
63    fn next(&mut self) -> Result<u8> {
64        let byte = self.peek()?;
65        self.pos += 1;
66        Ok(byte)
67    }
68
69    fn expect(&mut self, expected: u8) -> Result<()> {
70        let byte = self.next()?;
71        if byte != expected {
72            Err(Error::UnexpectedByte {
73                byte,
74                position: self.pos - 1,
75                expected: match expected {
76                    b'e' => "'e' (end marker)",
77                    b'i' => "'i' (integer start)",
78                    b'l' => "'l' (list start)",
79                    b'd' => "'d' (dict start)",
80                    b':' => "':' (string separator)",
81                    _ => "specific byte",
82                },
83            })
84        } else {
85            Ok(())
86        }
87    }
88
89    fn parse_integer_value(&mut self) -> Result<i64> {
90        let start = self.pos;
91
92        // Find the 'e' terminator
93        let end = self.input[self.pos..]
94            .iter()
95            .position(|&b| b == b'e')
96            .ok_or(Error::UnexpectedEof {
97                position: self.pos,
98                context: "unterminated integer".into(),
99            })?;
100
101        let num_bytes = &self.input[self.pos..self.pos + end];
102        self.pos += end + 1; // skip past 'e'
103
104        if num_bytes.is_empty() {
105            return Err(Error::InvalidInteger {
106                position: start,
107                detail: "empty integer".into(),
108            });
109        }
110
111        // Reject leading zeros (except "0" itself)
112        if num_bytes.len() > 1 && num_bytes[0] == b'0' {
113            return Err(Error::InvalidInteger {
114                position: start,
115                detail: "leading zero".into(),
116            });
117        }
118
119        // Reject negative zero
120        if num_bytes == b"-0" {
121            return Err(Error::InvalidInteger {
122                position: start,
123                detail: "negative zero".into(),
124            });
125        }
126
127        // Reject bare minus sign
128        if num_bytes == b"-" {
129            return Err(Error::InvalidInteger {
130                position: start,
131                detail: "bare minus sign".into(),
132            });
133        }
134
135        // Reject leading zeros in negative numbers
136        if num_bytes.len() > 2 && num_bytes[0] == b'-' && num_bytes[1] == b'0' {
137            return Err(Error::InvalidInteger {
138                position: start,
139                detail: "leading zero in negative".into(),
140            });
141        }
142
143        let s = std::str::from_utf8(num_bytes).map_err(|_| Error::InvalidInteger {
144            position: start,
145            detail: "non-ASCII integer".into(),
146        })?;
147
148        s.parse::<i64>().map_err(|e| Error::InvalidInteger {
149            position: start,
150            detail: e.to_string(),
151        })
152    }
153
154    fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
155        let start = self.pos;
156
157        // Parse length prefix
158        let colon = self.input[self.pos..]
159            .iter()
160            .position(|&b| b == b':')
161            .ok_or(Error::InvalidByteString {
162                position: start,
163                detail: "missing ':' separator".into(),
164            })?;
165
166        let len_bytes = &self.input[self.pos..self.pos + colon];
167        if len_bytes.is_empty() {
168            return Err(Error::InvalidByteString {
169                position: start,
170                detail: "empty length prefix".into(),
171            });
172        }
173
174        // Reject leading zeros in length (except "0" itself)
175        if len_bytes.len() > 1 && len_bytes[0] == b'0' {
176            return Err(Error::InvalidByteString {
177                position: start,
178                detail: "leading zero in length".into(),
179            });
180        }
181
182        let len_str = std::str::from_utf8(len_bytes).map_err(|_| Error::InvalidByteString {
183            position: start,
184            detail: "non-ASCII length".into(),
185        })?;
186
187        let len: usize =
188            len_str
189                .parse()
190                .map_err(|e: std::num::ParseIntError| Error::InvalidByteString {
191                    position: start,
192                    detail: e.to_string(),
193                })?;
194
195        self.pos += colon + 1; // skip past ':'
196
197        if self.pos + len > self.input.len() {
198            return Err(Error::UnexpectedEof {
199                position: self.pos,
200                context: format!(
201                    "byte string needs {len} bytes, only {} available",
202                    self.input.len() - self.pos
203                ),
204            });
205        }
206
207        let data = &self.input[self.pos..self.pos + len];
208        self.pos += len;
209        Ok(data)
210    }
211}
212
213impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
214    type Error = Error;
215
216    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
217        match self.peek()? {
218            b'i' => {
219                self.pos += 1;
220                let val = self.parse_integer_value()?;
221                visitor.visit_i64(val)
222            }
223            b'l' => self.deserialize_seq(visitor),
224            b'd' => self.deserialize_map(visitor),
225            b'0'..=b'9' => {
226                let data = self.parse_byte_string()?;
227                visitor.visit_borrowed_bytes(data)
228            }
229            byte => Err(Error::UnexpectedByte {
230                byte,
231                position: self.pos,
232                expected: "integer ('i'), string ('0'-'9'), list ('l'), or dict ('d')",
233            }),
234        }
235    }
236
237    fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
238        self.expect(b'i')?;
239        let val = self.parse_integer_value()?;
240        visitor.visit_bool(val != 0)
241    }
242
243    fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
244        self.expect(b'i')?;
245        visitor.visit_i64(self.parse_integer_value()?)
246    }
247
248    fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
249        self.expect(b'i')?;
250        visitor.visit_i64(self.parse_integer_value()?)
251    }
252
253    fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
254        self.expect(b'i')?;
255        visitor.visit_i64(self.parse_integer_value()?)
256    }
257
258    fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
259        self.expect(b'i')?;
260        visitor.visit_i64(self.parse_integer_value()?)
261    }
262
263    fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
264        self.expect(b'i')?;
265        visitor.visit_i64(self.parse_integer_value()?)
266    }
267
268    fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
269        self.expect(b'i')?;
270        visitor.visit_i64(self.parse_integer_value()?)
271    }
272
273    fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
274        self.expect(b'i')?;
275        visitor.visit_i64(self.parse_integer_value()?)
276    }
277
278    fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
279        self.expect(b'i')?;
280        visitor.visit_i64(self.parse_integer_value()?)
281    }
282
283    fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
284        Err(Error::Custom("bencode does not support floats".into()))
285    }
286
287    fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
288        Err(Error::Custom("bencode does not support floats".into()))
289    }
290
291    fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
292        let data = self.parse_byte_string()?;
293        let s = std::str::from_utf8(data)
294            .map_err(|_| Error::Custom("char is not valid UTF-8".into()))?;
295        let mut chars = s.chars();
296        let c = chars
297            .next()
298            .ok_or_else(|| Error::Custom("empty string for char".into()))?;
299        if chars.next().is_some() {
300            return Err(Error::Custom("multi-char string for char".into()));
301        }
302        visitor.visit_char(c)
303    }
304
305    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
306        let data = self.parse_byte_string()?;
307        let s = std::str::from_utf8(data).map_err(|_| {
308            Error::Custom("byte string is not valid UTF-8, use bytes instead".into())
309        })?;
310        visitor.visit_borrowed_str(s)
311    }
312
313    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
314        self.deserialize_str(visitor)
315    }
316
317    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
318        let data = self.parse_byte_string()?;
319        visitor.visit_borrowed_bytes(data)
320    }
321
322    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
323        self.deserialize_bytes(visitor)
324    }
325
326    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
327        visitor.visit_some(self)
328    }
329
330    fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
331        Err(Error::Custom("bencode does not support unit".into()))
332    }
333
334    fn deserialize_unit_struct<V: Visitor<'de>>(
335        self,
336        _name: &'static str,
337        _visitor: V,
338    ) -> Result<V::Value> {
339        Err(Error::Custom(
340            "bencode does not support unit structs".into(),
341        ))
342    }
343
344    fn deserialize_newtype_struct<V: Visitor<'de>>(
345        self,
346        _name: &'static str,
347        visitor: V,
348    ) -> Result<V::Value> {
349        visitor.visit_newtype_struct(self)
350    }
351
352    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
353        self.expect(b'l')?;
354        let value = visitor.visit_seq(SeqAccess { de: self })?;
355        self.expect(b'e')?;
356        Ok(value)
357    }
358
359    fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
360        self.deserialize_seq(visitor)
361    }
362
363    fn deserialize_tuple_struct<V: Visitor<'de>>(
364        self,
365        _name: &'static str,
366        _len: usize,
367        visitor: V,
368    ) -> Result<V::Value> {
369        self.deserialize_seq(visitor)
370    }
371
372    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
373        self.expect(b'd')?;
374        let strict_order = self.strict_order;
375        let value = visitor.visit_map(MapAccess {
376            de: self,
377            last_key: None,
378            strict_order,
379        })?;
380        self.expect(b'e')?;
381        Ok(value)
382    }
383
384    fn deserialize_struct<V: Visitor<'de>>(
385        self,
386        _name: &'static str,
387        _fields: &'static [&'static str],
388        visitor: V,
389    ) -> Result<V::Value> {
390        self.deserialize_map(visitor)
391    }
392
393    fn deserialize_enum<V: Visitor<'de>>(
394        self,
395        _name: &'static str,
396        _variants: &'static [&'static str],
397        visitor: V,
398    ) -> Result<V::Value> {
399        match self.peek()? {
400            b'd' => {
401                // Dict-based enum variant: d<variant-name><value>e
402                self.pos += 1;
403                let value = visitor.visit_enum(EnumAccess { de: self })?;
404                self.expect(b'e')?;
405                Ok(value)
406            }
407            _ => {
408                // String-based unit variant
409                visitor.visit_enum(UnitVariantAccess { de: self })
410            }
411        }
412    }
413
414    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
415        self.deserialize_str(visitor)
416    }
417
418    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
419        // Skip over any bencode value
420        self.skip_value()?;
421        visitor.visit_unit()
422    }
423}
424
425impl Deserializer<'_> {
426    /// Skip over a complete bencode value without allocating.
427    fn skip_value(&mut self) -> Result<()> {
428        match self.peek()? {
429            b'i' => {
430                self.pos += 1;
431                self.parse_integer_value()?;
432                Ok(())
433            }
434            b'l' => {
435                self.pos += 1;
436                while self.peek()? != b'e' {
437                    self.skip_value()?;
438                }
439                self.pos += 1;
440                Ok(())
441            }
442            b'd' => {
443                self.pos += 1;
444                while self.peek()? != b'e' {
445                    self.parse_byte_string()?; // key
446                    self.skip_value()?; // value
447                }
448                self.pos += 1;
449                Ok(())
450            }
451            b'0'..=b'9' => {
452                self.parse_byte_string()?;
453                Ok(())
454            }
455            byte => Err(Error::UnexpectedByte {
456                byte,
457                position: self.pos,
458                expected: "bencode value",
459            }),
460        }
461    }
462}
463
464struct SeqAccess<'a, 'de> {
465    de: &'a mut Deserializer<'de>,
466}
467
468impl<'de> de::SeqAccess<'de> for SeqAccess<'_, 'de> {
469    type Error = Error;
470
471    fn next_element_seed<T: de::DeserializeSeed<'de>>(
472        &mut self,
473        seed: T,
474    ) -> Result<Option<T::Value>> {
475        if self.de.peek()? == b'e' {
476            return Ok(None);
477        }
478        seed.deserialize(&mut *self.de).map(Some)
479    }
480}
481
482struct MapAccess<'a, 'de> {
483    de: &'a mut Deserializer<'de>,
484    last_key: Option<Vec<u8>>,
485    strict_order: bool,
486}
487
488impl<'de> de::MapAccess<'de> for MapAccess<'_, 'de> {
489    type Error = Error;
490
491    fn next_key_seed<K: de::DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>> {
492        if self.de.peek()? == b'e' {
493            return Ok(None);
494        }
495
496        // Read key bytes for sort-order validation
497        let key_start = self.de.pos;
498        let key_data = self.de.parse_byte_string()?;
499        let key_vec = key_data.to_vec();
500
501        // Validate sorted order (strict mode only — lenient accepts unsorted)
502        if let Some(ref last) = self.last_key
503            && self.strict_order
504            && key_vec <= *last
505        {
506            return Err(Error::UnsortedKeys {
507                position: key_start,
508            });
509        }
510        self.last_key = Some(key_vec);
511
512        // Now deserialize the key via serde using a special deserializer
513        // that yields the already-parsed string
514        let key_de = BorrowedStrDeserializer(key_data);
515        seed.deserialize(key_de).map(Some)
516    }
517
518    fn next_value_seed<V: de::DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value> {
519        seed.deserialize(&mut *self.de)
520    }
521}
522
523/// Minimal deserializer that yields a single already-parsed byte string.
524struct BorrowedStrDeserializer<'de>(&'de [u8]);
525
526impl<'de> de::Deserializer<'de> for BorrowedStrDeserializer<'de> {
527    type Error = Error;
528
529    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
530        visitor.visit_borrowed_bytes(self.0)
531    }
532
533    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
534        let s = std::str::from_utf8(self.0)
535            .map_err(|_| Error::Custom("dict key is not valid UTF-8".into()))?;
536        visitor.visit_borrowed_str(s)
537    }
538
539    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
540        self.deserialize_str(visitor)
541    }
542
543    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
544        visitor.visit_borrowed_bytes(self.0)
545    }
546
547    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
548        self.deserialize_bytes(visitor)
549    }
550
551    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
552        self.deserialize_str(visitor)
553    }
554
555    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
556        visitor.visit_unit()
557    }
558
559    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
560        visitor.visit_some(self)
561    }
562
563    fn deserialize_newtype_struct<V: Visitor<'de>>(
564        self,
565        _name: &'static str,
566        visitor: V,
567    ) -> Result<V::Value> {
568        visitor.visit_newtype_struct(self)
569    }
570
571    serde::forward_to_deserialize_any! {
572        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char
573        unit unit_struct seq tuple tuple_struct map struct
574        enum
575    }
576}
577
578// Enum support
579
580struct EnumAccess<'a, 'de> {
581    de: &'a mut Deserializer<'de>,
582}
583
584impl<'de> de::EnumAccess<'de> for EnumAccess<'_, 'de> {
585    type Error = Error;
586    type Variant = Self;
587
588    fn variant_seed<V: de::DeserializeSeed<'de>>(
589        self,
590        seed: V,
591    ) -> Result<(V::Value, Self::Variant)> {
592        let val = seed.deserialize(&mut *self.de)?;
593        Ok((val, self))
594    }
595}
596
597impl<'de> de::VariantAccess<'de> for EnumAccess<'_, 'de> {
598    type Error = Error;
599
600    fn unit_variant(self) -> Result<()> {
601        Err(Error::Custom(
602            "expected newtype/tuple/struct variant inside dict".into(),
603        ))
604    }
605
606    fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value> {
607        seed.deserialize(&mut *self.de)
608    }
609
610    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
611        de::Deserializer::deserialize_seq(&mut *self.de, visitor)
612    }
613
614    fn struct_variant<V: Visitor<'de>>(
615        self,
616        _fields: &'static [&'static str],
617        visitor: V,
618    ) -> Result<V::Value> {
619        de::Deserializer::deserialize_map(&mut *self.de, visitor)
620    }
621}
622
623struct UnitVariantAccess<'a, 'de> {
624    de: &'a mut Deserializer<'de>,
625}
626
627impl<'de> de::EnumAccess<'de> for UnitVariantAccess<'_, 'de> {
628    type Error = Error;
629    type Variant = Self;
630
631    fn variant_seed<V: de::DeserializeSeed<'de>>(
632        self,
633        seed: V,
634    ) -> Result<(V::Value, Self::Variant)> {
635        let val = seed.deserialize(&mut *self.de)?;
636        Ok((val, self))
637    }
638}
639
640impl<'de> de::VariantAccess<'de> for UnitVariantAccess<'_, 'de> {
641    type Error = Error;
642
643    fn unit_variant(self) -> Result<()> {
644        Ok(())
645    }
646
647    fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, _seed: T) -> Result<T::Value> {
648        Err(Error::Custom(
649            "expected unit variant for string enum".into(),
650        ))
651    }
652
653    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, _visitor: V) -> Result<V::Value> {
654        Err(Error::Custom(
655            "expected unit variant for string enum".into(),
656        ))
657    }
658
659    fn struct_variant<V: Visitor<'de>>(
660        self,
661        _fields: &'static [&'static str],
662        _visitor: V,
663    ) -> Result<V::Value> {
664        Err(Error::Custom(
665            "expected unit variant for string enum".into(),
666        ))
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use crate::from_bytes;
673
674    #[test]
675    fn deserialize_integer() {
676        assert_eq!(from_bytes::<i64>(b"i42e").unwrap(), 42);
677        assert_eq!(from_bytes::<i64>(b"i0e").unwrap(), 0);
678        assert_eq!(from_bytes::<i64>(b"i-1e").unwrap(), -1);
679    }
680
681    #[test]
682    fn deserialize_string() {
683        assert_eq!(from_bytes::<String>(b"4:spam").unwrap(), "spam");
684        assert_eq!(from_bytes::<String>(b"0:").unwrap(), "");
685    }
686
687    #[test]
688    fn reject_negative_zero() {
689        assert!(from_bytes::<i64>(b"i-0e").is_err());
690    }
691
692    #[test]
693    fn reject_leading_zeros() {
694        assert!(from_bytes::<i64>(b"i03e").is_err());
695    }
696
697    #[test]
698    fn reject_trailing_data() {
699        assert!(from_bytes::<i64>(b"i42eXXX").is_err());
700    }
701
702    #[test]
703    fn strict_rejects_unsorted_dict_keys() {
704        // d2:zz1:a2:aa1:be — keys "zz" then "aa" are unsorted
705        let unsorted = b"d2:zz1:a2:aa1:be";
706        assert!(from_bytes::<std::collections::BTreeMap<String, String>>(unsorted).is_err());
707    }
708
709    #[test]
710    fn lenient_accepts_unsorted_dict_keys() {
711        use crate::from_bytes_lenient;
712        // d2:zz1:a2:aa1:be — keys "zz" then "aa" are unsorted
713        let unsorted = b"d2:zz1:a2:aa1:be";
714        let map: std::collections::BTreeMap<String, String> = from_bytes_lenient(unsorted).unwrap();
715        assert_eq!(map.get("zz").unwrap(), "a");
716        assert_eq!(map.get("aa").unwrap(), "b");
717    }
718}