Skip to main content

irontide_bencode/
de.rs

1use serde::de::{self, Visitor};
2
3use crate::error::{Error, Result};
4
5/// Bencode deserializer.
6///
7/// Parses bencode from a byte slice, supporting zero-copy deserialization
8/// of byte strings via `deserialize_bytes`.
9pub struct Deserializer<'de> {
10    input: &'de [u8],
11    pos: usize,
12    strict_order: bool,
13}
14
15impl<'de> Deserializer<'de> {
16    /// Create a new deserializer from a byte slice.
17    ///
18    /// By default, dictionary key ordering is enforced per BEP 3.
19    /// Use [`Deserializer::lenient`] to accept unsorted keys.
20    #[must_use]
21    pub fn new(input: &'de [u8]) -> Self {
22        Deserializer {
23            input,
24            pos: 0,
25            strict_order: true,
26        }
27    }
28
29    /// Create a lenient deserializer that accepts unsorted dictionary keys.
30    ///
31    /// Many real-world clients send extension handshakes with unsorted keys.
32    /// Use this for parsing peer wire messages where strict BEP 3 compliance
33    /// would reject otherwise valid data.
34    #[must_use]
35    pub fn lenient(input: &'de [u8]) -> Self {
36        Deserializer {
37            input,
38            pos: 0,
39            strict_order: false,
40        }
41    }
42
43    /// Verify all input has been consumed. Call after deserializing.
44    ///
45    /// # Errors
46    ///
47    /// Returns an error if there is trailing data after the deserialized value.
48    pub fn finish(&self) -> Result<()> {
49        if self.pos < self.input.len() {
50            Err(Error::TrailingData {
51                position: self.pos,
52                count: self.input.len() - self.pos,
53            })
54        } else {
55            Ok(())
56        }
57    }
58
59    fn peek(&self) -> Result<u8> {
60        self.input
61            .get(self.pos)
62            .copied()
63            .ok_or_else(|| Error::UnexpectedEof {
64                position: self.pos,
65                context: "expected more data".into(),
66            })
67    }
68
69    fn next(&mut self) -> Result<u8> {
70        let byte = self.peek()?;
71        self.pos += 1;
72        Ok(byte)
73    }
74
75    fn expect(&mut self, expected: u8) -> Result<()> {
76        let byte = self.next()?;
77        if byte == expected {
78            Ok(())
79        } else {
80            Err(Error::UnexpectedByte {
81                byte,
82                position: self.pos - 1,
83                expected: match expected {
84                    b'e' => "'e' (end marker)",
85                    b'i' => "'i' (integer start)",
86                    b'l' => "'l' (list start)",
87                    b'd' => "'d' (dict start)",
88                    b':' => "':' (string separator)",
89                    _ => "specific byte",
90                },
91            })
92        }
93    }
94
95    fn parse_integer_value(&mut self) -> Result<i64> {
96        let start = self.pos;
97
98        // Find the 'e' terminator
99        let end = self.input[self.pos..]
100            .iter()
101            .position(|&b| b == b'e')
102            .ok_or_else(|| Error::UnexpectedEof {
103                position: self.pos,
104                context: "unterminated integer".into(),
105            })?;
106
107        let num_bytes = &self.input[self.pos..self.pos + end];
108        self.pos += end + 1; // skip past 'e'
109
110        if num_bytes.is_empty() {
111            return Err(Error::InvalidInteger {
112                position: start,
113                detail: "empty integer".into(),
114            });
115        }
116
117        // Reject leading zeros (except "0" itself)
118        if num_bytes.len() > 1 && num_bytes[0] == b'0' {
119            return Err(Error::InvalidInteger {
120                position: start,
121                detail: "leading zero".into(),
122            });
123        }
124
125        // Reject negative zero
126        if num_bytes == b"-0" {
127            return Err(Error::InvalidInteger {
128                position: start,
129                detail: "negative zero".into(),
130            });
131        }
132
133        // Reject bare minus sign
134        if num_bytes == b"-" {
135            return Err(Error::InvalidInteger {
136                position: start,
137                detail: "bare minus sign".into(),
138            });
139        }
140
141        // Reject leading zeros in negative numbers
142        if num_bytes.len() > 2 && num_bytes[0] == b'-' && num_bytes[1] == b'0' {
143            return Err(Error::InvalidInteger {
144                position: start,
145                detail: "leading zero in negative".into(),
146            });
147        }
148
149        let s = std::str::from_utf8(num_bytes).map_err(|_| Error::InvalidInteger {
150            position: start,
151            detail: "non-ASCII integer".into(),
152        })?;
153
154        s.parse::<i64>().map_err(|e| Error::InvalidInteger {
155            position: start,
156            detail: e.to_string(),
157        })
158    }
159
160    fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
161        let start = self.pos;
162
163        // Parse length prefix
164        let colon = self.input[self.pos..]
165            .iter()
166            .position(|&b| b == b':')
167            .ok_or_else(|| Error::InvalidByteString {
168                position: start,
169                detail: "missing ':' separator".into(),
170            })?;
171
172        let len_bytes = &self.input[self.pos..self.pos + colon];
173        if len_bytes.is_empty() {
174            return Err(Error::InvalidByteString {
175                position: start,
176                detail: "empty length prefix".into(),
177            });
178        }
179
180        // Reject leading zeros in length (except "0" itself)
181        if len_bytes.len() > 1 && len_bytes[0] == b'0' {
182            return Err(Error::InvalidByteString {
183                position: start,
184                detail: "leading zero in length".into(),
185            });
186        }
187
188        let len_str = std::str::from_utf8(len_bytes).map_err(|_| Error::InvalidByteString {
189            position: start,
190            detail: "non-ASCII length".into(),
191        })?;
192
193        let len: usize =
194            len_str
195                .parse()
196                .map_err(|e: std::num::ParseIntError| Error::InvalidByteString {
197                    position: start,
198                    detail: e.to_string(),
199                })?;
200
201        self.pos += colon + 1; // skip past ':'
202
203        if self.pos + len > self.input.len() {
204            return Err(Error::UnexpectedEof {
205                position: self.pos,
206                context: format!(
207                    "byte string needs {len} bytes, only {} available",
208                    self.input.len() - self.pos
209                ),
210            });
211        }
212
213        let data = &self.input[self.pos..self.pos + len];
214        self.pos += len;
215        Ok(data)
216    }
217}
218
219impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
220    type Error = Error;
221
222    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
223        match self.peek()? {
224            b'i' => {
225                self.pos += 1;
226                let val = self.parse_integer_value()?;
227                visitor.visit_i64(val)
228            }
229            b'l' => self.deserialize_seq(visitor),
230            b'd' => self.deserialize_map(visitor),
231            b'0'..=b'9' => {
232                let data = self.parse_byte_string()?;
233                visitor.visit_borrowed_bytes(data)
234            }
235            byte => Err(Error::UnexpectedByte {
236                byte,
237                position: self.pos,
238                expected: "integer ('i'), string ('0'-'9'), list ('l'), or dict ('d')",
239            }),
240        }
241    }
242
243    fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
244        self.expect(b'i')?;
245        let val = self.parse_integer_value()?;
246        visitor.visit_bool(val != 0)
247    }
248
249    fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
250        self.expect(b'i')?;
251        visitor.visit_i64(self.parse_integer_value()?)
252    }
253
254    fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
255        self.expect(b'i')?;
256        visitor.visit_i64(self.parse_integer_value()?)
257    }
258
259    fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
260        self.expect(b'i')?;
261        visitor.visit_i64(self.parse_integer_value()?)
262    }
263
264    fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
265        self.expect(b'i')?;
266        visitor.visit_i64(self.parse_integer_value()?)
267    }
268
269    fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
270        self.expect(b'i')?;
271        visitor.visit_i64(self.parse_integer_value()?)
272    }
273
274    fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
275        self.expect(b'i')?;
276        visitor.visit_i64(self.parse_integer_value()?)
277    }
278
279    fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
280        self.expect(b'i')?;
281        visitor.visit_i64(self.parse_integer_value()?)
282    }
283
284    fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
285        self.expect(b'i')?;
286        visitor.visit_i64(self.parse_integer_value()?)
287    }
288
289    fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
290        Err(Error::Custom("bencode does not support floats".into()))
291    }
292
293    fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
294        Err(Error::Custom("bencode does not support floats".into()))
295    }
296
297    fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
298        let data = self.parse_byte_string()?;
299        let s = std::str::from_utf8(data)
300            .map_err(|_| Error::Custom("char is not valid UTF-8".into()))?;
301        let mut chars = s.chars();
302        let c = chars
303            .next()
304            .ok_or_else(|| Error::Custom("empty string for char".into()))?;
305        if chars.next().is_some() {
306            return Err(Error::Custom("multi-char string for char".into()));
307        }
308        visitor.visit_char(c)
309    }
310
311    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
312        let data = self.parse_byte_string()?;
313        let s = std::str::from_utf8(data).map_err(|_| {
314            Error::Custom("byte string is not valid UTF-8, use bytes instead".into())
315        })?;
316        visitor.visit_borrowed_str(s)
317    }
318
319    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
320        self.deserialize_str(visitor)
321    }
322
323    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
324        let data = self.parse_byte_string()?;
325        visitor.visit_borrowed_bytes(data)
326    }
327
328    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
329        self.deserialize_bytes(visitor)
330    }
331
332    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
333        visitor.visit_some(self)
334    }
335
336    fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
337        Err(Error::Custom("bencode does not support unit".into()))
338    }
339
340    fn deserialize_unit_struct<V: Visitor<'de>>(
341        self,
342        _name: &'static str,
343        _visitor: V,
344    ) -> Result<V::Value> {
345        Err(Error::Custom(
346            "bencode does not support unit structs".into(),
347        ))
348    }
349
350    fn deserialize_newtype_struct<V: Visitor<'de>>(
351        self,
352        _name: &'static str,
353        visitor: V,
354    ) -> Result<V::Value> {
355        visitor.visit_newtype_struct(self)
356    }
357
358    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
359        self.expect(b'l')?;
360        let value = visitor.visit_seq(SeqAccess { de: self })?;
361        self.expect(b'e')?;
362        Ok(value)
363    }
364
365    fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
366        self.deserialize_seq(visitor)
367    }
368
369    fn deserialize_tuple_struct<V: Visitor<'de>>(
370        self,
371        _name: &'static str,
372        _len: usize,
373        visitor: V,
374    ) -> Result<V::Value> {
375        self.deserialize_seq(visitor)
376    }
377
378    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
379        self.expect(b'd')?;
380        let strict_order = self.strict_order;
381        let value = visitor.visit_map(MapAccess {
382            de: self,
383            last_key: None,
384            strict_order,
385        })?;
386        self.expect(b'e')?;
387        Ok(value)
388    }
389
390    fn deserialize_struct<V: Visitor<'de>>(
391        self,
392        _name: &'static str,
393        _fields: &'static [&'static str],
394        visitor: V,
395    ) -> Result<V::Value> {
396        self.deserialize_map(visitor)
397    }
398
399    fn deserialize_enum<V: Visitor<'de>>(
400        self,
401        _name: &'static str,
402        _variants: &'static [&'static str],
403        visitor: V,
404    ) -> Result<V::Value> {
405        match self.peek()? {
406            b'd' => {
407                // Dict-based enum variant: d<variant-name><value>e
408                self.pos += 1;
409                let value = visitor.visit_enum(EnumAccess { de: self })?;
410                self.expect(b'e')?;
411                Ok(value)
412            }
413            _ => {
414                // String-based unit variant
415                visitor.visit_enum(UnitVariantAccess { de: self })
416            }
417        }
418    }
419
420    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
421        self.deserialize_str(visitor)
422    }
423
424    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
425        // Skip over any bencode value
426        self.skip_value()?;
427        visitor.visit_unit()
428    }
429}
430
431impl Deserializer<'_> {
432    /// Skip over a complete bencode value without allocating.
433    fn skip_value(&mut self) -> Result<()> {
434        match self.peek()? {
435            b'i' => {
436                self.pos += 1;
437                self.parse_integer_value()?;
438                Ok(())
439            }
440            b'l' => {
441                self.pos += 1;
442                while self.peek()? != b'e' {
443                    self.skip_value()?;
444                }
445                self.pos += 1;
446                Ok(())
447            }
448            b'd' => {
449                self.pos += 1;
450                while self.peek()? != b'e' {
451                    self.parse_byte_string()?; // key
452                    self.skip_value()?; // value
453                }
454                self.pos += 1;
455                Ok(())
456            }
457            b'0'..=b'9' => {
458                self.parse_byte_string()?;
459                Ok(())
460            }
461            byte => Err(Error::UnexpectedByte {
462                byte,
463                position: self.pos,
464                expected: "bencode value",
465            }),
466        }
467    }
468}
469
470struct SeqAccess<'a, 'de> {
471    de: &'a mut Deserializer<'de>,
472}
473
474impl<'de> de::SeqAccess<'de> for SeqAccess<'_, 'de> {
475    type Error = Error;
476
477    fn next_element_seed<T: de::DeserializeSeed<'de>>(
478        &mut self,
479        seed: T,
480    ) -> Result<Option<T::Value>> {
481        if self.de.peek()? == b'e' {
482            return Ok(None);
483        }
484        seed.deserialize(&mut *self.de).map(Some)
485    }
486}
487
488struct MapAccess<'a, 'de> {
489    de: &'a mut Deserializer<'de>,
490    last_key: Option<Vec<u8>>,
491    strict_order: bool,
492}
493
494impl<'de> de::MapAccess<'de> for MapAccess<'_, 'de> {
495    type Error = Error;
496
497    fn next_key_seed<K: de::DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>> {
498        if self.de.peek()? == b'e' {
499            return Ok(None);
500        }
501
502        // Read key bytes for sort-order validation
503        let key_start = self.de.pos;
504        let key_data = self.de.parse_byte_string()?;
505        let key_vec = key_data.to_vec();
506
507        // Validate sorted order (strict mode only — lenient accepts unsorted)
508        if let Some(ref last) = self.last_key
509            && self.strict_order
510            && key_vec <= *last
511        {
512            return Err(Error::UnsortedKeys {
513                position: key_start,
514            });
515        }
516        self.last_key = Some(key_vec);
517
518        // Now deserialize the key via serde using a special deserializer
519        // that yields the already-parsed string
520        let key_de = BorrowedStrDeserializer(key_data);
521        seed.deserialize(key_de).map(Some)
522    }
523
524    fn next_value_seed<V: de::DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value> {
525        seed.deserialize(&mut *self.de)
526    }
527}
528
529/// Minimal deserializer that yields a single already-parsed byte string.
530struct BorrowedStrDeserializer<'de>(&'de [u8]);
531
532impl<'de> de::Deserializer<'de> for BorrowedStrDeserializer<'de> {
533    type Error = Error;
534
535    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
536        visitor.visit_borrowed_bytes(self.0)
537    }
538
539    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
540        let s = std::str::from_utf8(self.0)
541            .map_err(|_| Error::Custom("dict key is not valid UTF-8".into()))?;
542        visitor.visit_borrowed_str(s)
543    }
544
545    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
546        self.deserialize_str(visitor)
547    }
548
549    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
550        visitor.visit_borrowed_bytes(self.0)
551    }
552
553    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
554        self.deserialize_bytes(visitor)
555    }
556
557    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
558        self.deserialize_str(visitor)
559    }
560
561    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
562        visitor.visit_unit()
563    }
564
565    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
566        visitor.visit_some(self)
567    }
568
569    fn deserialize_newtype_struct<V: Visitor<'de>>(
570        self,
571        _name: &'static str,
572        visitor: V,
573    ) -> Result<V::Value> {
574        visitor.visit_newtype_struct(self)
575    }
576
577    serde::forward_to_deserialize_any! {
578        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char
579        unit unit_struct seq tuple tuple_struct map struct
580        enum
581    }
582}
583
584// Enum support
585
586struct EnumAccess<'a, 'de> {
587    de: &'a mut Deserializer<'de>,
588}
589
590impl<'de> de::EnumAccess<'de> for EnumAccess<'_, 'de> {
591    type Error = Error;
592    type Variant = Self;
593
594    fn variant_seed<V: de::DeserializeSeed<'de>>(
595        self,
596        seed: V,
597    ) -> Result<(V::Value, Self::Variant)> {
598        let val = seed.deserialize(&mut *self.de)?;
599        Ok((val, self))
600    }
601}
602
603impl<'de> de::VariantAccess<'de> for EnumAccess<'_, 'de> {
604    type Error = Error;
605
606    fn unit_variant(self) -> Result<()> {
607        Err(Error::Custom(
608            "expected newtype/tuple/struct variant inside dict".into(),
609        ))
610    }
611
612    fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value> {
613        seed.deserialize(&mut *self.de)
614    }
615
616    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
617        de::Deserializer::deserialize_seq(&mut *self.de, visitor)
618    }
619
620    fn struct_variant<V: Visitor<'de>>(
621        self,
622        _fields: &'static [&'static str],
623        visitor: V,
624    ) -> Result<V::Value> {
625        de::Deserializer::deserialize_map(&mut *self.de, visitor)
626    }
627}
628
629struct UnitVariantAccess<'a, 'de> {
630    de: &'a mut Deserializer<'de>,
631}
632
633impl<'de> de::EnumAccess<'de> for UnitVariantAccess<'_, 'de> {
634    type Error = Error;
635    type Variant = Self;
636
637    fn variant_seed<V: de::DeserializeSeed<'de>>(
638        self,
639        seed: V,
640    ) -> Result<(V::Value, Self::Variant)> {
641        let val = seed.deserialize(&mut *self.de)?;
642        Ok((val, self))
643    }
644}
645
646impl<'de> de::VariantAccess<'de> for UnitVariantAccess<'_, 'de> {
647    type Error = Error;
648
649    fn unit_variant(self) -> Result<()> {
650        Ok(())
651    }
652
653    fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, _seed: T) -> Result<T::Value> {
654        Err(Error::Custom(
655            "expected unit variant for string enum".into(),
656        ))
657    }
658
659    fn tuple_variant<V: Visitor<'de>>(self, _len: usize, _visitor: V) -> Result<V::Value> {
660        Err(Error::Custom(
661            "expected unit variant for string enum".into(),
662        ))
663    }
664
665    fn struct_variant<V: Visitor<'de>>(
666        self,
667        _fields: &'static [&'static str],
668        _visitor: V,
669    ) -> Result<V::Value> {
670        Err(Error::Custom(
671            "expected unit variant for string enum".into(),
672        ))
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use crate::from_bytes;
679
680    #[test]
681    fn deserialize_integer() {
682        assert_eq!(from_bytes::<i64>(b"i42e").unwrap(), 42);
683        assert_eq!(from_bytes::<i64>(b"i0e").unwrap(), 0);
684        assert_eq!(from_bytes::<i64>(b"i-1e").unwrap(), -1);
685    }
686
687    #[test]
688    fn deserialize_string() {
689        assert_eq!(from_bytes::<String>(b"4:spam").unwrap(), "spam");
690        assert_eq!(from_bytes::<String>(b"0:").unwrap(), "");
691    }
692
693    #[test]
694    fn reject_negative_zero() {
695        assert!(from_bytes::<i64>(b"i-0e").is_err());
696    }
697
698    #[test]
699    fn reject_leading_zeros() {
700        assert!(from_bytes::<i64>(b"i03e").is_err());
701    }
702
703    #[test]
704    fn reject_trailing_data() {
705        assert!(from_bytes::<i64>(b"i42eXXX").is_err());
706    }
707
708    #[test]
709    fn strict_rejects_unsorted_dict_keys() {
710        // d2:zz1:a2:aa1:be — keys "zz" then "aa" are unsorted
711        let unsorted = b"d2:zz1:a2:aa1:be";
712        assert!(from_bytes::<std::collections::BTreeMap<String, String>>(unsorted).is_err());
713    }
714
715    #[test]
716    fn lenient_accepts_unsorted_dict_keys() {
717        use crate::from_bytes_lenient;
718        // d2:zz1:a2:aa1:be — keys "zz" then "aa" are unsorted
719        let unsorted = b"d2:zz1:a2:aa1:be";
720        let map: std::collections::BTreeMap<String, String> = from_bytes_lenient(unsorted).unwrap();
721        assert_eq!(map.get("zz").unwrap(), "a");
722        assert_eq!(map.get("aa").unwrap(), "b");
723    }
724}