Skip to main content

irontide_wire/
extended.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    reason = "M175: BEP 10 extension protocol — message-id bytes bounded by extension count (u8)"
4)]
5
6use std::collections::BTreeMap;
7
8use bytes::Bytes;
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12
13/// Extension handshake (BEP 10, `ext_id=0`).
14///
15/// Exchanged after the standard handshake to negotiate extension IDs.
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct ExtHandshake {
18    /// Map of extension name → assigned message ID.
19    #[serde(default)]
20    pub m: BTreeMap<String, u8>,
21    /// Client name.
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub v: Option<String>,
24    /// TCP listen port for incoming connections.
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub p: Option<u16>,
27    /// Number of outstanding request queue size.
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub reqq: Option<u32>,
30    /// Total size of metadata (for `ut_metadata`).
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub metadata_size: Option<u64>,
33    /// BEP 21: upload-only flag (1 = seeder, 0 or absent = leecher).
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub upload_only: Option<u8>,
36}
37
38impl ExtHandshake {
39    /// Create a handshake advertising our supported extensions.
40    #[must_use]
41    pub fn new() -> Self {
42        let mut m = BTreeMap::new();
43        m.insert("ut_metadata".into(), 1);
44        m.insert("ut_pex".into(), 2);
45        m.insert("lt_trackers".into(), 3);
46        m.insert("ut_holepunch".into(), 4);
47        m.insert("lt_donthave".into(), 5);
48
49        Self {
50            m,
51            v: Some("Torrent 0.65.0".into()),
52            p: None,
53            reqq: Some(250),
54            metadata_size: None,
55            upload_only: None,
56        }
57    }
58
59    /// Create a handshake advertising built-in + plugin extensions.
60    ///
61    /// Built-in extensions are assigned IDs 1–4. Plugin names are assigned
62    /// IDs starting at 10, in the order provided.
63    #[must_use]
64    pub fn new_with_plugins(plugin_names: &[&str]) -> Self {
65        let mut hs = Self::new();
66        for (i, name) in plugin_names.iter().enumerate() {
67            hs.m.insert((*name).into(), 10 + i as u8);
68        }
69        hs
70    }
71
72    /// Create a handshake advertising upload-only (BEP 21 seeder) status.
73    #[must_use]
74    pub fn new_upload_only() -> Self {
75        let mut hs = Self::new();
76        hs.upload_only = Some(1);
77        hs
78    }
79
80    /// Returns true if the peer declared BEP 21 upload-only status.
81    #[must_use]
82    pub fn is_upload_only(&self) -> bool {
83        self.upload_only.unwrap_or(0) != 0
84    }
85
86    /// Encode to bencode bytes.
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if serialization fails.
91    pub fn to_bytes(&self) -> Result<Bytes> {
92        let data = irontide_bencode::to_bytes(self)?;
93        Ok(Bytes::from(data))
94    }
95
96    /// Decode from bencode bytes.
97    ///
98    /// Uses lenient parsing to accept unsorted dictionary keys, which many
99    /// real-world clients send in their extension handshakes.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if the data is not valid bencode.
104    pub fn from_bytes(data: &[u8]) -> Result<Self> {
105        Ok(irontide_bencode::from_bytes_lenient(data)?)
106    }
107
108    /// Look up the message ID for an extension by name.
109    #[must_use]
110    pub fn ext_id(&self, name: &str) -> Option<u8> {
111        self.m.get(name).copied()
112    }
113}
114
115/// Parsed extension messages.
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum ExtMessage {
118    /// Extension handshake (`ext_id=0`).
119    Handshake(Bytes),
120    /// `ut_metadata` message (BEP 9).
121    Metadata(MetadataMessage),
122}
123
124/// `ut_metadata` message types (BEP 9).
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum MetadataMessageType {
127    /// Request a metadata piece from the peer.
128    Request = 0,
129    /// Metadata piece payload.
130    Data = 1,
131    /// Peer does not have the requested metadata piece.
132    Reject = 2,
133}
134
135/// `ut_metadata` message (BEP 9).
136#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct MetadataMessage {
138    /// Message type (request, data, or reject).
139    pub msg_type: MetadataMessageType,
140    /// Zero-based metadata piece index.
141    pub piece: u32,
142    /// Total metadata size (included in Data messages).
143    pub total_size: Option<u64>,
144    /// Metadata piece data (appended after the bencode dict for Data messages).
145    pub data: Option<Bytes>,
146}
147
148/// Raw bencode structure for `ut_metadata` messages.
149#[derive(Serialize, Deserialize)]
150struct MetadataDict {
151    msg_type: u8,
152    piece: u32,
153    #[serde(default, skip_serializing_if = "Option::is_none")]
154    total_size: Option<u64>,
155}
156
157impl MetadataMessage {
158    /// Create a request for metadata piece.
159    #[must_use]
160    pub fn request(piece: u32) -> Self {
161        Self {
162            msg_type: MetadataMessageType::Request,
163            piece,
164            total_size: None,
165            data: None,
166        }
167    }
168
169    /// Create a data response for a metadata piece.
170    pub fn data(piece: u32, total_size: u64, data: Bytes) -> Self {
171        Self {
172            msg_type: MetadataMessageType::Data,
173            piece,
174            total_size: Some(total_size),
175            data: Some(data),
176        }
177    }
178
179    /// Create a reject for metadata piece.
180    #[must_use]
181    pub fn reject(piece: u32) -> Self {
182        Self {
183            msg_type: MetadataMessageType::Reject,
184            piece,
185            total_size: None,
186            data: None,
187        }
188    }
189
190    /// Encode to bytes (bencode dict + optional trailing data).
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if serialization fails.
195    pub fn to_bytes(&self) -> Result<Bytes> {
196        let dict = MetadataDict {
197            msg_type: self.msg_type as u8,
198            piece: self.piece,
199            total_size: self.total_size,
200        };
201        let mut buf = irontide_bencode::to_bytes(&dict)?;
202        if let Some(ref data) = self.data {
203            buf.extend_from_slice(data);
204        }
205        Ok(Bytes::from(buf))
206    }
207
208    /// Parse from bytes. The bencode dict may be followed by raw data.
209    ///
210    /// # Errors
211    ///
212    /// Returns an error if the data is not a valid metadata message.
213    pub fn from_bytes(data: &[u8]) -> Result<Self> {
214        // Find the end of the bencode dict by scanning for the matching 'e'
215        let dict_end = find_bencode_dict_end(data)?;
216        let dict: MetadataDict = irontide_bencode::from_bytes_lenient(&data[..dict_end])?;
217
218        let msg_type = match dict.msg_type {
219            0 => MetadataMessageType::Request,
220            1 => MetadataMessageType::Data,
221            2 => MetadataMessageType::Reject,
222            n => {
223                return Err(Error::InvalidExtended(format!(
224                    "unknown metadata msg_type {n}"
225                )));
226            }
227        };
228
229        let trailing = if dict_end < data.len() {
230            Some(Bytes::copy_from_slice(&data[dict_end..]))
231        } else {
232            None
233        };
234
235        Ok(Self {
236            msg_type,
237            piece: dict.piece,
238            total_size: dict.total_size,
239            data: trailing,
240        })
241    }
242}
243
244/// Find the end position of a bencode dictionary.
245fn find_bencode_dict_end(data: &[u8]) -> Result<usize> {
246    if data.first() != Some(&b'd') {
247        return Err(Error::InvalidExtended("expected bencode dict".into()));
248    }
249    let mut pos = 1;
250    let mut depth = 1u32;
251
252    while pos < data.len() && depth > 0 {
253        match data[pos] {
254            b'd' | b'l' => {
255                depth += 1;
256                pos += 1;
257            }
258            b'e' => {
259                depth -= 1;
260                pos += 1;
261            }
262            b'i' => {
263                pos += 1;
264                while pos < data.len() && data[pos] != b'e' {
265                    pos += 1;
266                }
267                pos += 1; // skip 'e'
268            }
269            b'0'..=b'9' => {
270                // byte string: parse length, skip content
271                let len_start = pos;
272                while pos < data.len() && data[pos] != b':' {
273                    pos += 1;
274                }
275                let len: usize = std::str::from_utf8(&data[len_start..pos])
276                    .map_err(|_| Error::InvalidExtended("bad string length".into()))?
277                    .parse()
278                    .map_err(|_| Error::InvalidExtended("bad string length".into()))?;
279                pos += 1 + len; // skip ':' + content
280            }
281            b => {
282                return Err(Error::InvalidExtended(format!(
283                    "unexpected byte {b:#04x} at position {pos}"
284                )));
285            }
286        }
287    }
288
289    if depth != 0 {
290        return Err(Error::InvalidExtended("unterminated dict".into()));
291    }
292    Ok(pos)
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    fn ext_handshake_round_trip() {
301        let hs = ExtHandshake::new();
302        let bytes = hs.to_bytes().unwrap();
303        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
304        assert_eq!(hs.m, parsed.m);
305        assert_eq!(hs.v, parsed.v);
306        assert_eq!(hs.reqq, parsed.reqq);
307    }
308
309    #[test]
310    fn ext_handshake_ext_id_lookup() {
311        let hs = ExtHandshake::new();
312        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
313        assert_eq!(hs.ext_id("ut_pex"), Some(2));
314        assert_eq!(hs.ext_id("lt_trackers"), Some(3));
315        assert_eq!(hs.ext_id("ut_holepunch"), Some(4));
316        assert_eq!(hs.ext_id("unknown"), None);
317    }
318
319    #[test]
320    fn ext_handshake_upload_only_round_trip() {
321        let hs = ExtHandshake::new_upload_only();
322        assert!(hs.is_upload_only());
323        let bytes = hs.to_bytes().unwrap();
324        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
325        assert!(parsed.is_upload_only());
326        assert_eq!(parsed.upload_only, Some(1));
327    }
328
329    #[test]
330    fn ext_handshake_no_upload_only_default() {
331        let hs = ExtHandshake::new();
332        assert!(!hs.is_upload_only());
333        assert_eq!(hs.upload_only, None);
334    }
335
336    #[test]
337    fn ext_handshake_with_plugins() {
338        let hs = ExtHandshake::new_with_plugins(&["ut_comment", "ut_holepunch"]);
339        // Built-ins unchanged
340        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
341        assert_eq!(hs.ext_id("ut_pex"), Some(2));
342        assert_eq!(hs.ext_id("lt_trackers"), Some(3));
343        // Plugins at 10+
344        assert_eq!(hs.ext_id("ut_comment"), Some(10));
345        assert_eq!(hs.ext_id("ut_holepunch"), Some(11));
346    }
347
348    #[test]
349    fn ext_handshake_with_plugins_round_trip() {
350        let hs = ExtHandshake::new_with_plugins(&["ut_echo"]);
351        let bytes = hs.to_bytes().unwrap();
352        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
353        assert_eq!(parsed.ext_id("ut_echo"), Some(10));
354        assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
355    }
356
357    #[test]
358    fn ext_handshake_no_plugins() {
359        let hs = ExtHandshake::new_with_plugins(&[]);
360        assert_eq!(hs.m.len(), 5); // only built-ins
361    }
362
363    #[test]
364    fn ext_handshake_holepunch_can_be_removed() {
365        let mut hs = ExtHandshake::new();
366        hs.m.remove("ut_holepunch");
367        assert_eq!(hs.ext_id("ut_holepunch"), None);
368        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
369        assert_eq!(hs.ext_id("ut_pex"), Some(2));
370    }
371
372    #[test]
373    fn metadata_request_round_trip() {
374        let msg = MetadataMessage::request(3);
375        let bytes = msg.to_bytes().unwrap();
376        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
377        assert_eq!(parsed.msg_type, MetadataMessageType::Request);
378        assert_eq!(parsed.piece, 3);
379        assert!(parsed.data.is_none());
380    }
381
382    #[test]
383    fn metadata_data_with_trailing() {
384        let msg = MetadataMessage {
385            msg_type: MetadataMessageType::Data,
386            piece: 0,
387            total_size: Some(31415),
388            data: Some(Bytes::from_static(b"raw metadata bytes here")),
389        };
390        let bytes = msg.to_bytes().unwrap();
391        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
392        assert_eq!(parsed.msg_type, MetadataMessageType::Data);
393        assert_eq!(parsed.piece, 0);
394        assert_eq!(parsed.total_size, Some(31415));
395        assert_eq!(
396            parsed.data.as_deref(),
397            Some(b"raw metadata bytes here".as_ref())
398        );
399    }
400
401    #[test]
402    fn metadata_reject() {
403        let msg = MetadataMessage::reject(5);
404        let bytes = msg.to_bytes().unwrap();
405        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
406        assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
407        assert_eq!(parsed.piece, 5);
408    }
409
410    /// BEP 10: setting an extension's message ID to 0 disables it. Verify that
411    /// encoding a handshake with a zero-ID extension round-trips correctly and
412    /// that `ext_id()` returns Some(0) (the peer explicitly disabled it).
413    #[test]
414    fn ext_handshake_disable_extension_via_zero() {
415        // Create a handshake and set ut_pex to 0 (disabling it per BEP 10)
416        let mut hs = ExtHandshake::new();
417        hs.m.insert("ut_pex".into(), 0);
418
419        // The map should still contain the key with value 0
420        assert_eq!(hs.ext_id("ut_pex"), Some(0));
421
422        // Round-trip through bencode
423        let bytes = hs.to_bytes().unwrap();
424        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
425
426        // After round-trip, the zero-valued extension should still be present
427        assert_eq!(
428            parsed.ext_id("ut_pex"),
429            Some(0),
430            "BEP 10: message ID 0 means disabled, but must survive round-trip"
431        );
432
433        // Other extensions should remain unaffected
434        assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
435        assert_eq!(parsed.ext_id("lt_trackers"), Some(3));
436
437        // Verify that a disabled extension (ID 0) is distinguishable from absent
438        assert_eq!(parsed.ext_id("nonexistent"), None);
439        assert_eq!(parsed.ext_id("ut_pex"), Some(0));
440    }
441
442    /// BEP 54: Verify that the extension handshake advertises `lt_donthave` with ID 5.
443    #[test]
444    fn ext_handshake_includes_lt_donthave() {
445        let hs = ExtHandshake::new();
446        assert_eq!(hs.ext_id("lt_donthave"), Some(5));
447        // Verify round-trip preserves the ID
448        let bytes = hs.to_bytes().unwrap();
449        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
450        assert_eq!(parsed.ext_id("lt_donthave"), Some(5));
451    }
452}