Skip to main content

rift_torrent/
bencode.rs

1//! Bencode encoding and decoding (BEP-3 compatible).
2
3use std::collections::BTreeMap;
4
5use crate::{FileInfo, InfoHash, PieceHash, SrtExtension, TorrentError, TorrentMeta};
6
7/// Bencode value representation.
8#[derive(Clone, Debug, PartialEq)]
9pub enum BValue {
10    /// Byte string (may contain non-UTF8).
11    Bytes(Vec<u8>),
12    /// Integer.
13    Int(i64),
14    /// List of values.
15    List(Vec<BValue>),
16    /// Dictionary (sorted by key).
17    Dict(BTreeMap<Vec<u8>, BValue>),
18}
19
20impl BValue {
21    /// Get as bytes if Bytes variant.
22    pub fn as_bytes(&self) -> Option<&[u8]> {
23        match self {
24            BValue::Bytes(b) => Some(b),
25            _ => None,
26        }
27    }
28
29    /// Get as string if valid UTF-8 Bytes.
30    pub fn as_str(&self) -> Option<&str> {
31        self.as_bytes().and_then(|b| std::str::from_utf8(b).ok())
32    }
33
34    /// Get as integer.
35    pub fn as_int(&self) -> Option<i64> {
36        match self {
37            BValue::Int(i) => Some(*i),
38            _ => None,
39        }
40    }
41
42    /// Get as list.
43    pub fn as_list(&self) -> Option<&[BValue]> {
44        match self {
45            BValue::List(l) => Some(l),
46            _ => None,
47        }
48    }
49
50    /// Get as dictionary.
51    pub fn as_dict(&self) -> Option<&BTreeMap<Vec<u8>, BValue>> {
52        match self {
53            BValue::Dict(d) => Some(d),
54            _ => None,
55        }
56    }
57
58    /// Dictionary lookup by string key.
59    pub fn get(&self, key: &str) -> Option<&BValue> {
60        self.as_dict().and_then(|d| d.get(key.as_bytes()))
61    }
62}
63
64/// Decode bencode from bytes.
65pub fn decode(input: &[u8]) -> Result<(BValue, usize), TorrentError> {
66    if input.is_empty() {
67        return Err(TorrentError::Bencode("empty input".into()));
68    }
69
70    match input[0] {
71        b'i' => decode_int(input),
72        b'l' => decode_list(input),
73        b'd' => decode_dict(input),
74        b'0'..=b'9' => decode_bytes(input),
75        c => Err(TorrentError::Bencode(format!("unexpected byte: {}", c as char))),
76    }
77}
78
79fn decode_int(input: &[u8]) -> Result<(BValue, usize), TorrentError> {
80    let end = input
81        .iter()
82        .position(|&b| b == b'e')
83        .ok_or_else(|| TorrentError::Bencode("unterminated integer".into()))?;
84
85    let num_str = std::str::from_utf8(&input[1..end])
86        .map_err(|_| TorrentError::Bencode("invalid integer encoding".into()))?;
87
88    let num: i64 = num_str
89        .parse()
90        .map_err(|_| TorrentError::Bencode("invalid integer value".into()))?;
91
92    Ok((BValue::Int(num), end + 1))
93}
94
95fn decode_bytes(input: &[u8]) -> Result<(BValue, usize), TorrentError> {
96    let colon = input
97        .iter()
98        .position(|&b| b == b':')
99        .ok_or_else(|| TorrentError::Bencode("missing colon in byte string".into()))?;
100
101    let len_str = std::str::from_utf8(&input[..colon])
102        .map_err(|_| TorrentError::Bencode("invalid byte string length".into()))?;
103
104    let len: usize = len_str
105        .parse()
106        .map_err(|_| TorrentError::Bencode("invalid byte string length value".into()))?;
107
108    let start = colon + 1;
109    let end = start + len;
110
111    if end > input.len() {
112        return Err(TorrentError::Bencode("byte string exceeds input".into()));
113    }
114
115    Ok((BValue::Bytes(input[start..end].to_vec()), end))
116}
117
118fn decode_list(input: &[u8]) -> Result<(BValue, usize), TorrentError> {
119    let mut items = Vec::new();
120    let mut pos = 1; // Skip 'l'
121
122    while pos < input.len() && input[pos] != b'e' {
123        let (value, consumed) = decode(&input[pos..])?;
124        items.push(value);
125        pos += consumed;
126    }
127
128    if pos >= input.len() {
129        return Err(TorrentError::Bencode("unterminated list".into()));
130    }
131
132    Ok((BValue::List(items), pos + 1)) // +1 for 'e'
133}
134
135fn decode_dict(input: &[u8]) -> Result<(BValue, usize), TorrentError> {
136    let mut map = BTreeMap::new();
137    let mut pos = 1; // Skip 'd'
138
139    while pos < input.len() && input[pos] != b'e' {
140        // Key must be a byte string
141        let (key_val, key_consumed) = decode(&input[pos..])?;
142        let key = match key_val {
143            BValue::Bytes(b) => b,
144            _ => return Err(TorrentError::Bencode("dictionary key must be string".into())),
145        };
146        pos += key_consumed;
147
148        // Value
149        let (value, val_consumed) = decode(&input[pos..])?;
150        pos += val_consumed;
151
152        map.insert(key, value);
153    }
154
155    if pos >= input.len() {
156        return Err(TorrentError::Bencode("unterminated dictionary".into()));
157    }
158
159    Ok((BValue::Dict(map), pos + 1)) // +1 for 'e'
160}
161
162/// Encode bencode to bytes.
163pub fn encode(value: &BValue) -> Vec<u8> {
164    let mut out = Vec::new();
165    encode_into(value, &mut out);
166    out
167}
168
169fn encode_into(value: &BValue, out: &mut Vec<u8>) {
170    match value {
171        BValue::Bytes(b) => {
172            out.extend_from_slice(b.len().to_string().as_bytes());
173            out.push(b':');
174            out.extend_from_slice(b);
175        }
176        BValue::Int(i) => {
177            out.push(b'i');
178            out.extend_from_slice(i.to_string().as_bytes());
179            out.push(b'e');
180        }
181        BValue::List(items) => {
182            out.push(b'l');
183            for item in items {
184                encode_into(item, out);
185            }
186            out.push(b'e');
187        }
188        BValue::Dict(map) => {
189            out.push(b'd');
190            for (key, val) in map {
191                // Key as byte string
192                out.extend_from_slice(key.len().to_string().as_bytes());
193                out.push(b':');
194                out.extend_from_slice(key);
195                // Value
196                encode_into(val, out);
197            }
198            out.push(b'e');
199        }
200    }
201}
202
203/// Parse a .torrent file into TorrentMeta.
204pub fn parse_torrent(data: &[u8]) -> Result<TorrentMeta, TorrentError> {
205    let (root, _) = decode(data)?;
206    let dict = root
207        .as_dict()
208        .ok_or(TorrentError::InvalidTorrent("root not dict"))?;
209
210    // Extract info dictionary and compute infohash
211    let info_bval = dict
212        .get(b"info".as_slice())
213        .ok_or(TorrentError::InvalidTorrent("missing info"))?;
214    let info_bytes = encode(info_bval);
215    let info_hash = compute_sha1_infohash(&info_bytes);
216
217    let info = info_bval
218        .as_dict()
219        .ok_or(TorrentError::InvalidTorrent("info not dict"))?;
220
221    // Name (required)
222    let name = info
223        .get(b"name".as_slice())
224        .and_then(|v| v.as_str())
225        .ok_or(TorrentError::InvalidTorrent("missing name"))?
226        .to_string();
227
228    // Piece length (required)
229    let piece_length = info
230        .get(b"piece length".as_slice())
231        .and_then(|v| v.as_int())
232        .ok_or(TorrentError::InvalidTorrent("missing piece length"))? as u64;
233
234    // Pieces (required) - concatenated SHA1 hashes
235    let pieces_bytes = info
236        .get(b"pieces".as_slice())
237        .and_then(|v| v.as_bytes())
238        .ok_or(TorrentError::InvalidTorrent("missing pieces"))?;
239
240    if pieces_bytes.len() % 20 != 0 {
241        return Err(TorrentError::InvalidTorrent("pieces length not multiple of 20"));
242    }
243
244    let pieces: Vec<PieceHash> = pieces_bytes
245        .chunks_exact(20)
246        .map(|chunk| {
247            let mut h = [0u8; 20];
248            h.copy_from_slice(chunk);
249            PieceHash(h)
250        })
251        .collect();
252
253    // Files - either single file or multi-file
254    let (files, total_length) = if let Some(length_val) = info.get(b"length".as_slice()) {
255        // Single file mode
256        let length = length_val
257            .as_int()
258            .ok_or(TorrentError::InvalidTorrent("invalid length"))? as u64;
259        let file = FileInfo {
260            path: name.clone().into(),
261            length,
262        };
263        (vec![file], length)
264    } else if let Some(files_val) = info.get(b"files".as_slice()) {
265        // Multi-file mode
266        let files_list = files_val
267            .as_list()
268            .ok_or(TorrentError::InvalidTorrent("files not list"))?;
269
270        let mut files = Vec::new();
271        let mut total = 0u64;
272
273        for file_val in files_list {
274            let file_dict = file_val
275                .as_dict()
276                .ok_or(TorrentError::InvalidTorrent("file entry not dict"))?;
277
278            let length = file_dict
279                .get(b"length".as_slice())
280                .and_then(|v| v.as_int())
281                .ok_or(TorrentError::InvalidTorrent("file missing length"))? as u64;
282
283            let path_list = file_dict
284                .get(b"path".as_slice())
285                .and_then(|v| v.as_list())
286                .ok_or(TorrentError::InvalidTorrent("file missing path"))?;
287
288            let path: PathBuf = path_list
289                .iter()
290                .filter_map(|p| p.as_str())
291                .collect();
292
293            files.push(FileInfo { path, length });
294            total += length;
295        }
296
297        (files, total)
298    } else {
299        return Err(TorrentError::InvalidTorrent("missing length or files"));
300    };
301
302    // Optional fields
303    let announce = dict
304        .get(b"announce".as_slice())
305        .and_then(|v| v.as_str())
306        .map(String::from);
307
308    let announce_list = dict
309        .get(b"announce-list".as_slice())
310        .and_then(|v| v.as_list())
311        .map(|tiers| {
312            tiers
313                .iter()
314                .filter_map(|tier| {
315                    tier.as_list().map(|urls| {
316                        urls.iter()
317                            .filter_map(|u| u.as_str().map(String::from))
318                            .collect()
319                    })
320                })
321                .collect()
322        })
323        .unwrap_or_default();
324
325    let creation_date = dict
326        .get(b"creation date".as_slice())
327        .and_then(|v| v.as_int())
328        .map(|i| i as u64);
329
330    let comment = dict
331        .get(b"comment".as_slice())
332        .and_then(|v| v.as_str())
333        .map(String::from);
334
335    let created_by = dict
336        .get(b"created by".as_slice())
337        .and_then(|v| v.as_str())
338        .map(String::from);
339
340    // SRT extension (rift-specific)
341    let srt_extension = dict
342        .get(b"srt".as_slice())
343        .and_then(|v| parse_srt_extension(v).ok());
344
345    Ok(TorrentMeta {
346        info_hash,
347        name,
348        piece_length,
349        pieces,
350        files,
351        total_length,
352        announce,
353        announce_list,
354        creation_date,
355        srt_extension,
356        comment,
357        created_by,
358    })
359}
360
361fn parse_srt_extension(value: &BValue) -> Result<SrtExtension, TorrentError> {
362    let dict = value
363        .as_dict()
364        .ok_or(TorrentError::InvalidTorrent("srt not dict"))?;
365
366    let version = dict
367        .get(b"version".as_slice())
368        .and_then(|v| v.as_int())
369        .unwrap_or(1) as u8;
370
371    let t0_offset = dict
372        .get(b"t0_offset".as_slice())
373        .and_then(|v| v.as_int())
374        .unwrap_or(0) as u64;
375
376    let window_secs = dict
377        .get(b"window_secs".as_slice())
378        .and_then(|v| v.as_int())
379        .unwrap_or(300) as u64;
380
381    let slot_ms = dict
382        .get(b"slot_ms".as_slice())
383        .and_then(|v| v.as_int())
384        .unwrap_or(500) as u64;
385
386    let salt = dict.get(b"salt".as_slice()).and_then(|v| {
387        v.as_bytes().and_then(|b| {
388            if b.len() == 16 {
389                let mut arr = [0u8; 16];
390                arr.copy_from_slice(b);
391                Some(arr)
392            } else {
393                None
394            }
395        })
396    });
397
398    Ok(SrtExtension {
399        version,
400        t0_offset,
401        window_secs,
402        slot_ms,
403        salt,
404    })
405}
406
407fn compute_sha1_infohash(info_bytes: &[u8]) -> InfoHash {
408    use sha1::{Digest, Sha1};
409    let mut hasher = Sha1::new();
410    hasher.update(info_bytes);
411    let result = hasher.finalize();
412    let mut h = [0u8; 20];
413    h.copy_from_slice(&result);
414    InfoHash::Sha1(h)
415}
416
417use std::path::PathBuf;
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn decode_int() {
425        let (val, len) = decode(b"i42e").unwrap();
426        assert_eq!(val, BValue::Int(42));
427        assert_eq!(len, 4);
428    }
429
430    #[test]
431    fn decode_negative_int() {
432        let (val, _) = decode(b"i-123e").unwrap();
433        assert_eq!(val, BValue::Int(-123));
434    }
435
436    #[test]
437    fn decode_bytes() {
438        let (val, len) = decode(b"5:hello").unwrap();
439        assert_eq!(val, BValue::Bytes(b"hello".to_vec()));
440        assert_eq!(len, 7);
441    }
442
443    #[test]
444    fn decode_list() {
445        let (val, _) = decode(b"li1ei2ei3ee").unwrap();
446        assert_eq!(
447            val,
448            BValue::List(vec![BValue::Int(1), BValue::Int(2), BValue::Int(3)])
449        );
450    }
451
452    #[test]
453    fn decode_dict() {
454        let (val, _) = decode(b"d3:fooi42ee").unwrap();
455        let mut expected = BTreeMap::new();
456        expected.insert(b"foo".to_vec(), BValue::Int(42));
457        assert_eq!(val, BValue::Dict(expected));
458    }
459
460    #[test]
461    fn encode_roundtrip() {
462        let original = BValue::Dict({
463            let mut m = BTreeMap::new();
464            m.insert(b"name".to_vec(), BValue::Bytes(b"test".to_vec()));
465            m.insert(b"size".to_vec(), BValue::Int(1024));
466            m.insert(
467                b"list".to_vec(),
468                BValue::List(vec![BValue::Int(1), BValue::Int(2)]),
469            );
470            m
471        });
472
473        let encoded = encode(&original);
474        let (decoded, _) = decode(&encoded).unwrap();
475        assert_eq!(original, decoded);
476    }
477}