librqbit_bencode/
serde_bencode_de.rs

1use buffers::ByteBuf;
2use serde::de::Error as DeError;
3
4pub struct BencodeDeserializer<'de> {
5    buf: &'de [u8],
6    field_context: Vec<ByteBuf<'de>>,
7    parsing_key: bool,
8
9    // This is a f**ing hack
10    pub is_torrent_info: bool,
11    pub torrent_info_digest: Option<[u8; 20]>,
12    pub torrent_info_bytes: Option<&'de [u8]>,
13}
14
15impl<'de> BencodeDeserializer<'de> {
16    pub fn new_from_buf(buf: &'de [u8]) -> BencodeDeserializer<'de> {
17        Self {
18            buf,
19            field_context: Default::default(),
20            parsing_key: false,
21            is_torrent_info: false,
22            torrent_info_digest: None,
23            torrent_info_bytes: None,
24        }
25    }
26    pub fn into_remaining(self) -> &'de [u8] {
27        self.buf
28    }
29    fn parse_integer(&mut self) -> Result<i64, Error> {
30        match self.buf.iter().copied().position(|e| e == b'e') {
31            Some(end) => {
32                let intbytes = &self.buf[1..end];
33                let value: i64 = std::str::from_utf8(intbytes)
34                    .map_err(|e| Error::new_from_err(e).set_context(self))?
35                    .parse()
36                    .map_err(|e| Error::new_from_err(e).set_context(self))?;
37                let rem = self.buf.get(end + 1..).unwrap_or_default();
38                self.buf = rem;
39                Ok(value)
40            }
41            None => Err(Error::custom("cannot parse integer, unexpected EOF").set_context(self)),
42        }
43    }
44
45    fn parse_bytes(&mut self) -> Result<&'de [u8], Error> {
46        match self.buf.iter().copied().position(|e| e == b':') {
47            Some(length_delim) => {
48                let lenbytes = &self.buf[..length_delim];
49                let length: usize = std::str::from_utf8(lenbytes)
50                    .map_err(|e| Error::new_from_err(e).set_context(self))?
51                    .parse()
52                    .map_err(|e| Error::new_from_err(e).set_context(self))?;
53                let bytes_start = length_delim + 1;
54                let bytes_end = bytes_start + length;
55                let bytes = &self.buf.get(bytes_start..bytes_end).ok_or_else(|| {
56                    Error::custom(format!(
57                        "could not get byte range {}..{}, data in the buffer: {:?}",
58                        bytes_start, bytes_end, &self.buf
59                    ))
60                    .set_context(self)
61                })?;
62                let rem = self.buf.get(bytes_end..).unwrap_or_default();
63                self.buf = rem;
64                Ok(bytes)
65            }
66            None => Err(Error::custom("cannot parse bytes, unexpected EOF").set_context(self)),
67        }
68    }
69
70    fn parse_bytes_checked(&mut self) -> Result<&'de [u8], Error> {
71        let first = match self.buf.first().copied() {
72            Some(first) => first,
73            None => return Err(Error::custom("expected bencode bytes, got EOF").set_context(self)),
74        };
75        match first {
76            b'0'..=b'9' => {}
77            _ => return Err(Error::custom("expected bencode bytes").set_context(self)),
78        }
79        let b = self.parse_bytes()?;
80        if self.parsing_key {
81            self.field_context.push(ByteBuf(b));
82        }
83        Ok(b)
84    }
85}
86
87pub fn from_bytes<'a, T>(buf: &'a [u8]) -> anyhow::Result<T>
88where
89    T: serde::de::Deserialize<'a>,
90{
91    let mut de = BencodeDeserializer::new_from_buf(buf);
92    let v = T::deserialize(&mut de)?;
93    if !de.buf.is_empty() {
94        anyhow::bail!(
95            "deserialized successfully, but {} bytes remaining",
96            de.buf.len()
97        )
98    }
99    Ok(v)
100}
101
102#[derive(Debug)]
103enum ErrorKind {
104    Other(anyhow::Error),
105    NotSupported(&'static str),
106}
107
108#[derive(Debug, Default)]
109pub struct ErrorContext {
110    field_stack: Vec<String>,
111}
112
113impl std::fmt::Display for ErrorContext {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        let mut it = self.field_stack.iter();
116        if let Some(field) = it.next() {
117            write!(f, "\"{field}\"")?;
118        } else {
119            return Ok(());
120        }
121        for field in self.field_stack.iter().skip(1) {
122            write!(f, " -> \"{field}\"")?;
123        }
124        f.write_str(": ")
125    }
126}
127
128#[derive(Debug)]
129pub struct Error {
130    kind: ErrorKind,
131    context: ErrorContext,
132}
133
134impl std::fmt::Display for ErrorKind {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            ErrorKind::Other(err) => err.fmt(f),
138            ErrorKind::NotSupported(s) => write!(f, "{s} is not supported by bencode"),
139        }
140    }
141}
142
143impl std::fmt::Display for Error {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        write!(f, "{}{}", self.context, self.kind)
146    }
147}
148
149impl std::error::Error for Error {
150    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
151        match &self.kind {
152            ErrorKind::Other(err) => err.source(),
153            _ => None,
154        }
155    }
156}
157
158impl Error {
159    fn new_from_err<E>(e: E) -> Self
160    where
161        E: std::error::Error + Send + Sync + 'static,
162    {
163        Error {
164            kind: ErrorKind::Other(anyhow::Error::new(e)),
165            context: Default::default(),
166        }
167    }
168
169    fn new_from_kind(kind: ErrorKind) -> Self {
170        Self {
171            kind,
172            context: Default::default(),
173        }
174    }
175
176    fn new_from_anyhow(e: anyhow::Error) -> Self {
177        Error {
178            kind: ErrorKind::Other(e),
179            context: Default::default(),
180        }
181    }
182    fn custom_with_de<M: std::fmt::Display>(msg: M, de: &BencodeDeserializer<'_>) -> Self {
183        Self::custom(msg).set_context(de)
184    }
185    fn set_context(mut self, de: &BencodeDeserializer<'_>) -> Self {
186        self.context = ErrorContext {
187            field_stack: de.field_context.iter().map(|s| format!("{s}")).collect(),
188        };
189        self
190    }
191}
192
193impl serde::de::Error for Error {
194    fn custom<T>(msg: T) -> Self
195    where
196        T: std::fmt::Display,
197    {
198        Self {
199            kind: ErrorKind::Other(anyhow::anyhow!("{}", msg)),
200            context: Default::default(),
201        }
202    }
203}
204
205impl<'de> serde::de::Deserializer<'de> for &mut BencodeDeserializer<'de> {
206    type Error = Error;
207
208    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209    where
210        V: serde::de::Visitor<'de>,
211    {
212        match self.buf.first().copied() {
213            Some(b'd') => self.deserialize_map(visitor),
214            Some(b'i') => self.deserialize_u64(visitor),
215            Some(b'l') => self.deserialize_seq(visitor),
216            Some(_) => self.deserialize_bytes(visitor),
217            None => Err(Error::custom_with_de("empty input", self)),
218        }
219    }
220
221    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222    where
223        V: serde::de::Visitor<'de>,
224    {
225        if !self.buf.starts_with(b"i") {
226            return Err(Error::custom_with_de(
227                "expected bencode int to represent bool",
228                self,
229            ));
230        }
231        let value = self.parse_integer()?;
232        if value > 1 {
233            return Err(Error::custom_with_de(
234                format!("expected 0 or 1 for boolean, but got {value}"),
235                self,
236            ));
237        }
238        visitor
239            .visit_bool(value == 1)
240            .map_err(|e: Self::Error| e.set_context(self))
241    }
242
243    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
244    where
245        V: serde::de::Visitor<'de>,
246    {
247        self.deserialize_i64(visitor)
248    }
249
250    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
251    where
252        V: serde::de::Visitor<'de>,
253    {
254        self.deserialize_i64(visitor)
255    }
256
257    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
258    where
259        V: serde::de::Visitor<'de>,
260    {
261        self.deserialize_i64(visitor)
262    }
263
264    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: serde::de::Visitor<'de>,
267    {
268        if !self.buf.starts_with(b"i") {
269            return Err(Error::custom_with_de("expected bencode int", self));
270        }
271        visitor
272            .visit_i64(self.parse_integer()?)
273            .map_err(|e: Self::Error| e.set_context(self))
274    }
275
276    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
277    where
278        V: serde::de::Visitor<'de>,
279    {
280        self.deserialize_i64(visitor)
281    }
282
283    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
284    where
285        V: serde::de::Visitor<'de>,
286    {
287        self.deserialize_i64(visitor)
288    }
289
290    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
291    where
292        V: serde::de::Visitor<'de>,
293    {
294        self.deserialize_i64(visitor)
295    }
296
297    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
298    where
299        V: serde::de::Visitor<'de>,
300    {
301        self.deserialize_i64(visitor)
302    }
303
304    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
305    where
306        V: serde::de::Visitor<'de>,
307    {
308        Err(
309            Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
310                .set_context(self),
311        )
312    }
313
314    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
315    where
316        V: serde::de::Visitor<'de>,
317    {
318        Err(
319            Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
320                .set_context(self),
321        )
322    }
323
324    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
325    where
326        V: serde::de::Visitor<'de>,
327    {
328        Err(
329            Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support chars"))
330                .set_context(self),
331        )
332    }
333
334    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
335    where
336        V: serde::de::Visitor<'de>,
337    {
338        let first = match self.buf.first().copied() {
339            Some(first) => first,
340            None => {
341                return Err(Error::custom_with_de(
342                    "expected bencode string, got EOF",
343                    self,
344                ))
345            }
346        };
347        match first {
348            b'0'..=b'9' => {}
349            _ => return Err(Error::custom_with_de("expected bencode string", self)),
350        }
351        let b = self.parse_bytes()?;
352        let s = std::str::from_utf8(b).map_err(|e| {
353            Error::new_from_anyhow(anyhow::anyhow!("error reading utf-8: {}", e)).set_context(self)
354        })?;
355        visitor
356            .visit_borrowed_str(s)
357            .map_err(|e: Self::Error| e.set_context(self))
358    }
359
360    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
361    where
362        V: serde::de::Visitor<'de>,
363    {
364        self.deserialize_str(visitor)
365    }
366
367    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
368    where
369        V: serde::de::Visitor<'de>,
370    {
371        let b = self.parse_bytes_checked()?;
372        visitor
373            .visit_borrowed_bytes(b)
374            .map_err(|e: Self::Error| e.set_context(self))
375    }
376
377    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
378    where
379        V: serde::de::Visitor<'de>,
380    {
381        self.deserialize_bytes(visitor)
382    }
383
384    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
385    where
386        V: serde::de::Visitor<'de>,
387    {
388        visitor
389            .visit_some(&mut *self)
390            .map_err(|e: Self::Error| e.set_context(self))
391    }
392
393    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
394    where
395        V: serde::de::Visitor<'de>,
396    {
397        Err(Error::new_from_kind(ErrorKind::NotSupported(
398            "bencode doesn't support unit types",
399        ))
400        .set_context(self))
401    }
402
403    fn deserialize_unit_struct<V>(
404        self,
405        _name: &'static str,
406        _visitor: V,
407    ) -> Result<V::Value, Self::Error>
408    where
409        V: serde::de::Visitor<'de>,
410    {
411        Err(Error::new_from_kind(ErrorKind::NotSupported(
412            "bencode doesn't support unit structs",
413        ))
414        .set_context(self))
415    }
416
417    fn deserialize_newtype_struct<V>(
418        self,
419        _name: &'static str,
420        _visitor: V,
421    ) -> Result<V::Value, Self::Error>
422    where
423        V: serde::de::Visitor<'de>,
424    {
425        Err(
426            Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't newtype structs"))
427                .set_context(self),
428        )
429    }
430
431    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
432    where
433        V: serde::de::Visitor<'de>,
434    {
435        if !self.buf.starts_with(b"l") {
436            return Err(Error::custom(format!(
437                "expected bencode list, but got {}",
438                self.buf[0] as char,
439            )));
440        }
441        self.buf = self.buf.get(1..).unwrap_or_default();
442        visitor
443            .visit_seq(SeqAccess { de: self })
444            .map_err(|e: Self::Error| e.set_context(self))
445    }
446
447    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
448    where
449        V: serde::de::Visitor<'de>,
450    {
451        self.deserialize_seq(visitor)
452    }
453
454    fn deserialize_tuple_struct<V>(
455        self,
456        _name: &'static str,
457        _len: usize,
458        visitor: V,
459    ) -> Result<V::Value, Self::Error>
460    where
461        V: serde::de::Visitor<'de>,
462    {
463        self.deserialize_seq(visitor)
464    }
465
466    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
467    where
468        V: serde::de::Visitor<'de>,
469    {
470        if !self.buf.starts_with(b"d") {
471            return Err(Error::custom("expected bencode dict"));
472        }
473        self.buf = self.buf.get(1..).unwrap_or_default();
474        visitor
475            .visit_map(MapAccess { de: self })
476            .map_err(|e: Self::Error| e.set_context(self))
477    }
478
479    fn deserialize_struct<V>(
480        self,
481        _name: &'static str,
482        _fields: &'static [&'static str],
483        visitor: V,
484    ) -> Result<V::Value, Self::Error>
485    where
486        V: serde::de::Visitor<'de>,
487    {
488        self.deserialize_map(visitor)
489    }
490
491    fn deserialize_enum<V>(
492        self,
493        _name: &'static str,
494        _variants: &'static [&'static str],
495        _visitor: V,
496    ) -> Result<V::Value, Self::Error>
497    where
498        V: serde::de::Visitor<'de>,
499    {
500        Err(
501            Error::new_from_kind(ErrorKind::NotSupported("deserializing enums not supported"))
502                .set_context(self),
503        )
504    }
505
506    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
507    where
508        V: serde::de::Visitor<'de>,
509    {
510        let name = self.parse_bytes_checked()?;
511        visitor
512            .visit_borrowed_bytes(name)
513            .map_err(|e: Self::Error| e.set_context(self))
514    }
515
516    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
517    where
518        V: serde::de::Visitor<'de>,
519    {
520        self.deserialize_any(visitor)
521    }
522}
523
524struct MapAccess<'a, 'de> {
525    de: &'a mut BencodeDeserializer<'de>,
526}
527
528struct SeqAccess<'a, 'de> {
529    de: &'a mut BencodeDeserializer<'de>,
530}
531
532impl<'de> serde::de::MapAccess<'de> for MapAccess<'_, 'de> {
533    type Error = Error;
534
535    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
536    where
537        K: serde::de::DeserializeSeed<'de>,
538    {
539        if self.de.buf.starts_with(b"e") {
540            self.de.buf = self.de.buf.get(1..).unwrap_or_default();
541            return Ok(None);
542        }
543        self.de.parsing_key = true;
544        let retval = seed.deserialize(&mut *self.de)?;
545        self.de.parsing_key = false;
546        Ok(Some(retval))
547    }
548
549    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
550    where
551        V: serde::de::DeserializeSeed<'de>,
552    {
553        #[cfg(any(feature = "sha1-crypto-hash", feature = "sha1-ring"))]
554        let buf_before = self.de.buf;
555        let value = seed.deserialize(&mut *self.de)?;
556        #[cfg(any(feature = "sha1-crypto-hash", feature = "sha1-ring"))]
557        {
558            use sha1w::{ISha1, Sha1};
559            if self.de.is_torrent_info && self.de.field_context.as_slice() == [ByteBuf(b"info")] {
560                let len = self.de.buf.as_ptr() as usize - buf_before.as_ptr() as usize;
561                let mut hash = Sha1::new();
562                let torrent_info_bytes = &buf_before[..len];
563                hash.update(torrent_info_bytes);
564                let digest = hash.finish();
565                self.de.torrent_info_digest = Some(digest);
566                self.de.torrent_info_bytes = Some(torrent_info_bytes);
567            }
568        }
569        self.de.field_context.pop();
570        Ok(value)
571    }
572}
573
574impl<'de> serde::de::SeqAccess<'de> for SeqAccess<'_, 'de> {
575    type Error = Error;
576
577    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
578    where
579        T: serde::de::DeserializeSeed<'de>,
580    {
581        if self.de.buf.starts_with(b"e") {
582            self.de.buf = self.de.buf.get(1..).unwrap_or_default();
583            return Ok(None);
584        }
585        Ok(Some(seed.deserialize(&mut *self.de)?))
586    }
587}