1use serde::de::{self, Visitor};
2
3use crate::error::{Error, Result};
4
5pub struct Deserializer<'de> {
10 input: &'de [u8],
11 pos: usize,
12 strict_order: bool,
13}
14
15impl<'de> Deserializer<'de> {
16 pub fn new(input: &'de [u8]) -> Self {
21 Deserializer {
22 input,
23 pos: 0,
24 strict_order: true,
25 }
26 }
27
28 pub fn lenient(input: &'de [u8]) -> Self {
34 Deserializer {
35 input,
36 pos: 0,
37 strict_order: false,
38 }
39 }
40
41 pub fn finish(&self) -> Result<()> {
43 if self.pos < self.input.len() {
44 Err(Error::TrailingData {
45 position: self.pos,
46 count: self.input.len() - self.pos,
47 })
48 } else {
49 Ok(())
50 }
51 }
52
53 fn peek(&self) -> Result<u8> {
54 self.input
55 .get(self.pos)
56 .copied()
57 .ok_or(Error::UnexpectedEof {
58 position: self.pos,
59 context: "expected more data".into(),
60 })
61 }
62
63 fn next(&mut self) -> Result<u8> {
64 let byte = self.peek()?;
65 self.pos += 1;
66 Ok(byte)
67 }
68
69 fn expect(&mut self, expected: u8) -> Result<()> {
70 let byte = self.next()?;
71 if byte != expected {
72 Err(Error::UnexpectedByte {
73 byte,
74 position: self.pos - 1,
75 expected: match expected {
76 b'e' => "'e' (end marker)",
77 b'i' => "'i' (integer start)",
78 b'l' => "'l' (list start)",
79 b'd' => "'d' (dict start)",
80 b':' => "':' (string separator)",
81 _ => "specific byte",
82 },
83 })
84 } else {
85 Ok(())
86 }
87 }
88
89 fn parse_integer_value(&mut self) -> Result<i64> {
90 let start = self.pos;
91
92 let end = self.input[self.pos..]
94 .iter()
95 .position(|&b| b == b'e')
96 .ok_or(Error::UnexpectedEof {
97 position: self.pos,
98 context: "unterminated integer".into(),
99 })?;
100
101 let num_bytes = &self.input[self.pos..self.pos + end];
102 self.pos += end + 1; if num_bytes.is_empty() {
105 return Err(Error::InvalidInteger {
106 position: start,
107 detail: "empty integer".into(),
108 });
109 }
110
111 if num_bytes.len() > 1 && num_bytes[0] == b'0' {
113 return Err(Error::InvalidInteger {
114 position: start,
115 detail: "leading zero".into(),
116 });
117 }
118
119 if num_bytes == b"-0" {
121 return Err(Error::InvalidInteger {
122 position: start,
123 detail: "negative zero".into(),
124 });
125 }
126
127 if num_bytes == b"-" {
129 return Err(Error::InvalidInteger {
130 position: start,
131 detail: "bare minus sign".into(),
132 });
133 }
134
135 if num_bytes.len() > 2 && num_bytes[0] == b'-' && num_bytes[1] == b'0' {
137 return Err(Error::InvalidInteger {
138 position: start,
139 detail: "leading zero in negative".into(),
140 });
141 }
142
143 let s = std::str::from_utf8(num_bytes).map_err(|_| Error::InvalidInteger {
144 position: start,
145 detail: "non-ASCII integer".into(),
146 })?;
147
148 s.parse::<i64>().map_err(|e| Error::InvalidInteger {
149 position: start,
150 detail: e.to_string(),
151 })
152 }
153
154 fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
155 let start = self.pos;
156
157 let colon = self.input[self.pos..]
159 .iter()
160 .position(|&b| b == b':')
161 .ok_or(Error::InvalidByteString {
162 position: start,
163 detail: "missing ':' separator".into(),
164 })?;
165
166 let len_bytes = &self.input[self.pos..self.pos + colon];
167 if len_bytes.is_empty() {
168 return Err(Error::InvalidByteString {
169 position: start,
170 detail: "empty length prefix".into(),
171 });
172 }
173
174 if len_bytes.len() > 1 && len_bytes[0] == b'0' {
176 return Err(Error::InvalidByteString {
177 position: start,
178 detail: "leading zero in length".into(),
179 });
180 }
181
182 let len_str = std::str::from_utf8(len_bytes).map_err(|_| Error::InvalidByteString {
183 position: start,
184 detail: "non-ASCII length".into(),
185 })?;
186
187 let len: usize =
188 len_str
189 .parse()
190 .map_err(|e: std::num::ParseIntError| Error::InvalidByteString {
191 position: start,
192 detail: e.to_string(),
193 })?;
194
195 self.pos += colon + 1; if self.pos + len > self.input.len() {
198 return Err(Error::UnexpectedEof {
199 position: self.pos,
200 context: format!(
201 "byte string needs {len} bytes, only {} available",
202 self.input.len() - self.pos
203 ),
204 });
205 }
206
207 let data = &self.input[self.pos..self.pos + len];
208 self.pos += len;
209 Ok(data)
210 }
211}
212
213impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
214 type Error = Error;
215
216 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
217 match self.peek()? {
218 b'i' => {
219 self.pos += 1;
220 let val = self.parse_integer_value()?;
221 visitor.visit_i64(val)
222 }
223 b'l' => self.deserialize_seq(visitor),
224 b'd' => self.deserialize_map(visitor),
225 b'0'..=b'9' => {
226 let data = self.parse_byte_string()?;
227 visitor.visit_borrowed_bytes(data)
228 }
229 byte => Err(Error::UnexpectedByte {
230 byte,
231 position: self.pos,
232 expected: "integer ('i'), string ('0'-'9'), list ('l'), or dict ('d')",
233 }),
234 }
235 }
236
237 fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
238 self.expect(b'i')?;
239 let val = self.parse_integer_value()?;
240 visitor.visit_bool(val != 0)
241 }
242
243 fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
244 self.expect(b'i')?;
245 visitor.visit_i64(self.parse_integer_value()?)
246 }
247
248 fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
249 self.expect(b'i')?;
250 visitor.visit_i64(self.parse_integer_value()?)
251 }
252
253 fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
254 self.expect(b'i')?;
255 visitor.visit_i64(self.parse_integer_value()?)
256 }
257
258 fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
259 self.expect(b'i')?;
260 visitor.visit_i64(self.parse_integer_value()?)
261 }
262
263 fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
264 self.expect(b'i')?;
265 visitor.visit_i64(self.parse_integer_value()?)
266 }
267
268 fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
269 self.expect(b'i')?;
270 visitor.visit_i64(self.parse_integer_value()?)
271 }
272
273 fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
274 self.expect(b'i')?;
275 visitor.visit_i64(self.parse_integer_value()?)
276 }
277
278 fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
279 self.expect(b'i')?;
280 visitor.visit_i64(self.parse_integer_value()?)
281 }
282
283 fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
284 Err(Error::Custom("bencode does not support floats".into()))
285 }
286
287 fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
288 Err(Error::Custom("bencode does not support floats".into()))
289 }
290
291 fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
292 let data = self.parse_byte_string()?;
293 let s = std::str::from_utf8(data)
294 .map_err(|_| Error::Custom("char is not valid UTF-8".into()))?;
295 let mut chars = s.chars();
296 let c = chars
297 .next()
298 .ok_or_else(|| Error::Custom("empty string for char".into()))?;
299 if chars.next().is_some() {
300 return Err(Error::Custom("multi-char string for char".into()));
301 }
302 visitor.visit_char(c)
303 }
304
305 fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
306 let data = self.parse_byte_string()?;
307 let s = std::str::from_utf8(data).map_err(|_| {
308 Error::Custom("byte string is not valid UTF-8, use bytes instead".into())
309 })?;
310 visitor.visit_borrowed_str(s)
311 }
312
313 fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
314 self.deserialize_str(visitor)
315 }
316
317 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
318 let data = self.parse_byte_string()?;
319 visitor.visit_borrowed_bytes(data)
320 }
321
322 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
323 self.deserialize_bytes(visitor)
324 }
325
326 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
327 visitor.visit_some(self)
328 }
329
330 fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
331 Err(Error::Custom("bencode does not support unit".into()))
332 }
333
334 fn deserialize_unit_struct<V: Visitor<'de>>(
335 self,
336 _name: &'static str,
337 _visitor: V,
338 ) -> Result<V::Value> {
339 Err(Error::Custom(
340 "bencode does not support unit structs".into(),
341 ))
342 }
343
344 fn deserialize_newtype_struct<V: Visitor<'de>>(
345 self,
346 _name: &'static str,
347 visitor: V,
348 ) -> Result<V::Value> {
349 visitor.visit_newtype_struct(self)
350 }
351
352 fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
353 self.expect(b'l')?;
354 let value = visitor.visit_seq(SeqAccess { de: self })?;
355 self.expect(b'e')?;
356 Ok(value)
357 }
358
359 fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
360 self.deserialize_seq(visitor)
361 }
362
363 fn deserialize_tuple_struct<V: Visitor<'de>>(
364 self,
365 _name: &'static str,
366 _len: usize,
367 visitor: V,
368 ) -> Result<V::Value> {
369 self.deserialize_seq(visitor)
370 }
371
372 fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
373 self.expect(b'd')?;
374 let strict_order = self.strict_order;
375 let value = visitor.visit_map(MapAccess {
376 de: self,
377 last_key: None,
378 strict_order,
379 })?;
380 self.expect(b'e')?;
381 Ok(value)
382 }
383
384 fn deserialize_struct<V: Visitor<'de>>(
385 self,
386 _name: &'static str,
387 _fields: &'static [&'static str],
388 visitor: V,
389 ) -> Result<V::Value> {
390 self.deserialize_map(visitor)
391 }
392
393 fn deserialize_enum<V: Visitor<'de>>(
394 self,
395 _name: &'static str,
396 _variants: &'static [&'static str],
397 visitor: V,
398 ) -> Result<V::Value> {
399 match self.peek()? {
400 b'd' => {
401 self.pos += 1;
403 let value = visitor.visit_enum(EnumAccess { de: self })?;
404 self.expect(b'e')?;
405 Ok(value)
406 }
407 _ => {
408 visitor.visit_enum(UnitVariantAccess { de: self })
410 }
411 }
412 }
413
414 fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
415 self.deserialize_str(visitor)
416 }
417
418 fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
419 self.skip_value()?;
421 visitor.visit_unit()
422 }
423}
424
425impl Deserializer<'_> {
426 fn skip_value(&mut self) -> Result<()> {
428 match self.peek()? {
429 b'i' => {
430 self.pos += 1;
431 self.parse_integer_value()?;
432 Ok(())
433 }
434 b'l' => {
435 self.pos += 1;
436 while self.peek()? != b'e' {
437 self.skip_value()?;
438 }
439 self.pos += 1;
440 Ok(())
441 }
442 b'd' => {
443 self.pos += 1;
444 while self.peek()? != b'e' {
445 self.parse_byte_string()?; self.skip_value()?; }
448 self.pos += 1;
449 Ok(())
450 }
451 b'0'..=b'9' => {
452 self.parse_byte_string()?;
453 Ok(())
454 }
455 byte => Err(Error::UnexpectedByte {
456 byte,
457 position: self.pos,
458 expected: "bencode value",
459 }),
460 }
461 }
462}
463
464struct SeqAccess<'a, 'de> {
465 de: &'a mut Deserializer<'de>,
466}
467
468impl<'de> de::SeqAccess<'de> for SeqAccess<'_, 'de> {
469 type Error = Error;
470
471 fn next_element_seed<T: de::DeserializeSeed<'de>>(
472 &mut self,
473 seed: T,
474 ) -> Result<Option<T::Value>> {
475 if self.de.peek()? == b'e' {
476 return Ok(None);
477 }
478 seed.deserialize(&mut *self.de).map(Some)
479 }
480}
481
482struct MapAccess<'a, 'de> {
483 de: &'a mut Deserializer<'de>,
484 last_key: Option<Vec<u8>>,
485 strict_order: bool,
486}
487
488impl<'de> de::MapAccess<'de> for MapAccess<'_, 'de> {
489 type Error = Error;
490
491 fn next_key_seed<K: de::DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>> {
492 if self.de.peek()? == b'e' {
493 return Ok(None);
494 }
495
496 let key_start = self.de.pos;
498 let key_data = self.de.parse_byte_string()?;
499 let key_vec = key_data.to_vec();
500
501 if let Some(ref last) = self.last_key
503 && self.strict_order
504 && key_vec <= *last
505 {
506 return Err(Error::UnsortedKeys {
507 position: key_start,
508 });
509 }
510 self.last_key = Some(key_vec);
511
512 let key_de = BorrowedStrDeserializer(key_data);
515 seed.deserialize(key_de).map(Some)
516 }
517
518 fn next_value_seed<V: de::DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value> {
519 seed.deserialize(&mut *self.de)
520 }
521}
522
523struct BorrowedStrDeserializer<'de>(&'de [u8]);
525
526impl<'de> de::Deserializer<'de> for BorrowedStrDeserializer<'de> {
527 type Error = Error;
528
529 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
530 visitor.visit_borrowed_bytes(self.0)
531 }
532
533 fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
534 let s = std::str::from_utf8(self.0)
535 .map_err(|_| Error::Custom("dict key is not valid UTF-8".into()))?;
536 visitor.visit_borrowed_str(s)
537 }
538
539 fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
540 self.deserialize_str(visitor)
541 }
542
543 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
544 visitor.visit_borrowed_bytes(self.0)
545 }
546
547 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
548 self.deserialize_bytes(visitor)
549 }
550
551 fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
552 self.deserialize_str(visitor)
553 }
554
555 fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
556 visitor.visit_unit()
557 }
558
559 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
560 visitor.visit_some(self)
561 }
562
563 fn deserialize_newtype_struct<V: Visitor<'de>>(
564 self,
565 _name: &'static str,
566 visitor: V,
567 ) -> Result<V::Value> {
568 visitor.visit_newtype_struct(self)
569 }
570
571 serde::forward_to_deserialize_any! {
572 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char
573 unit unit_struct seq tuple tuple_struct map struct
574 enum
575 }
576}
577
578struct EnumAccess<'a, 'de> {
581 de: &'a mut Deserializer<'de>,
582}
583
584impl<'de> de::EnumAccess<'de> for EnumAccess<'_, 'de> {
585 type Error = Error;
586 type Variant = Self;
587
588 fn variant_seed<V: de::DeserializeSeed<'de>>(
589 self,
590 seed: V,
591 ) -> Result<(V::Value, Self::Variant)> {
592 let val = seed.deserialize(&mut *self.de)?;
593 Ok((val, self))
594 }
595}
596
597impl<'de> de::VariantAccess<'de> for EnumAccess<'_, 'de> {
598 type Error = Error;
599
600 fn unit_variant(self) -> Result<()> {
601 Err(Error::Custom(
602 "expected newtype/tuple/struct variant inside dict".into(),
603 ))
604 }
605
606 fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value> {
607 seed.deserialize(&mut *self.de)
608 }
609
610 fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
611 de::Deserializer::deserialize_seq(&mut *self.de, visitor)
612 }
613
614 fn struct_variant<V: Visitor<'de>>(
615 self,
616 _fields: &'static [&'static str],
617 visitor: V,
618 ) -> Result<V::Value> {
619 de::Deserializer::deserialize_map(&mut *self.de, visitor)
620 }
621}
622
623struct UnitVariantAccess<'a, 'de> {
624 de: &'a mut Deserializer<'de>,
625}
626
627impl<'de> de::EnumAccess<'de> for UnitVariantAccess<'_, 'de> {
628 type Error = Error;
629 type Variant = Self;
630
631 fn variant_seed<V: de::DeserializeSeed<'de>>(
632 self,
633 seed: V,
634 ) -> Result<(V::Value, Self::Variant)> {
635 let val = seed.deserialize(&mut *self.de)?;
636 Ok((val, self))
637 }
638}
639
640impl<'de> de::VariantAccess<'de> for UnitVariantAccess<'_, 'de> {
641 type Error = Error;
642
643 fn unit_variant(self) -> Result<()> {
644 Ok(())
645 }
646
647 fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, _seed: T) -> Result<T::Value> {
648 Err(Error::Custom(
649 "expected unit variant for string enum".into(),
650 ))
651 }
652
653 fn tuple_variant<V: Visitor<'de>>(self, _len: usize, _visitor: V) -> Result<V::Value> {
654 Err(Error::Custom(
655 "expected unit variant for string enum".into(),
656 ))
657 }
658
659 fn struct_variant<V: Visitor<'de>>(
660 self,
661 _fields: &'static [&'static str],
662 _visitor: V,
663 ) -> Result<V::Value> {
664 Err(Error::Custom(
665 "expected unit variant for string enum".into(),
666 ))
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use crate::from_bytes;
673
674 #[test]
675 fn deserialize_integer() {
676 assert_eq!(from_bytes::<i64>(b"i42e").unwrap(), 42);
677 assert_eq!(from_bytes::<i64>(b"i0e").unwrap(), 0);
678 assert_eq!(from_bytes::<i64>(b"i-1e").unwrap(), -1);
679 }
680
681 #[test]
682 fn deserialize_string() {
683 assert_eq!(from_bytes::<String>(b"4:spam").unwrap(), "spam");
684 assert_eq!(from_bytes::<String>(b"0:").unwrap(), "");
685 }
686
687 #[test]
688 fn reject_negative_zero() {
689 assert!(from_bytes::<i64>(b"i-0e").is_err());
690 }
691
692 #[test]
693 fn reject_leading_zeros() {
694 assert!(from_bytes::<i64>(b"i03e").is_err());
695 }
696
697 #[test]
698 fn reject_trailing_data() {
699 assert!(from_bytes::<i64>(b"i42eXXX").is_err());
700 }
701
702 #[test]
703 fn strict_rejects_unsorted_dict_keys() {
704 let unsorted = b"d2:zz1:a2:aa1:be";
706 assert!(from_bytes::<std::collections::BTreeMap<String, String>>(unsorted).is_err());
707 }
708
709 #[test]
710 fn lenient_accepts_unsorted_dict_keys() {
711 use crate::from_bytes_lenient;
712 let unsorted = b"d2:zz1:a2:aa1:be";
714 let map: std::collections::BTreeMap<String, String> = from_bytes_lenient(unsorted).unwrap();
715 assert_eq!(map.get("zz").unwrap(), "a");
716 assert_eq!(map.get("aa").unwrap(), "b");
717 }
718}