1use core::str;
2
3use serde::{Deserialize, Deserializer};
4
5use crate::{Error, ErrorKind};
6
7pub struct BencodeDeserializer<'de> {
8 input: &'de [u8],
9 position: usize,
10}
11
12impl<'de> BencodeDeserializer<'de> {
13 pub fn from_str(input: &'de str) -> Self {
14 BencodeDeserializer {
15 input: input.as_bytes(),
16 position: 0,
17 }
18 }
19
20 pub fn from_bytes(input: &'de [u8]) -> Self {
21 BencodeDeserializer { input, position: 0 }
22 }
23
24 pub fn error(&self, kind: ErrorKind) -> Error {
25 let err: Error = kind.into();
26 err.set_position(self.position)
27 }
28
29 pub(crate) fn move_cursor(&mut self, by: usize) {
30 self.position += by;
31 self.input = &self.input[by..]
32 }
33}
34
35impl<'de> BencodeDeserializer<'de> {
36 fn parse_integer(&mut self) -> Result<i64, Error> {
37 match self.input.iter().position(|byte| *byte == b'e') {
38 Some(pos) => {
39 let integer = &self.input[..pos];
40 let parsed_integer: i64 = std::str::from_utf8(integer)
41 .map_err(|_| self.error(ErrorKind::BadInputData("Bad integer")))?
42 .parse()
43 .map_err(|_| {
44 self.error(ErrorKind::BadInputData(
45 "Unnable to parse integer from the provided data",
46 ))
47 })?;
48 self.move_cursor(pos + 1);
49 Ok(parsed_integer)
50 }
51 _ => Err(self.error(ErrorKind::BadInputData("expected closing delimiter for integer"))),
52 }
53 }
54
55 fn parse_bytes(&mut self) -> Result<&'de [u8], Error> {
56 match self.input.iter().position(|byte| *byte == b':') {
57 Some(delim_pos) => {
58 let raw_bytes_len = &self.input[..delim_pos];
59 let bytes_len: usize = std::str::from_utf8(raw_bytes_len)
60 .map_err(|_| self.error(ErrorKind::BadInputData("bytes length is not valid utf8")))?
61 .parse()
62 .map_err(|_| self.error(ErrorKind::BadInputData("expected valid bytes length")))?;
63 self.move_cursor(delim_pos + 1);
65 let raw_bytes = &self.input[..bytes_len];
66 self.move_cursor(bytes_len);
67 Ok(raw_bytes)
68 }
69 _ => Err(self.error(ErrorKind::BadInputData("expected bytes delimiter ':'"))),
70 }
71 }
72
73 fn parse_bytes_checked(&mut self) -> Result<&'de [u8], Error> {
74 match self
75 .input
76 .first()
77 .ok_or(self.error(ErrorKind::UnexpectedEof("bytes")))?
78 {
79 b'0'..=b'9' => self.parse_bytes(),
80 _ => Err(self.error(ErrorKind::BadInputData("expected bytes length"))),
81 }
82 }
83}
84
85impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BencodeDeserializer<'de> {
86 type Error = Error;
87
88 fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
89 where
90 V: serde::de::Visitor<'de>,
91 {
92 Err(self.error(ErrorKind::Unsupported("bool")))
93 }
94
95 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
96 where
97 V: serde::de::Visitor<'de>,
98 {
99 match self
100 .input
101 .first()
102 .ok_or(self.error(ErrorKind::UnexpectedEof("any bencode value")))?
103 {
104 b'd' => self.deserialize_map(visitor),
105 b'l' => self.deserialize_seq(visitor),
106 b'i' => self.deserialize_i64(visitor),
107 _ => self.deserialize_bytes(visitor),
108 }
109 }
110
111 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
112 where
113 V: serde::de::Visitor<'de>,
114 {
115 self.deserialize_i64(visitor)
116 }
117
118 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
119 where
120 V: serde::de::Visitor<'de>,
121 {
122 self.deserialize_i64(visitor)
123 }
124
125 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126 where
127 V: serde::de::Visitor<'de>,
128 {
129 self.deserialize_i64(visitor)
130 }
131
132 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
133 where
134 V: serde::de::Visitor<'de>,
135 {
136 match self
137 .input
138 .first()
139 .ok_or(self.error(ErrorKind::UnexpectedEof("integer")))?
140 {
141 b'i' => {
142 self.move_cursor(1);
144 visitor.visit_i64(self.parse_integer()?)
145 }
146 _ => Err(self.error(ErrorKind::BadInputData("expected integer label 'i'"))),
147 }
148 }
149
150 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
151 where
152 V: serde::de::Visitor<'de>,
153 {
154 self.deserialize_i64(visitor)
155 }
156
157 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
158 where
159 V: serde::de::Visitor<'de>,
160 {
161 self.deserialize_i64(visitor)
162 }
163
164 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
165 where
166 V: serde::de::Visitor<'de>,
167 {
168 self.deserialize_i64(visitor)
169 }
170
171 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
172 where
173 V: serde::de::Visitor<'de>,
174 {
175 self.deserialize_i64(visitor)
176 }
177
178 fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
179 where
180 V: serde::de::Visitor<'de>,
181 {
182 Err(self.error(ErrorKind::Unsupported("f32")))
183 }
184
185 fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
186 where
187 V: serde::de::Visitor<'de>,
188 {
189 Err(self.error(ErrorKind::Unsupported("f64")))
190 }
191
192 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
193 where
194 V: serde::de::Visitor<'de>,
195 {
196 Err(self.error(ErrorKind::Unsupported("char")))
197 }
198
199 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
200 where
201 V: serde::de::Visitor<'de>,
202 {
203 let str = str::from_utf8(self.parse_bytes_checked()?)
204 .map_err(|_| self.error(ErrorKind::BadInputData("expected valid utf8 string")))?;
205 visitor.visit_borrowed_str(str)
206 }
207
208 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209 where
210 V: serde::de::Visitor<'de>,
211 {
212 self.deserialize_str(visitor)
213 }
214
215 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216 where
217 V: serde::de::Visitor<'de>,
218 {
219 visitor.visit_borrowed_bytes(self.parse_bytes_checked()?)
220 }
221
222 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223 where
224 V: serde::de::Visitor<'de>,
225 {
226 self.deserialize_bytes(visitor)
227 }
228
229 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
230 where
231 V: serde::de::Visitor<'de>,
232 {
233 match self
234 .input
235 .first()
236 .ok_or(self.error(ErrorKind::UnexpectedEof("null string")))?
237 {
238 b'0' => {
239 let _ = self.parse_bytes()?;
240 visitor.visit_none()
241 }
242 _ => visitor.visit_some(&mut *self),
243 }
244 }
245
246 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
247 where
248 V: serde::de::Visitor<'de>,
249 {
250 let bytes = self.parse_bytes_checked()?;
251 if !bytes.is_empty() {
252 return Err(self.error(ErrorKind::BadInputData("expected bencode string of length 0")));
253 }
254 visitor.visit_unit()
255 }
256
257 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
258 where
259 V: serde::de::Visitor<'de>,
260 {
261 self.deserialize_unit(visitor)
262 }
263
264 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
265 where
266 V: serde::de::Visitor<'de>,
267 {
268 visitor.visit_newtype_struct(self)
269 }
270
271 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
272 where
273 V: serde::de::Visitor<'de>,
274 {
275 match self
276 .input
277 .first()
278 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode list")))?
279 {
280 b'l' => {
281 self.move_cursor(1);
283 let value = visitor.visit_seq(BencodeAccessor { de: self });
284 match self
285 .input
286 .first()
287 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode list end")))?
288 {
289 b'e' => {
290 self.move_cursor(1);
292 value
293 }
294 _ => Err(self.error(ErrorKind::BadInputData("expected bencode list end"))),
295 }
296 }
297 _ => Err(self.error(ErrorKind::BadInputData("expected bencode list"))),
298 }
299 }
300
301 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
302 where
303 V: serde::de::Visitor<'de>,
304 {
305 self.deserialize_seq(visitor)
306 }
307
308 fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
309 where
310 V: serde::de::Visitor<'de>,
311 {
312 self.deserialize_seq(visitor)
313 }
314
315 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
316 where
317 V: serde::de::Visitor<'de>,
318 {
319 match self
320 .input
321 .first()
322 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary")))?
323 {
324 b'd' => {
325 self.move_cursor(1);
327 let value = visitor.visit_map(BencodeAccessor { de: self });
328 match self
329 .input
330 .first()
331 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary end")))?
332 {
333 b'e' => {
334 self.move_cursor(1);
336 value
337 }
338 _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary end"))),
339 }
340 }
341 _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary"))),
342 }
343 }
344
345 fn deserialize_struct<V>(
346 self,
347 _name: &'static str,
348 _fields: &'static [&'static str],
349 visitor: V,
350 ) -> Result<V::Value, Self::Error>
351 where
352 V: serde::de::Visitor<'de>,
353 {
354 self.deserialize_map(visitor)
355 }
356
357 fn deserialize_enum<V>(
358 self,
359 _name: &'static str,
360 _variants: &'static [&'static str],
361 visitor: V,
362 ) -> Result<V::Value, Self::Error>
363 where
364 V: serde::de::Visitor<'de>,
365 {
366 match self
367 .input
368 .first()
369 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary")))?
370 {
371 b'd' => {
372 self.move_cursor(1);
374 let value = visitor.visit_enum(BencodeAccessor { de: self });
375 match self
376 .input
377 .first()
378 .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary end")))?
379 {
380 b'e' => {
381 self.move_cursor(1);
383 value
384 }
385 _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary end"))),
386 }
387 }
388 _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary"))),
389 }
390 }
391
392 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
393 where
394 V: serde::de::Visitor<'de>,
395 {
396 self.deserialize_str(visitor)
397 }
398
399 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
400 where
401 V: serde::de::Visitor<'de>,
402 {
403 self.deserialize_any(visitor)
404 }
405}
406
407struct BencodeAccessor<'a, 'de> {
408 de: &'a mut BencodeDeserializer<'de>,
409}
410
411impl<'a, 'de> serde::de::SeqAccess<'de> for BencodeAccessor<'a, 'de> {
412 type Error = Error;
413
414 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
415 where
416 T: serde::de::DeserializeSeed<'de>,
417 {
418 if *self
419 .de
420 .input
421 .first()
422 .ok_or(self.de.error(ErrorKind::UnexpectedEof("next seq element or end")))?
423 == b'e'
424 {
425 return Ok(None);
426 }
427 seed.deserialize(&mut *self.de).map(Some)
428 }
429}
430
431impl<'a, 'de> serde::de::MapAccess<'de> for BencodeAccessor<'a, 'de> {
432 type Error = Error;
433
434 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
435 where
436 K: serde::de::DeserializeSeed<'de>,
437 {
438 if *self
439 .de
440 .input
441 .first()
442 .ok_or(self.de.error(ErrorKind::UnexpectedEof("next dict key or end")))?
443 == b'e'
444 {
445 return Ok(None);
446 }
447 seed.deserialize(&mut *self.de).map(Some)
448 }
449
450 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
451 where
452 V: serde::de::DeserializeSeed<'de>,
453 {
454 seed.deserialize(&mut *self.de)
455 }
456}
457
458impl<'de, 'a> serde::de::EnumAccess<'de> for BencodeAccessor<'a, 'de> {
459 type Error = Error;
460 type Variant = Self;
461
462 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
463 where
464 V: serde::de::DeserializeSeed<'de>,
465 {
466 Ok((seed.deserialize(&mut *self.de)?, self))
467 }
468}
469
470impl<'de, 'a> serde::de::VariantAccess<'de> for BencodeAccessor<'a, 'de> {
471 type Error = Error;
472
473 fn unit_variant(self) -> Result<(), Self::Error> {
474 let bytes = self.de.parse_bytes_checked()?;
475 if !bytes.is_empty() {
476 return Err(self
477 .de
478 .error(ErrorKind::BadInputData("expected bencode string of length 0")));
479 }
480 Ok(())
481 }
482
483 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
484 where
485 T: serde::de::DeserializeSeed<'de>,
486 {
487 seed.deserialize(&mut *self.de)
488 }
489
490 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
491 where
492 V: serde::de::Visitor<'de>,
493 {
494 self.de.deserialize_seq(visitor)
495 }
496
497 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error>
498 where
499 V: serde::de::Visitor<'de>,
500 {
501 self.de.deserialize_map(visitor)
502 }
503}
504
505pub fn from_bytes<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
506 let mut deserializer = BencodeDeserializer::from_bytes(input);
507 let deserialized = T::deserialize(&mut deserializer)?;
508 if !deserializer.input.is_empty() {
509 return Err(ErrorKind::Custom(format!(
510 "Trailing bytes after deserialization: {}",
511 deserializer.input.len()
512 ))
513 .into());
514 }
515 Ok(deserialized)
516}
517
518pub fn from_str<'de, T: Deserialize<'de>>(input: &'de str) -> Result<T, Error> {
519 from_bytes(input.as_bytes())
520}