1use serde::de::{self, Visitor};
2
3use crate::error::{Error, Result};
4
5pub struct Deserializer<'de> {
10 input: &'de [u8],
11 pos: usize,
12 strict_order: bool,
13}
14
15impl<'de> Deserializer<'de> {
16 #[must_use]
21 pub fn new(input: &'de [u8]) -> Self {
22 Deserializer {
23 input,
24 pos: 0,
25 strict_order: true,
26 }
27 }
28
29 #[must_use]
35 pub fn lenient(input: &'de [u8]) -> Self {
36 Deserializer {
37 input,
38 pos: 0,
39 strict_order: false,
40 }
41 }
42
43 pub fn finish(&self) -> Result<()> {
49 if self.pos < self.input.len() {
50 Err(Error::TrailingData {
51 position: self.pos,
52 count: self.input.len() - self.pos,
53 })
54 } else {
55 Ok(())
56 }
57 }
58
59 fn peek(&self) -> Result<u8> {
60 self.input
61 .get(self.pos)
62 .copied()
63 .ok_or_else(|| Error::UnexpectedEof {
64 position: self.pos,
65 context: "expected more data".into(),
66 })
67 }
68
69 fn next(&mut self) -> Result<u8> {
70 let byte = self.peek()?;
71 self.pos += 1;
72 Ok(byte)
73 }
74
75 fn expect(&mut self, expected: u8) -> Result<()> {
76 let byte = self.next()?;
77 if byte == expected {
78 Ok(())
79 } else {
80 Err(Error::UnexpectedByte {
81 byte,
82 position: self.pos - 1,
83 expected: match expected {
84 b'e' => "'e' (end marker)",
85 b'i' => "'i' (integer start)",
86 b'l' => "'l' (list start)",
87 b'd' => "'d' (dict start)",
88 b':' => "':' (string separator)",
89 _ => "specific byte",
90 },
91 })
92 }
93 }
94
95 fn parse_integer_value(&mut self) -> Result<i64> {
96 let start = self.pos;
97
98 let end = self.input[self.pos..]
100 .iter()
101 .position(|&b| b == b'e')
102 .ok_or_else(|| Error::UnexpectedEof {
103 position: self.pos,
104 context: "unterminated integer".into(),
105 })?;
106
107 let num_bytes = &self.input[self.pos..self.pos + end];
108 self.pos += end + 1; if num_bytes.is_empty() {
111 return Err(Error::InvalidInteger {
112 position: start,
113 detail: "empty integer".into(),
114 });
115 }
116
117 if num_bytes.len() > 1 && num_bytes[0] == b'0' {
119 return Err(Error::InvalidInteger {
120 position: start,
121 detail: "leading zero".into(),
122 });
123 }
124
125 if num_bytes == b"-0" {
127 return Err(Error::InvalidInteger {
128 position: start,
129 detail: "negative zero".into(),
130 });
131 }
132
133 if num_bytes == b"-" {
135 return Err(Error::InvalidInteger {
136 position: start,
137 detail: "bare minus sign".into(),
138 });
139 }
140
141 if num_bytes.len() > 2 && num_bytes[0] == b'-' && num_bytes[1] == b'0' {
143 return Err(Error::InvalidInteger {
144 position: start,
145 detail: "leading zero in negative".into(),
146 });
147 }
148
149 let s = std::str::from_utf8(num_bytes).map_err(|_| Error::InvalidInteger {
150 position: start,
151 detail: "non-ASCII integer".into(),
152 })?;
153
154 s.parse::<i64>().map_err(|e| Error::InvalidInteger {
155 position: start,
156 detail: e.to_string(),
157 })
158 }
159
160 fn parse_byte_string(&mut self) -> Result<&'de [u8]> {
161 let start = self.pos;
162
163 let colon = self.input[self.pos..]
165 .iter()
166 .position(|&b| b == b':')
167 .ok_or_else(|| Error::InvalidByteString {
168 position: start,
169 detail: "missing ':' separator".into(),
170 })?;
171
172 let len_bytes = &self.input[self.pos..self.pos + colon];
173 if len_bytes.is_empty() {
174 return Err(Error::InvalidByteString {
175 position: start,
176 detail: "empty length prefix".into(),
177 });
178 }
179
180 if len_bytes.len() > 1 && len_bytes[0] == b'0' {
182 return Err(Error::InvalidByteString {
183 position: start,
184 detail: "leading zero in length".into(),
185 });
186 }
187
188 let len_str = std::str::from_utf8(len_bytes).map_err(|_| Error::InvalidByteString {
189 position: start,
190 detail: "non-ASCII length".into(),
191 })?;
192
193 let len: usize =
194 len_str
195 .parse()
196 .map_err(|e: std::num::ParseIntError| Error::InvalidByteString {
197 position: start,
198 detail: e.to_string(),
199 })?;
200
201 self.pos += colon + 1; if self.pos + len > self.input.len() {
204 return Err(Error::UnexpectedEof {
205 position: self.pos,
206 context: format!(
207 "byte string needs {len} bytes, only {} available",
208 self.input.len() - self.pos
209 ),
210 });
211 }
212
213 let data = &self.input[self.pos..self.pos + len];
214 self.pos += len;
215 Ok(data)
216 }
217}
218
219impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
220 type Error = Error;
221
222 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
223 match self.peek()? {
224 b'i' => {
225 self.pos += 1;
226 let val = self.parse_integer_value()?;
227 visitor.visit_i64(val)
228 }
229 b'l' => self.deserialize_seq(visitor),
230 b'd' => self.deserialize_map(visitor),
231 b'0'..=b'9' => {
232 let data = self.parse_byte_string()?;
233 visitor.visit_borrowed_bytes(data)
234 }
235 byte => Err(Error::UnexpectedByte {
236 byte,
237 position: self.pos,
238 expected: "integer ('i'), string ('0'-'9'), list ('l'), or dict ('d')",
239 }),
240 }
241 }
242
243 fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
244 self.expect(b'i')?;
245 let val = self.parse_integer_value()?;
246 visitor.visit_bool(val != 0)
247 }
248
249 fn deserialize_i8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
250 self.expect(b'i')?;
251 visitor.visit_i64(self.parse_integer_value()?)
252 }
253
254 fn deserialize_i16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
255 self.expect(b'i')?;
256 visitor.visit_i64(self.parse_integer_value()?)
257 }
258
259 fn deserialize_i32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
260 self.expect(b'i')?;
261 visitor.visit_i64(self.parse_integer_value()?)
262 }
263
264 fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
265 self.expect(b'i')?;
266 visitor.visit_i64(self.parse_integer_value()?)
267 }
268
269 fn deserialize_u8<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
270 self.expect(b'i')?;
271 visitor.visit_i64(self.parse_integer_value()?)
272 }
273
274 fn deserialize_u16<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
275 self.expect(b'i')?;
276 visitor.visit_i64(self.parse_integer_value()?)
277 }
278
279 fn deserialize_u32<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
280 self.expect(b'i')?;
281 visitor.visit_i64(self.parse_integer_value()?)
282 }
283
284 fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
285 self.expect(b'i')?;
286 visitor.visit_i64(self.parse_integer_value()?)
287 }
288
289 fn deserialize_f32<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
290 Err(Error::Custom("bencode does not support floats".into()))
291 }
292
293 fn deserialize_f64<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
294 Err(Error::Custom("bencode does not support floats".into()))
295 }
296
297 fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
298 let data = self.parse_byte_string()?;
299 let s = std::str::from_utf8(data)
300 .map_err(|_| Error::Custom("char is not valid UTF-8".into()))?;
301 let mut chars = s.chars();
302 let c = chars
303 .next()
304 .ok_or_else(|| Error::Custom("empty string for char".into()))?;
305 if chars.next().is_some() {
306 return Err(Error::Custom("multi-char string for char".into()));
307 }
308 visitor.visit_char(c)
309 }
310
311 fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
312 let data = self.parse_byte_string()?;
313 let s = std::str::from_utf8(data).map_err(|_| {
314 Error::Custom("byte string is not valid UTF-8, use bytes instead".into())
315 })?;
316 visitor.visit_borrowed_str(s)
317 }
318
319 fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
320 self.deserialize_str(visitor)
321 }
322
323 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
324 let data = self.parse_byte_string()?;
325 visitor.visit_borrowed_bytes(data)
326 }
327
328 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
329 self.deserialize_bytes(visitor)
330 }
331
332 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
333 visitor.visit_some(self)
334 }
335
336 fn deserialize_unit<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
337 Err(Error::Custom("bencode does not support unit".into()))
338 }
339
340 fn deserialize_unit_struct<V: Visitor<'de>>(
341 self,
342 _name: &'static str,
343 _visitor: V,
344 ) -> Result<V::Value> {
345 Err(Error::Custom(
346 "bencode does not support unit structs".into(),
347 ))
348 }
349
350 fn deserialize_newtype_struct<V: Visitor<'de>>(
351 self,
352 _name: &'static str,
353 visitor: V,
354 ) -> Result<V::Value> {
355 visitor.visit_newtype_struct(self)
356 }
357
358 fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
359 self.expect(b'l')?;
360 let value = visitor.visit_seq(SeqAccess { de: self })?;
361 self.expect(b'e')?;
362 Ok(value)
363 }
364
365 fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
366 self.deserialize_seq(visitor)
367 }
368
369 fn deserialize_tuple_struct<V: Visitor<'de>>(
370 self,
371 _name: &'static str,
372 _len: usize,
373 visitor: V,
374 ) -> Result<V::Value> {
375 self.deserialize_seq(visitor)
376 }
377
378 fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
379 self.expect(b'd')?;
380 let strict_order = self.strict_order;
381 let value = visitor.visit_map(MapAccess {
382 de: self,
383 last_key: None,
384 strict_order,
385 })?;
386 self.expect(b'e')?;
387 Ok(value)
388 }
389
390 fn deserialize_struct<V: Visitor<'de>>(
391 self,
392 _name: &'static str,
393 _fields: &'static [&'static str],
394 visitor: V,
395 ) -> Result<V::Value> {
396 self.deserialize_map(visitor)
397 }
398
399 fn deserialize_enum<V: Visitor<'de>>(
400 self,
401 _name: &'static str,
402 _variants: &'static [&'static str],
403 visitor: V,
404 ) -> Result<V::Value> {
405 match self.peek()? {
406 b'd' => {
407 self.pos += 1;
409 let value = visitor.visit_enum(EnumAccess { de: self })?;
410 self.expect(b'e')?;
411 Ok(value)
412 }
413 _ => {
414 visitor.visit_enum(UnitVariantAccess { de: self })
416 }
417 }
418 }
419
420 fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
421 self.deserialize_str(visitor)
422 }
423
424 fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
425 self.skip_value()?;
427 visitor.visit_unit()
428 }
429}
430
431impl Deserializer<'_> {
432 fn skip_value(&mut self) -> Result<()> {
434 match self.peek()? {
435 b'i' => {
436 self.pos += 1;
437 self.parse_integer_value()?;
438 Ok(())
439 }
440 b'l' => {
441 self.pos += 1;
442 while self.peek()? != b'e' {
443 self.skip_value()?;
444 }
445 self.pos += 1;
446 Ok(())
447 }
448 b'd' => {
449 self.pos += 1;
450 while self.peek()? != b'e' {
451 self.parse_byte_string()?; self.skip_value()?; }
454 self.pos += 1;
455 Ok(())
456 }
457 b'0'..=b'9' => {
458 self.parse_byte_string()?;
459 Ok(())
460 }
461 byte => Err(Error::UnexpectedByte {
462 byte,
463 position: self.pos,
464 expected: "bencode value",
465 }),
466 }
467 }
468}
469
470struct SeqAccess<'a, 'de> {
471 de: &'a mut Deserializer<'de>,
472}
473
474impl<'de> de::SeqAccess<'de> for SeqAccess<'_, 'de> {
475 type Error = Error;
476
477 fn next_element_seed<T: de::DeserializeSeed<'de>>(
478 &mut self,
479 seed: T,
480 ) -> Result<Option<T::Value>> {
481 if self.de.peek()? == b'e' {
482 return Ok(None);
483 }
484 seed.deserialize(&mut *self.de).map(Some)
485 }
486}
487
488struct MapAccess<'a, 'de> {
489 de: &'a mut Deserializer<'de>,
490 last_key: Option<Vec<u8>>,
491 strict_order: bool,
492}
493
494impl<'de> de::MapAccess<'de> for MapAccess<'_, 'de> {
495 type Error = Error;
496
497 fn next_key_seed<K: de::DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>> {
498 if self.de.peek()? == b'e' {
499 return Ok(None);
500 }
501
502 let key_start = self.de.pos;
504 let key_data = self.de.parse_byte_string()?;
505 let key_vec = key_data.to_vec();
506
507 if let Some(ref last) = self.last_key
509 && self.strict_order
510 && key_vec <= *last
511 {
512 return Err(Error::UnsortedKeys {
513 position: key_start,
514 });
515 }
516 self.last_key = Some(key_vec);
517
518 let key_de = BorrowedStrDeserializer(key_data);
521 seed.deserialize(key_de).map(Some)
522 }
523
524 fn next_value_seed<V: de::DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value> {
525 seed.deserialize(&mut *self.de)
526 }
527}
528
529struct BorrowedStrDeserializer<'de>(&'de [u8]);
531
532impl<'de> de::Deserializer<'de> for BorrowedStrDeserializer<'de> {
533 type Error = Error;
534
535 fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
536 visitor.visit_borrowed_bytes(self.0)
537 }
538
539 fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
540 let s = std::str::from_utf8(self.0)
541 .map_err(|_| Error::Custom("dict key is not valid UTF-8".into()))?;
542 visitor.visit_borrowed_str(s)
543 }
544
545 fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
546 self.deserialize_str(visitor)
547 }
548
549 fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
550 visitor.visit_borrowed_bytes(self.0)
551 }
552
553 fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
554 self.deserialize_bytes(visitor)
555 }
556
557 fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
558 self.deserialize_str(visitor)
559 }
560
561 fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
562 visitor.visit_unit()
563 }
564
565 fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
566 visitor.visit_some(self)
567 }
568
569 fn deserialize_newtype_struct<V: Visitor<'de>>(
570 self,
571 _name: &'static str,
572 visitor: V,
573 ) -> Result<V::Value> {
574 visitor.visit_newtype_struct(self)
575 }
576
577 serde::forward_to_deserialize_any! {
578 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char
579 unit unit_struct seq tuple tuple_struct map struct
580 enum
581 }
582}
583
584struct EnumAccess<'a, 'de> {
587 de: &'a mut Deserializer<'de>,
588}
589
590impl<'de> de::EnumAccess<'de> for EnumAccess<'_, 'de> {
591 type Error = Error;
592 type Variant = Self;
593
594 fn variant_seed<V: de::DeserializeSeed<'de>>(
595 self,
596 seed: V,
597 ) -> Result<(V::Value, Self::Variant)> {
598 let val = seed.deserialize(&mut *self.de)?;
599 Ok((val, self))
600 }
601}
602
603impl<'de> de::VariantAccess<'de> for EnumAccess<'_, 'de> {
604 type Error = Error;
605
606 fn unit_variant(self) -> Result<()> {
607 Err(Error::Custom(
608 "expected newtype/tuple/struct variant inside dict".into(),
609 ))
610 }
611
612 fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value> {
613 seed.deserialize(&mut *self.de)
614 }
615
616 fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
617 de::Deserializer::deserialize_seq(&mut *self.de, visitor)
618 }
619
620 fn struct_variant<V: Visitor<'de>>(
621 self,
622 _fields: &'static [&'static str],
623 visitor: V,
624 ) -> Result<V::Value> {
625 de::Deserializer::deserialize_map(&mut *self.de, visitor)
626 }
627}
628
629struct UnitVariantAccess<'a, 'de> {
630 de: &'a mut Deserializer<'de>,
631}
632
633impl<'de> de::EnumAccess<'de> for UnitVariantAccess<'_, 'de> {
634 type Error = Error;
635 type Variant = Self;
636
637 fn variant_seed<V: de::DeserializeSeed<'de>>(
638 self,
639 seed: V,
640 ) -> Result<(V::Value, Self::Variant)> {
641 let val = seed.deserialize(&mut *self.de)?;
642 Ok((val, self))
643 }
644}
645
646impl<'de> de::VariantAccess<'de> for UnitVariantAccess<'_, 'de> {
647 type Error = Error;
648
649 fn unit_variant(self) -> Result<()> {
650 Ok(())
651 }
652
653 fn newtype_variant_seed<T: de::DeserializeSeed<'de>>(self, _seed: T) -> Result<T::Value> {
654 Err(Error::Custom(
655 "expected unit variant for string enum".into(),
656 ))
657 }
658
659 fn tuple_variant<V: Visitor<'de>>(self, _len: usize, _visitor: V) -> Result<V::Value> {
660 Err(Error::Custom(
661 "expected unit variant for string enum".into(),
662 ))
663 }
664
665 fn struct_variant<V: Visitor<'de>>(
666 self,
667 _fields: &'static [&'static str],
668 _visitor: V,
669 ) -> Result<V::Value> {
670 Err(Error::Custom(
671 "expected unit variant for string enum".into(),
672 ))
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use crate::from_bytes;
679
680 #[test]
681 fn deserialize_integer() {
682 assert_eq!(from_bytes::<i64>(b"i42e").unwrap(), 42);
683 assert_eq!(from_bytes::<i64>(b"i0e").unwrap(), 0);
684 assert_eq!(from_bytes::<i64>(b"i-1e").unwrap(), -1);
685 }
686
687 #[test]
688 fn deserialize_string() {
689 assert_eq!(from_bytes::<String>(b"4:spam").unwrap(), "spam");
690 assert_eq!(from_bytes::<String>(b"0:").unwrap(), "");
691 }
692
693 #[test]
694 fn reject_negative_zero() {
695 assert!(from_bytes::<i64>(b"i-0e").is_err());
696 }
697
698 #[test]
699 fn reject_leading_zeros() {
700 assert!(from_bytes::<i64>(b"i03e").is_err());
701 }
702
703 #[test]
704 fn reject_trailing_data() {
705 assert!(from_bytes::<i64>(b"i42eXXX").is_err());
706 }
707
708 #[test]
709 fn strict_rejects_unsorted_dict_keys() {
710 let unsorted = b"d2:zz1:a2:aa1:be";
712 assert!(from_bytes::<std::collections::BTreeMap<String, String>>(unsorted).is_err());
713 }
714
715 #[test]
716 fn lenient_accepts_unsorted_dict_keys() {
717 use crate::from_bytes_lenient;
718 let unsorted = b"d2:zz1:a2:aa1:be";
720 let map: std::collections::BTreeMap<String, String> = from_bytes_lenient(unsorted).unwrap();
721 assert_eq!(map.get("zz").unwrap(), "a");
722 assert_eq!(map.get("aa").unwrap(), "b");
723 }
724}