Skip to main content

irontide_wire/
extended.rs

1use std::collections::BTreeMap;
2
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8/// Extension handshake (BEP 10, ext_id=0).
9///
10/// Exchanged after the standard handshake to negotiate extension IDs.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct ExtHandshake {
13    /// Map of extension name → assigned message ID.
14    #[serde(default)]
15    pub m: BTreeMap<String, u8>,
16    /// Client name.
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub v: Option<String>,
19    /// TCP listen port for incoming connections.
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub p: Option<u16>,
22    /// Number of outstanding request queue size.
23    #[serde(default, skip_serializing_if = "Option::is_none")]
24    pub reqq: Option<u32>,
25    /// Total size of metadata (for ut_metadata).
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub metadata_size: Option<u64>,
28    /// BEP 21: upload-only flag (1 = seeder, 0 or absent = leecher).
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub upload_only: Option<u8>,
31}
32
33impl ExtHandshake {
34    /// Create a handshake advertising our supported extensions.
35    pub fn new() -> Self {
36        let mut m = BTreeMap::new();
37        m.insert("ut_metadata".into(), 1);
38        m.insert("ut_pex".into(), 2);
39        m.insert("lt_trackers".into(), 3);
40        m.insert("ut_holepunch".into(), 4);
41        m.insert("lt_donthave".into(), 5);
42
43        ExtHandshake {
44            m,
45            v: Some("Torrent 0.65.0".into()),
46            p: None,
47            reqq: Some(250),
48            metadata_size: None,
49            upload_only: None,
50        }
51    }
52
53    /// Create a handshake advertising built-in + plugin extensions.
54    ///
55    /// Built-in extensions are assigned IDs 1–4. Plugin names are assigned
56    /// IDs starting at 10, in the order provided.
57    pub fn new_with_plugins(plugin_names: &[&str]) -> Self {
58        let mut hs = Self::new();
59        for (i, name) in plugin_names.iter().enumerate() {
60            hs.m.insert((*name).into(), 10 + i as u8);
61        }
62        hs
63    }
64
65    /// Create a handshake advertising upload-only (BEP 21 seeder) status.
66    pub fn new_upload_only() -> Self {
67        let mut hs = Self::new();
68        hs.upload_only = Some(1);
69        hs
70    }
71
72    /// Returns true if the peer declared BEP 21 upload-only status.
73    pub fn is_upload_only(&self) -> bool {
74        self.upload_only.unwrap_or(0) != 0
75    }
76
77    /// Encode to bencode bytes.
78    pub fn to_bytes(&self) -> Result<Bytes> {
79        let data = irontide_bencode::to_bytes(self)?;
80        Ok(Bytes::from(data))
81    }
82
83    /// Decode from bencode bytes.
84    ///
85    /// Uses lenient parsing to accept unsorted dictionary keys, which many
86    /// real-world clients send in their extension handshakes.
87    pub fn from_bytes(data: &[u8]) -> Result<Self> {
88        Ok(irontide_bencode::from_bytes_lenient(data)?)
89    }
90
91    /// Look up the message ID for an extension by name.
92    pub fn ext_id(&self, name: &str) -> Option<u8> {
93        self.m.get(name).copied()
94    }
95}
96
97/// Parsed extension messages.
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub enum ExtMessage {
100    /// Extension handshake (ext_id=0).
101    Handshake(Bytes),
102    /// ut_metadata message (BEP 9).
103    Metadata(MetadataMessage),
104}
105
106/// ut_metadata message types (BEP 9).
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum MetadataMessageType {
109    /// Request a metadata piece from the peer.
110    Request = 0,
111    /// Metadata piece payload.
112    Data = 1,
113    /// Peer does not have the requested metadata piece.
114    Reject = 2,
115}
116
117/// ut_metadata message (BEP 9).
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct MetadataMessage {
120    /// Message type (request, data, or reject).
121    pub msg_type: MetadataMessageType,
122    /// Zero-based metadata piece index.
123    pub piece: u32,
124    /// Total metadata size (included in Data messages).
125    pub total_size: Option<u64>,
126    /// Metadata piece data (appended after the bencode dict for Data messages).
127    pub data: Option<Bytes>,
128}
129
130/// Raw bencode structure for ut_metadata messages.
131#[derive(Serialize, Deserialize)]
132struct MetadataDict {
133    msg_type: u8,
134    piece: u32,
135    #[serde(default, skip_serializing_if = "Option::is_none")]
136    total_size: Option<u64>,
137}
138
139impl MetadataMessage {
140    /// Create a request for metadata piece.
141    pub fn request(piece: u32) -> Self {
142        MetadataMessage {
143            msg_type: MetadataMessageType::Request,
144            piece,
145            total_size: None,
146            data: None,
147        }
148    }
149
150    /// Create a data response for a metadata piece.
151    pub fn data(piece: u32, total_size: u64, data: Bytes) -> Self {
152        MetadataMessage {
153            msg_type: MetadataMessageType::Data,
154            piece,
155            total_size: Some(total_size),
156            data: Some(data),
157        }
158    }
159
160    /// Create a reject for metadata piece.
161    pub fn reject(piece: u32) -> Self {
162        MetadataMessage {
163            msg_type: MetadataMessageType::Reject,
164            piece,
165            total_size: None,
166            data: None,
167        }
168    }
169
170    /// Encode to bytes (bencode dict + optional trailing data).
171    pub fn to_bytes(&self) -> Result<Bytes> {
172        let dict = MetadataDict {
173            msg_type: self.msg_type as u8,
174            piece: self.piece,
175            total_size: self.total_size,
176        };
177        let mut buf = irontide_bencode::to_bytes(&dict)?;
178        if let Some(ref data) = self.data {
179            buf.extend_from_slice(data);
180        }
181        Ok(Bytes::from(buf))
182    }
183
184    /// Parse from bytes. The bencode dict may be followed by raw data.
185    pub fn from_bytes(data: &[u8]) -> Result<Self> {
186        // Find the end of the bencode dict by scanning for the matching 'e'
187        let dict_end = find_bencode_dict_end(data)?;
188        let dict: MetadataDict = irontide_bencode::from_bytes_lenient(&data[..dict_end])?;
189
190        let msg_type = match dict.msg_type {
191            0 => MetadataMessageType::Request,
192            1 => MetadataMessageType::Data,
193            2 => MetadataMessageType::Reject,
194            n => {
195                return Err(Error::InvalidExtended(format!(
196                    "unknown metadata msg_type {n}"
197                )));
198            }
199        };
200
201        let trailing = if dict_end < data.len() {
202            Some(Bytes::copy_from_slice(&data[dict_end..]))
203        } else {
204            None
205        };
206
207        Ok(MetadataMessage {
208            msg_type,
209            piece: dict.piece,
210            total_size: dict.total_size,
211            data: trailing,
212        })
213    }
214}
215
216/// Find the end position of a bencode dictionary.
217fn find_bencode_dict_end(data: &[u8]) -> Result<usize> {
218    if data.first() != Some(&b'd') {
219        return Err(Error::InvalidExtended("expected bencode dict".into()));
220    }
221    let mut pos = 1;
222    let mut depth = 1u32;
223
224    while pos < data.len() && depth > 0 {
225        match data[pos] {
226            b'd' | b'l' => {
227                depth += 1;
228                pos += 1;
229            }
230            b'e' => {
231                depth -= 1;
232                pos += 1;
233            }
234            b'i' => {
235                pos += 1;
236                while pos < data.len() && data[pos] != b'e' {
237                    pos += 1;
238                }
239                pos += 1; // skip 'e'
240            }
241            b'0'..=b'9' => {
242                // byte string: parse length, skip content
243                let len_start = pos;
244                while pos < data.len() && data[pos] != b':' {
245                    pos += 1;
246                }
247                let len: usize = std::str::from_utf8(&data[len_start..pos])
248                    .map_err(|_| Error::InvalidExtended("bad string length".into()))?
249                    .parse()
250                    .map_err(|_| Error::InvalidExtended("bad string length".into()))?;
251                pos += 1 + len; // skip ':' + content
252            }
253            b => {
254                return Err(Error::InvalidExtended(format!(
255                    "unexpected byte {b:#04x} at position {pos}"
256                )));
257            }
258        }
259    }
260
261    if depth != 0 {
262        return Err(Error::InvalidExtended("unterminated dict".into()));
263    }
264    Ok(pos)
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn ext_handshake_round_trip() {
273        let hs = ExtHandshake::new();
274        let bytes = hs.to_bytes().unwrap();
275        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
276        assert_eq!(hs.m, parsed.m);
277        assert_eq!(hs.v, parsed.v);
278        assert_eq!(hs.reqq, parsed.reqq);
279    }
280
281    #[test]
282    fn ext_handshake_ext_id_lookup() {
283        let hs = ExtHandshake::new();
284        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
285        assert_eq!(hs.ext_id("ut_pex"), Some(2));
286        assert_eq!(hs.ext_id("lt_trackers"), Some(3));
287        assert_eq!(hs.ext_id("ut_holepunch"), Some(4));
288        assert_eq!(hs.ext_id("unknown"), None);
289    }
290
291    #[test]
292    fn ext_handshake_upload_only_round_trip() {
293        let hs = ExtHandshake::new_upload_only();
294        assert!(hs.is_upload_only());
295        let bytes = hs.to_bytes().unwrap();
296        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
297        assert!(parsed.is_upload_only());
298        assert_eq!(parsed.upload_only, Some(1));
299    }
300
301    #[test]
302    fn ext_handshake_no_upload_only_default() {
303        let hs = ExtHandshake::new();
304        assert!(!hs.is_upload_only());
305        assert_eq!(hs.upload_only, None);
306    }
307
308    #[test]
309    fn ext_handshake_with_plugins() {
310        let hs = ExtHandshake::new_with_plugins(&["ut_comment", "ut_holepunch"]);
311        // Built-ins unchanged
312        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
313        assert_eq!(hs.ext_id("ut_pex"), Some(2));
314        assert_eq!(hs.ext_id("lt_trackers"), Some(3));
315        // Plugins at 10+
316        assert_eq!(hs.ext_id("ut_comment"), Some(10));
317        assert_eq!(hs.ext_id("ut_holepunch"), Some(11));
318    }
319
320    #[test]
321    fn ext_handshake_with_plugins_round_trip() {
322        let hs = ExtHandshake::new_with_plugins(&["ut_echo"]);
323        let bytes = hs.to_bytes().unwrap();
324        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
325        assert_eq!(parsed.ext_id("ut_echo"), Some(10));
326        assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
327    }
328
329    #[test]
330    fn ext_handshake_no_plugins() {
331        let hs = ExtHandshake::new_with_plugins(&[]);
332        assert_eq!(hs.m.len(), 5); // only built-ins
333    }
334
335    #[test]
336    fn ext_handshake_holepunch_can_be_removed() {
337        let mut hs = ExtHandshake::new();
338        hs.m.remove("ut_holepunch");
339        assert_eq!(hs.ext_id("ut_holepunch"), None);
340        assert_eq!(hs.ext_id("ut_metadata"), Some(1));
341        assert_eq!(hs.ext_id("ut_pex"), Some(2));
342    }
343
344    #[test]
345    fn metadata_request_round_trip() {
346        let msg = MetadataMessage::request(3);
347        let bytes = msg.to_bytes().unwrap();
348        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
349        assert_eq!(parsed.msg_type, MetadataMessageType::Request);
350        assert_eq!(parsed.piece, 3);
351        assert!(parsed.data.is_none());
352    }
353
354    #[test]
355    fn metadata_data_with_trailing() {
356        let msg = MetadataMessage {
357            msg_type: MetadataMessageType::Data,
358            piece: 0,
359            total_size: Some(31415),
360            data: Some(Bytes::from_static(b"raw metadata bytes here")),
361        };
362        let bytes = msg.to_bytes().unwrap();
363        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
364        assert_eq!(parsed.msg_type, MetadataMessageType::Data);
365        assert_eq!(parsed.piece, 0);
366        assert_eq!(parsed.total_size, Some(31415));
367        assert_eq!(
368            parsed.data.as_deref(),
369            Some(b"raw metadata bytes here".as_ref())
370        );
371    }
372
373    #[test]
374    fn metadata_reject() {
375        let msg = MetadataMessage::reject(5);
376        let bytes = msg.to_bytes().unwrap();
377        let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
378        assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
379        assert_eq!(parsed.piece, 5);
380    }
381
382    /// BEP 10: setting an extension's message ID to 0 disables it. Verify that
383    /// encoding a handshake with a zero-ID extension round-trips correctly and
384    /// that ext_id() returns Some(0) (the peer explicitly disabled it).
385    #[test]
386    fn ext_handshake_disable_extension_via_zero() {
387        // Create a handshake and set ut_pex to 0 (disabling it per BEP 10)
388        let mut hs = ExtHandshake::new();
389        hs.m.insert("ut_pex".into(), 0);
390
391        // The map should still contain the key with value 0
392        assert_eq!(hs.ext_id("ut_pex"), Some(0));
393
394        // Round-trip through bencode
395        let bytes = hs.to_bytes().unwrap();
396        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
397
398        // After round-trip, the zero-valued extension should still be present
399        assert_eq!(
400            parsed.ext_id("ut_pex"),
401            Some(0),
402            "BEP 10: message ID 0 means disabled, but must survive round-trip"
403        );
404
405        // Other extensions should remain unaffected
406        assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
407        assert_eq!(parsed.ext_id("lt_trackers"), Some(3));
408
409        // Verify that a disabled extension (ID 0) is distinguishable from absent
410        assert_eq!(parsed.ext_id("nonexistent"), None);
411        assert_eq!(parsed.ext_id("ut_pex"), Some(0));
412    }
413
414    /// BEP 54: Verify that the extension handshake advertises lt_donthave with ID 5.
415    #[test]
416    fn ext_handshake_includes_lt_donthave() {
417        let hs = ExtHandshake::new();
418        assert_eq!(hs.ext_id("lt_donthave"), Some(5));
419        // Verify round-trip preserves the ID
420        let bytes = hs.to_bytes().unwrap();
421        let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
422        assert_eq!(parsed.ext_id("lt_donthave"), Some(5));
423    }
424}