1use buffers::ByteBuf;
2use serde::de::Error as DeError;
3
4pub struct BencodeDeserializer<'de> {
5 buf: &'de [u8],
6 field_context: Vec<ByteBuf<'de>>,
7 parsing_key: bool,
8
9 pub is_torrent_info: bool,
11 pub torrent_info_digest: Option<[u8; 20]>,
12 pub torrent_info_bytes: Option<&'de [u8]>,
13}
14
15impl<'de> BencodeDeserializer<'de> {
16 pub fn new_from_buf(buf: &'de [u8]) -> BencodeDeserializer<'de> {
17 Self {
18 buf,
19 field_context: Default::default(),
20 parsing_key: false,
21 is_torrent_info: false,
22 torrent_info_digest: None,
23 torrent_info_bytes: None,
24 }
25 }
26 pub fn into_remaining(self) -> &'de [u8] {
27 self.buf
28 }
29 fn parse_integer(&mut self) -> Result<i64, Error> {
30 match self.buf.iter().copied().position(|e| e == b'e') {
31 Some(end) => {
32 let intbytes = &self.buf[1..end];
33 let value: i64 = std::str::from_utf8(intbytes)
34 .map_err(|e| Error::new_from_err(e).set_context(self))?
35 .parse()
36 .map_err(|e| Error::new_from_err(e).set_context(self))?;
37 let rem = self.buf.get(end + 1..).unwrap_or_default();
38 self.buf = rem;
39 Ok(value)
40 }
41 None => Err(Error::custom("cannot parse integer, unexpected EOF").set_context(self)),
42 }
43 }
44
45 fn parse_bytes(&mut self) -> Result<&'de [u8], Error> {
46 match self.buf.iter().copied().position(|e| e == b':') {
47 Some(length_delim) => {
48 let lenbytes = &self.buf[..length_delim];
49 let length: usize = std::str::from_utf8(lenbytes)
50 .map_err(|e| Error::new_from_err(e).set_context(self))?
51 .parse()
52 .map_err(|e| Error::new_from_err(e).set_context(self))?;
53 let bytes_start = length_delim + 1;
54 let bytes_end = bytes_start + length;
55 let bytes = &self.buf.get(bytes_start..bytes_end).ok_or_else(|| {
56 Error::custom(format!(
57 "could not get byte range {}..{}, data in the buffer: {:?}",
58 bytes_start, bytes_end, &self.buf
59 ))
60 .set_context(self)
61 })?;
62 let rem = self.buf.get(bytes_end..).unwrap_or_default();
63 self.buf = rem;
64 Ok(bytes)
65 }
66 None => Err(Error::custom("cannot parse bytes, unexpected EOF").set_context(self)),
67 }
68 }
69
70 fn parse_bytes_checked(&mut self) -> Result<&'de [u8], Error> {
71 let first = match self.buf.first().copied() {
72 Some(first) => first,
73 None => return Err(Error::custom("expected bencode bytes, got EOF").set_context(self)),
74 };
75 match first {
76 b'0'..=b'9' => {}
77 _ => return Err(Error::custom("expected bencode bytes").set_context(self)),
78 }
79 let b = self.parse_bytes()?;
80 if self.parsing_key {
81 self.field_context.push(ByteBuf(b));
82 }
83 Ok(b)
84 }
85}
86
87pub fn from_bytes<'a, T>(buf: &'a [u8]) -> anyhow::Result<T>
88where
89 T: serde::de::Deserialize<'a>,
90{
91 let mut de = BencodeDeserializer::new_from_buf(buf);
92 let v = T::deserialize(&mut de)?;
93 if !de.buf.is_empty() {
94 anyhow::bail!(
95 "deserialized successfully, but {} bytes remaining",
96 de.buf.len()
97 )
98 }
99 Ok(v)
100}
101
102#[derive(Debug)]
103enum ErrorKind {
104 Other(anyhow::Error),
105 NotSupported(&'static str),
106}
107
108#[derive(Debug, Default)]
109pub struct ErrorContext {
110 field_stack: Vec<String>,
111}
112
113impl std::fmt::Display for ErrorContext {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 let mut it = self.field_stack.iter();
116 if let Some(field) = it.next() {
117 write!(f, "\"{field}\"")?;
118 } else {
119 return Ok(());
120 }
121 for field in self.field_stack.iter().skip(1) {
122 write!(f, " -> \"{field}\"")?;
123 }
124 f.write_str(": ")
125 }
126}
127
128#[derive(Debug)]
129pub struct Error {
130 kind: ErrorKind,
131 context: ErrorContext,
132}
133
134impl std::fmt::Display for ErrorKind {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 ErrorKind::Other(err) => err.fmt(f),
138 ErrorKind::NotSupported(s) => write!(f, "{s} is not supported by bencode"),
139 }
140 }
141}
142
143impl std::fmt::Display for Error {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(f, "{}{}", self.context, self.kind)
146 }
147}
148
149impl std::error::Error for Error {
150 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
151 match &self.kind {
152 ErrorKind::Other(err) => err.source(),
153 _ => None,
154 }
155 }
156}
157
158impl Error {
159 fn new_from_err<E>(e: E) -> Self
160 where
161 E: std::error::Error + Send + Sync + 'static,
162 {
163 Error {
164 kind: ErrorKind::Other(anyhow::Error::new(e)),
165 context: Default::default(),
166 }
167 }
168
169 fn new_from_kind(kind: ErrorKind) -> Self {
170 Self {
171 kind,
172 context: Default::default(),
173 }
174 }
175
176 fn new_from_anyhow(e: anyhow::Error) -> Self {
177 Error {
178 kind: ErrorKind::Other(e),
179 context: Default::default(),
180 }
181 }
182 fn custom_with_de<M: std::fmt::Display>(msg: M, de: &BencodeDeserializer<'_>) -> Self {
183 Self::custom(msg).set_context(de)
184 }
185 fn set_context(mut self, de: &BencodeDeserializer<'_>) -> Self {
186 self.context = ErrorContext {
187 field_stack: de.field_context.iter().map(|s| format!("{s}")).collect(),
188 };
189 self
190 }
191}
192
193impl serde::de::Error for Error {
194 fn custom<T>(msg: T) -> Self
195 where
196 T: std::fmt::Display,
197 {
198 Self {
199 kind: ErrorKind::Other(anyhow::anyhow!("{}", msg)),
200 context: Default::default(),
201 }
202 }
203}
204
205impl<'de> serde::de::Deserializer<'de> for &mut BencodeDeserializer<'de> {
206 type Error = Error;
207
208 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209 where
210 V: serde::de::Visitor<'de>,
211 {
212 match self.buf.first().copied() {
213 Some(b'd') => self.deserialize_map(visitor),
214 Some(b'i') => self.deserialize_u64(visitor),
215 Some(b'l') => self.deserialize_seq(visitor),
216 Some(_) => self.deserialize_bytes(visitor),
217 None => Err(Error::custom_with_de("empty input", self)),
218 }
219 }
220
221 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222 where
223 V: serde::de::Visitor<'de>,
224 {
225 if !self.buf.starts_with(b"i") {
226 return Err(Error::custom_with_de(
227 "expected bencode int to represent bool",
228 self,
229 ));
230 }
231 let value = self.parse_integer()?;
232 if value > 1 {
233 return Err(Error::custom_with_de(
234 format!("expected 0 or 1 for boolean, but got {value}"),
235 self,
236 ));
237 }
238 visitor
239 .visit_bool(value == 1)
240 .map_err(|e: Self::Error| e.set_context(self))
241 }
242
243 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
244 where
245 V: serde::de::Visitor<'de>,
246 {
247 self.deserialize_i64(visitor)
248 }
249
250 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
251 where
252 V: serde::de::Visitor<'de>,
253 {
254 self.deserialize_i64(visitor)
255 }
256
257 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
258 where
259 V: serde::de::Visitor<'de>,
260 {
261 self.deserialize_i64(visitor)
262 }
263
264 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265 where
266 V: serde::de::Visitor<'de>,
267 {
268 if !self.buf.starts_with(b"i") {
269 return Err(Error::custom_with_de("expected bencode int", self));
270 }
271 visitor
272 .visit_i64(self.parse_integer()?)
273 .map_err(|e: Self::Error| e.set_context(self))
274 }
275
276 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
277 where
278 V: serde::de::Visitor<'de>,
279 {
280 self.deserialize_i64(visitor)
281 }
282
283 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
284 where
285 V: serde::de::Visitor<'de>,
286 {
287 self.deserialize_i64(visitor)
288 }
289
290 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
291 where
292 V: serde::de::Visitor<'de>,
293 {
294 self.deserialize_i64(visitor)
295 }
296
297 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
298 where
299 V: serde::de::Visitor<'de>,
300 {
301 self.deserialize_i64(visitor)
302 }
303
304 fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
305 where
306 V: serde::de::Visitor<'de>,
307 {
308 Err(
309 Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
310 .set_context(self),
311 )
312 }
313
314 fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
315 where
316 V: serde::de::Visitor<'de>,
317 {
318 Err(
319 Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support floats"))
320 .set_context(self),
321 )
322 }
323
324 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
325 where
326 V: serde::de::Visitor<'de>,
327 {
328 Err(
329 Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't support chars"))
330 .set_context(self),
331 )
332 }
333
334 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
335 where
336 V: serde::de::Visitor<'de>,
337 {
338 let first = match self.buf.first().copied() {
339 Some(first) => first,
340 None => {
341 return Err(Error::custom_with_de(
342 "expected bencode string, got EOF",
343 self,
344 ))
345 }
346 };
347 match first {
348 b'0'..=b'9' => {}
349 _ => return Err(Error::custom_with_de("expected bencode string", self)),
350 }
351 let b = self.parse_bytes()?;
352 let s = std::str::from_utf8(b).map_err(|e| {
353 Error::new_from_anyhow(anyhow::anyhow!("error reading utf-8: {}", e)).set_context(self)
354 })?;
355 visitor
356 .visit_borrowed_str(s)
357 .map_err(|e: Self::Error| e.set_context(self))
358 }
359
360 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
361 where
362 V: serde::de::Visitor<'de>,
363 {
364 self.deserialize_str(visitor)
365 }
366
367 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
368 where
369 V: serde::de::Visitor<'de>,
370 {
371 let b = self.parse_bytes_checked()?;
372 visitor
373 .visit_borrowed_bytes(b)
374 .map_err(|e: Self::Error| e.set_context(self))
375 }
376
377 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
378 where
379 V: serde::de::Visitor<'de>,
380 {
381 self.deserialize_bytes(visitor)
382 }
383
384 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
385 where
386 V: serde::de::Visitor<'de>,
387 {
388 visitor
389 .visit_some(&mut *self)
390 .map_err(|e: Self::Error| e.set_context(self))
391 }
392
393 fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
394 where
395 V: serde::de::Visitor<'de>,
396 {
397 Err(Error::new_from_kind(ErrorKind::NotSupported(
398 "bencode doesn't support unit types",
399 ))
400 .set_context(self))
401 }
402
403 fn deserialize_unit_struct<V>(
404 self,
405 _name: &'static str,
406 _visitor: V,
407 ) -> Result<V::Value, Self::Error>
408 where
409 V: serde::de::Visitor<'de>,
410 {
411 Err(Error::new_from_kind(ErrorKind::NotSupported(
412 "bencode doesn't support unit structs",
413 ))
414 .set_context(self))
415 }
416
417 fn deserialize_newtype_struct<V>(
418 self,
419 _name: &'static str,
420 _visitor: V,
421 ) -> Result<V::Value, Self::Error>
422 where
423 V: serde::de::Visitor<'de>,
424 {
425 Err(
426 Error::new_from_kind(ErrorKind::NotSupported("bencode doesn't newtype structs"))
427 .set_context(self),
428 )
429 }
430
431 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
432 where
433 V: serde::de::Visitor<'de>,
434 {
435 if !self.buf.starts_with(b"l") {
436 return Err(Error::custom(format!(
437 "expected bencode list, but got {}",
438 self.buf[0] as char,
439 )));
440 }
441 self.buf = self.buf.get(1..).unwrap_or_default();
442 visitor
443 .visit_seq(SeqAccess { de: self })
444 .map_err(|e: Self::Error| e.set_context(self))
445 }
446
447 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
448 where
449 V: serde::de::Visitor<'de>,
450 {
451 self.deserialize_seq(visitor)
452 }
453
454 fn deserialize_tuple_struct<V>(
455 self,
456 _name: &'static str,
457 _len: usize,
458 visitor: V,
459 ) -> Result<V::Value, Self::Error>
460 where
461 V: serde::de::Visitor<'de>,
462 {
463 self.deserialize_seq(visitor)
464 }
465
466 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
467 where
468 V: serde::de::Visitor<'de>,
469 {
470 if !self.buf.starts_with(b"d") {
471 return Err(Error::custom("expected bencode dict"));
472 }
473 self.buf = self.buf.get(1..).unwrap_or_default();
474 visitor
475 .visit_map(MapAccess { de: self })
476 .map_err(|e: Self::Error| e.set_context(self))
477 }
478
479 fn deserialize_struct<V>(
480 self,
481 _name: &'static str,
482 _fields: &'static [&'static str],
483 visitor: V,
484 ) -> Result<V::Value, Self::Error>
485 where
486 V: serde::de::Visitor<'de>,
487 {
488 self.deserialize_map(visitor)
489 }
490
491 fn deserialize_enum<V>(
492 self,
493 _name: &'static str,
494 _variants: &'static [&'static str],
495 _visitor: V,
496 ) -> Result<V::Value, Self::Error>
497 where
498 V: serde::de::Visitor<'de>,
499 {
500 Err(
501 Error::new_from_kind(ErrorKind::NotSupported("deserializing enums not supported"))
502 .set_context(self),
503 )
504 }
505
506 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
507 where
508 V: serde::de::Visitor<'de>,
509 {
510 let name = self.parse_bytes_checked()?;
511 visitor
512 .visit_borrowed_bytes(name)
513 .map_err(|e: Self::Error| e.set_context(self))
514 }
515
516 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
517 where
518 V: serde::de::Visitor<'de>,
519 {
520 self.deserialize_any(visitor)
521 }
522}
523
524struct MapAccess<'a, 'de> {
525 de: &'a mut BencodeDeserializer<'de>,
526}
527
528struct SeqAccess<'a, 'de> {
529 de: &'a mut BencodeDeserializer<'de>,
530}
531
532impl<'de> serde::de::MapAccess<'de> for MapAccess<'_, 'de> {
533 type Error = Error;
534
535 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
536 where
537 K: serde::de::DeserializeSeed<'de>,
538 {
539 if self.de.buf.starts_with(b"e") {
540 self.de.buf = self.de.buf.get(1..).unwrap_or_default();
541 return Ok(None);
542 }
543 self.de.parsing_key = true;
544 let retval = seed.deserialize(&mut *self.de)?;
545 self.de.parsing_key = false;
546 Ok(Some(retval))
547 }
548
549 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
550 where
551 V: serde::de::DeserializeSeed<'de>,
552 {
553 #[cfg(any(feature = "sha1-crypto-hash", feature = "sha1-ring"))]
554 let buf_before = self.de.buf;
555 let value = seed.deserialize(&mut *self.de)?;
556 #[cfg(any(feature = "sha1-crypto-hash", feature = "sha1-ring"))]
557 {
558 use sha1w::{ISha1, Sha1};
559 if self.de.is_torrent_info && self.de.field_context.as_slice() == [ByteBuf(b"info")] {
560 let len = self.de.buf.as_ptr() as usize - buf_before.as_ptr() as usize;
561 let mut hash = Sha1::new();
562 let torrent_info_bytes = &buf_before[..len];
563 hash.update(torrent_info_bytes);
564 let digest = hash.finish();
565 self.de.torrent_info_digest = Some(digest);
566 self.de.torrent_info_bytes = Some(torrent_info_bytes);
567 }
568 }
569 self.de.field_context.pop();
570 Ok(value)
571 }
572}
573
574impl<'de> serde::de::SeqAccess<'de> for SeqAccess<'_, 'de> {
575 type Error = Error;
576
577 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
578 where
579 T: serde::de::DeserializeSeed<'de>,
580 {
581 if self.de.buf.starts_with(b"e") {
582 self.de.buf = self.de.buf.get(1..).unwrap_or_default();
583 return Ok(None);
584 }
585 Ok(Some(seed.deserialize(&mut *self.de)?))
586 }
587}