1#![allow(
2 clippy::cast_possible_truncation,
3 reason = "M175: BEP 10 extension protocol — message-id bytes bounded by extension count (u8)"
4)]
5
6use std::collections::BTreeMap;
7
8use bytes::Bytes;
9use serde::{Deserialize, Serialize};
10
11use crate::error::{Error, Result};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct ExtHandshake {
18 #[serde(default)]
20 pub m: BTreeMap<String, u8>,
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub v: Option<String>,
24 #[serde(default, skip_serializing_if = "Option::is_none")]
26 pub p: Option<u16>,
27 #[serde(default, skip_serializing_if = "Option::is_none")]
29 pub reqq: Option<u32>,
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub metadata_size: Option<u64>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub upload_only: Option<u8>,
36}
37
38impl ExtHandshake {
39 #[must_use]
41 pub fn new() -> Self {
42 let mut m = BTreeMap::new();
43 m.insert("ut_metadata".into(), 1);
44 m.insert("ut_pex".into(), 2);
45 m.insert("lt_trackers".into(), 3);
46 m.insert("ut_holepunch".into(), 4);
47 m.insert("lt_donthave".into(), 5);
48
49 Self {
50 m,
51 v: Some("Torrent 0.65.0".into()),
52 p: None,
53 reqq: Some(250),
54 metadata_size: None,
55 upload_only: None,
56 }
57 }
58
59 #[must_use]
64 pub fn new_with_plugins(plugin_names: &[&str]) -> Self {
65 let mut hs = Self::new();
66 for (i, name) in plugin_names.iter().enumerate() {
67 hs.m.insert((*name).into(), 10 + i as u8);
68 }
69 hs
70 }
71
72 #[must_use]
74 pub fn new_upload_only() -> Self {
75 let mut hs = Self::new();
76 hs.upload_only = Some(1);
77 hs
78 }
79
80 #[must_use]
82 pub fn is_upload_only(&self) -> bool {
83 self.upload_only.unwrap_or(0) != 0
84 }
85
86 pub fn to_bytes(&self) -> Result<Bytes> {
92 let data = irontide_bencode::to_bytes(self)?;
93 Ok(Bytes::from(data))
94 }
95
96 pub fn from_bytes(data: &[u8]) -> Result<Self> {
105 Ok(irontide_bencode::from_bytes_lenient(data)?)
106 }
107
108 #[must_use]
110 pub fn ext_id(&self, name: &str) -> Option<u8> {
111 self.m.get(name).copied()
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum ExtMessage {
118 Handshake(Bytes),
120 Metadata(MetadataMessage),
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum MetadataMessageType {
127 Request = 0,
129 Data = 1,
131 Reject = 2,
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
137pub struct MetadataMessage {
138 pub msg_type: MetadataMessageType,
140 pub piece: u32,
142 pub total_size: Option<u64>,
144 pub data: Option<Bytes>,
146}
147
148#[derive(Serialize, Deserialize)]
150struct MetadataDict {
151 msg_type: u8,
152 piece: u32,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
154 total_size: Option<u64>,
155}
156
157impl MetadataMessage {
158 #[must_use]
160 pub fn request(piece: u32) -> Self {
161 Self {
162 msg_type: MetadataMessageType::Request,
163 piece,
164 total_size: None,
165 data: None,
166 }
167 }
168
169 pub fn data(piece: u32, total_size: u64, data: Bytes) -> Self {
171 Self {
172 msg_type: MetadataMessageType::Data,
173 piece,
174 total_size: Some(total_size),
175 data: Some(data),
176 }
177 }
178
179 #[must_use]
181 pub fn reject(piece: u32) -> Self {
182 Self {
183 msg_type: MetadataMessageType::Reject,
184 piece,
185 total_size: None,
186 data: None,
187 }
188 }
189
190 pub fn to_bytes(&self) -> Result<Bytes> {
196 let dict = MetadataDict {
197 msg_type: self.msg_type as u8,
198 piece: self.piece,
199 total_size: self.total_size,
200 };
201 let mut buf = irontide_bencode::to_bytes(&dict)?;
202 if let Some(ref data) = self.data {
203 buf.extend_from_slice(data);
204 }
205 Ok(Bytes::from(buf))
206 }
207
208 pub fn from_bytes(data: &[u8]) -> Result<Self> {
214 let dict_end = find_bencode_dict_end(data)?;
216 let dict: MetadataDict = irontide_bencode::from_bytes_lenient(&data[..dict_end])?;
217
218 let msg_type = match dict.msg_type {
219 0 => MetadataMessageType::Request,
220 1 => MetadataMessageType::Data,
221 2 => MetadataMessageType::Reject,
222 n => {
223 return Err(Error::InvalidExtended(format!(
224 "unknown metadata msg_type {n}"
225 )));
226 }
227 };
228
229 let trailing = if dict_end < data.len() {
230 Some(Bytes::copy_from_slice(&data[dict_end..]))
231 } else {
232 None
233 };
234
235 Ok(Self {
236 msg_type,
237 piece: dict.piece,
238 total_size: dict.total_size,
239 data: trailing,
240 })
241 }
242}
243
244fn find_bencode_dict_end(data: &[u8]) -> Result<usize> {
246 if data.first() != Some(&b'd') {
247 return Err(Error::InvalidExtended("expected bencode dict".into()));
248 }
249 let mut pos = 1;
250 let mut depth = 1u32;
251
252 while pos < data.len() && depth > 0 {
253 match data[pos] {
254 b'd' | b'l' => {
255 depth += 1;
256 pos += 1;
257 }
258 b'e' => {
259 depth -= 1;
260 pos += 1;
261 }
262 b'i' => {
263 pos += 1;
264 while pos < data.len() && data[pos] != b'e' {
265 pos += 1;
266 }
267 pos += 1; }
269 b'0'..=b'9' => {
270 let len_start = pos;
272 while pos < data.len() && data[pos] != b':' {
273 pos += 1;
274 }
275 let len: usize = std::str::from_utf8(&data[len_start..pos])
276 .map_err(|_| Error::InvalidExtended("bad string length".into()))?
277 .parse()
278 .map_err(|_| Error::InvalidExtended("bad string length".into()))?;
279 pos += 1 + len; }
281 b => {
282 return Err(Error::InvalidExtended(format!(
283 "unexpected byte {b:#04x} at position {pos}"
284 )));
285 }
286 }
287 }
288
289 if depth != 0 {
290 return Err(Error::InvalidExtended("unterminated dict".into()));
291 }
292 Ok(pos)
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn ext_handshake_round_trip() {
301 let hs = ExtHandshake::new();
302 let bytes = hs.to_bytes().unwrap();
303 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
304 assert_eq!(hs.m, parsed.m);
305 assert_eq!(hs.v, parsed.v);
306 assert_eq!(hs.reqq, parsed.reqq);
307 }
308
309 #[test]
310 fn ext_handshake_ext_id_lookup() {
311 let hs = ExtHandshake::new();
312 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
313 assert_eq!(hs.ext_id("ut_pex"), Some(2));
314 assert_eq!(hs.ext_id("lt_trackers"), Some(3));
315 assert_eq!(hs.ext_id("ut_holepunch"), Some(4));
316 assert_eq!(hs.ext_id("unknown"), None);
317 }
318
319 #[test]
320 fn ext_handshake_upload_only_round_trip() {
321 let hs = ExtHandshake::new_upload_only();
322 assert!(hs.is_upload_only());
323 let bytes = hs.to_bytes().unwrap();
324 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
325 assert!(parsed.is_upload_only());
326 assert_eq!(parsed.upload_only, Some(1));
327 }
328
329 #[test]
330 fn ext_handshake_no_upload_only_default() {
331 let hs = ExtHandshake::new();
332 assert!(!hs.is_upload_only());
333 assert_eq!(hs.upload_only, None);
334 }
335
336 #[test]
337 fn ext_handshake_with_plugins() {
338 let hs = ExtHandshake::new_with_plugins(&["ut_comment", "ut_holepunch"]);
339 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
341 assert_eq!(hs.ext_id("ut_pex"), Some(2));
342 assert_eq!(hs.ext_id("lt_trackers"), Some(3));
343 assert_eq!(hs.ext_id("ut_comment"), Some(10));
345 assert_eq!(hs.ext_id("ut_holepunch"), Some(11));
346 }
347
348 #[test]
349 fn ext_handshake_with_plugins_round_trip() {
350 let hs = ExtHandshake::new_with_plugins(&["ut_echo"]);
351 let bytes = hs.to_bytes().unwrap();
352 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
353 assert_eq!(parsed.ext_id("ut_echo"), Some(10));
354 assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
355 }
356
357 #[test]
358 fn ext_handshake_no_plugins() {
359 let hs = ExtHandshake::new_with_plugins(&[]);
360 assert_eq!(hs.m.len(), 5); }
362
363 #[test]
364 fn ext_handshake_holepunch_can_be_removed() {
365 let mut hs = ExtHandshake::new();
366 hs.m.remove("ut_holepunch");
367 assert_eq!(hs.ext_id("ut_holepunch"), None);
368 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
369 assert_eq!(hs.ext_id("ut_pex"), Some(2));
370 }
371
372 #[test]
373 fn metadata_request_round_trip() {
374 let msg = MetadataMessage::request(3);
375 let bytes = msg.to_bytes().unwrap();
376 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
377 assert_eq!(parsed.msg_type, MetadataMessageType::Request);
378 assert_eq!(parsed.piece, 3);
379 assert!(parsed.data.is_none());
380 }
381
382 #[test]
383 fn metadata_data_with_trailing() {
384 let msg = MetadataMessage {
385 msg_type: MetadataMessageType::Data,
386 piece: 0,
387 total_size: Some(31415),
388 data: Some(Bytes::from_static(b"raw metadata bytes here")),
389 };
390 let bytes = msg.to_bytes().unwrap();
391 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
392 assert_eq!(parsed.msg_type, MetadataMessageType::Data);
393 assert_eq!(parsed.piece, 0);
394 assert_eq!(parsed.total_size, Some(31415));
395 assert_eq!(
396 parsed.data.as_deref(),
397 Some(b"raw metadata bytes here".as_ref())
398 );
399 }
400
401 #[test]
402 fn metadata_reject() {
403 let msg = MetadataMessage::reject(5);
404 let bytes = msg.to_bytes().unwrap();
405 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
406 assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
407 assert_eq!(parsed.piece, 5);
408 }
409
410 #[test]
414 fn ext_handshake_disable_extension_via_zero() {
415 let mut hs = ExtHandshake::new();
417 hs.m.insert("ut_pex".into(), 0);
418
419 assert_eq!(hs.ext_id("ut_pex"), Some(0));
421
422 let bytes = hs.to_bytes().unwrap();
424 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
425
426 assert_eq!(
428 parsed.ext_id("ut_pex"),
429 Some(0),
430 "BEP 10: message ID 0 means disabled, but must survive round-trip"
431 );
432
433 assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
435 assert_eq!(parsed.ext_id("lt_trackers"), Some(3));
436
437 assert_eq!(parsed.ext_id("nonexistent"), None);
439 assert_eq!(parsed.ext_id("ut_pex"), Some(0));
440 }
441
442 #[test]
444 fn ext_handshake_includes_lt_donthave() {
445 let hs = ExtHandshake::new();
446 assert_eq!(hs.ext_id("lt_donthave"), Some(5));
447 let bytes = hs.to_bytes().unwrap();
449 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
450 assert_eq!(parsed.ext_id("lt_donthave"), Some(5));
451 }
452}