1use std::collections::BTreeMap;
2
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12pub struct ExtHandshake {
13 #[serde(default)]
15 pub m: BTreeMap<String, u8>,
16 #[serde(default, skip_serializing_if = "Option::is_none")]
18 pub v: Option<String>,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub p: Option<u16>,
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub reqq: Option<u32>,
25 #[serde(default, skip_serializing_if = "Option::is_none")]
27 pub metadata_size: Option<u64>,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
30 pub upload_only: Option<u8>,
31}
32
33impl ExtHandshake {
34 pub fn new() -> Self {
36 let mut m = BTreeMap::new();
37 m.insert("ut_metadata".into(), 1);
38 m.insert("ut_pex".into(), 2);
39 m.insert("lt_trackers".into(), 3);
40 m.insert("ut_holepunch".into(), 4);
41 m.insert("lt_donthave".into(), 5);
42
43 ExtHandshake {
44 m,
45 v: Some("Torrent 0.65.0".into()),
46 p: None,
47 reqq: Some(250),
48 metadata_size: None,
49 upload_only: None,
50 }
51 }
52
53 pub fn new_with_plugins(plugin_names: &[&str]) -> Self {
58 let mut hs = Self::new();
59 for (i, name) in plugin_names.iter().enumerate() {
60 hs.m.insert((*name).into(), 10 + i as u8);
61 }
62 hs
63 }
64
65 pub fn new_upload_only() -> Self {
67 let mut hs = Self::new();
68 hs.upload_only = Some(1);
69 hs
70 }
71
72 pub fn is_upload_only(&self) -> bool {
74 self.upload_only.unwrap_or(0) != 0
75 }
76
77 pub fn to_bytes(&self) -> Result<Bytes> {
79 let data = irontide_bencode::to_bytes(self)?;
80 Ok(Bytes::from(data))
81 }
82
83 pub fn from_bytes(data: &[u8]) -> Result<Self> {
88 Ok(irontide_bencode::from_bytes_lenient(data)?)
89 }
90
91 pub fn ext_id(&self, name: &str) -> Option<u8> {
93 self.m.get(name).copied()
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
99pub enum ExtMessage {
100 Handshake(Bytes),
102 Metadata(MetadataMessage),
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum MetadataMessageType {
109 Request = 0,
111 Data = 1,
113 Reject = 2,
115}
116
117#[derive(Debug, Clone, PartialEq, Eq)]
119pub struct MetadataMessage {
120 pub msg_type: MetadataMessageType,
122 pub piece: u32,
124 pub total_size: Option<u64>,
126 pub data: Option<Bytes>,
128}
129
130#[derive(Serialize, Deserialize)]
132struct MetadataDict {
133 msg_type: u8,
134 piece: u32,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
136 total_size: Option<u64>,
137}
138
139impl MetadataMessage {
140 pub fn request(piece: u32) -> Self {
142 MetadataMessage {
143 msg_type: MetadataMessageType::Request,
144 piece,
145 total_size: None,
146 data: None,
147 }
148 }
149
150 pub fn data(piece: u32, total_size: u64, data: Bytes) -> Self {
152 MetadataMessage {
153 msg_type: MetadataMessageType::Data,
154 piece,
155 total_size: Some(total_size),
156 data: Some(data),
157 }
158 }
159
160 pub fn reject(piece: u32) -> Self {
162 MetadataMessage {
163 msg_type: MetadataMessageType::Reject,
164 piece,
165 total_size: None,
166 data: None,
167 }
168 }
169
170 pub fn to_bytes(&self) -> Result<Bytes> {
172 let dict = MetadataDict {
173 msg_type: self.msg_type as u8,
174 piece: self.piece,
175 total_size: self.total_size,
176 };
177 let mut buf = irontide_bencode::to_bytes(&dict)?;
178 if let Some(ref data) = self.data {
179 buf.extend_from_slice(data);
180 }
181 Ok(Bytes::from(buf))
182 }
183
184 pub fn from_bytes(data: &[u8]) -> Result<Self> {
186 let dict_end = find_bencode_dict_end(data)?;
188 let dict: MetadataDict = irontide_bencode::from_bytes_lenient(&data[..dict_end])?;
189
190 let msg_type = match dict.msg_type {
191 0 => MetadataMessageType::Request,
192 1 => MetadataMessageType::Data,
193 2 => MetadataMessageType::Reject,
194 n => {
195 return Err(Error::InvalidExtended(format!(
196 "unknown metadata msg_type {n}"
197 )));
198 }
199 };
200
201 let trailing = if dict_end < data.len() {
202 Some(Bytes::copy_from_slice(&data[dict_end..]))
203 } else {
204 None
205 };
206
207 Ok(MetadataMessage {
208 msg_type,
209 piece: dict.piece,
210 total_size: dict.total_size,
211 data: trailing,
212 })
213 }
214}
215
216fn find_bencode_dict_end(data: &[u8]) -> Result<usize> {
218 if data.first() != Some(&b'd') {
219 return Err(Error::InvalidExtended("expected bencode dict".into()));
220 }
221 let mut pos = 1;
222 let mut depth = 1u32;
223
224 while pos < data.len() && depth > 0 {
225 match data[pos] {
226 b'd' | b'l' => {
227 depth += 1;
228 pos += 1;
229 }
230 b'e' => {
231 depth -= 1;
232 pos += 1;
233 }
234 b'i' => {
235 pos += 1;
236 while pos < data.len() && data[pos] != b'e' {
237 pos += 1;
238 }
239 pos += 1; }
241 b'0'..=b'9' => {
242 let len_start = pos;
244 while pos < data.len() && data[pos] != b':' {
245 pos += 1;
246 }
247 let len: usize = std::str::from_utf8(&data[len_start..pos])
248 .map_err(|_| Error::InvalidExtended("bad string length".into()))?
249 .parse()
250 .map_err(|_| Error::InvalidExtended("bad string length".into()))?;
251 pos += 1 + len; }
253 b => {
254 return Err(Error::InvalidExtended(format!(
255 "unexpected byte {b:#04x} at position {pos}"
256 )));
257 }
258 }
259 }
260
261 if depth != 0 {
262 return Err(Error::InvalidExtended("unterminated dict".into()));
263 }
264 Ok(pos)
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn ext_handshake_round_trip() {
273 let hs = ExtHandshake::new();
274 let bytes = hs.to_bytes().unwrap();
275 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
276 assert_eq!(hs.m, parsed.m);
277 assert_eq!(hs.v, parsed.v);
278 assert_eq!(hs.reqq, parsed.reqq);
279 }
280
281 #[test]
282 fn ext_handshake_ext_id_lookup() {
283 let hs = ExtHandshake::new();
284 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
285 assert_eq!(hs.ext_id("ut_pex"), Some(2));
286 assert_eq!(hs.ext_id("lt_trackers"), Some(3));
287 assert_eq!(hs.ext_id("ut_holepunch"), Some(4));
288 assert_eq!(hs.ext_id("unknown"), None);
289 }
290
291 #[test]
292 fn ext_handshake_upload_only_round_trip() {
293 let hs = ExtHandshake::new_upload_only();
294 assert!(hs.is_upload_only());
295 let bytes = hs.to_bytes().unwrap();
296 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
297 assert!(parsed.is_upload_only());
298 assert_eq!(parsed.upload_only, Some(1));
299 }
300
301 #[test]
302 fn ext_handshake_no_upload_only_default() {
303 let hs = ExtHandshake::new();
304 assert!(!hs.is_upload_only());
305 assert_eq!(hs.upload_only, None);
306 }
307
308 #[test]
309 fn ext_handshake_with_plugins() {
310 let hs = ExtHandshake::new_with_plugins(&["ut_comment", "ut_holepunch"]);
311 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
313 assert_eq!(hs.ext_id("ut_pex"), Some(2));
314 assert_eq!(hs.ext_id("lt_trackers"), Some(3));
315 assert_eq!(hs.ext_id("ut_comment"), Some(10));
317 assert_eq!(hs.ext_id("ut_holepunch"), Some(11));
318 }
319
320 #[test]
321 fn ext_handshake_with_plugins_round_trip() {
322 let hs = ExtHandshake::new_with_plugins(&["ut_echo"]);
323 let bytes = hs.to_bytes().unwrap();
324 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
325 assert_eq!(parsed.ext_id("ut_echo"), Some(10));
326 assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
327 }
328
329 #[test]
330 fn ext_handshake_no_plugins() {
331 let hs = ExtHandshake::new_with_plugins(&[]);
332 assert_eq!(hs.m.len(), 5); }
334
335 #[test]
336 fn ext_handshake_holepunch_can_be_removed() {
337 let mut hs = ExtHandshake::new();
338 hs.m.remove("ut_holepunch");
339 assert_eq!(hs.ext_id("ut_holepunch"), None);
340 assert_eq!(hs.ext_id("ut_metadata"), Some(1));
341 assert_eq!(hs.ext_id("ut_pex"), Some(2));
342 }
343
344 #[test]
345 fn metadata_request_round_trip() {
346 let msg = MetadataMessage::request(3);
347 let bytes = msg.to_bytes().unwrap();
348 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
349 assert_eq!(parsed.msg_type, MetadataMessageType::Request);
350 assert_eq!(parsed.piece, 3);
351 assert!(parsed.data.is_none());
352 }
353
354 #[test]
355 fn metadata_data_with_trailing() {
356 let msg = MetadataMessage {
357 msg_type: MetadataMessageType::Data,
358 piece: 0,
359 total_size: Some(31415),
360 data: Some(Bytes::from_static(b"raw metadata bytes here")),
361 };
362 let bytes = msg.to_bytes().unwrap();
363 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
364 assert_eq!(parsed.msg_type, MetadataMessageType::Data);
365 assert_eq!(parsed.piece, 0);
366 assert_eq!(parsed.total_size, Some(31415));
367 assert_eq!(
368 parsed.data.as_deref(),
369 Some(b"raw metadata bytes here".as_ref())
370 );
371 }
372
373 #[test]
374 fn metadata_reject() {
375 let msg = MetadataMessage::reject(5);
376 let bytes = msg.to_bytes().unwrap();
377 let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
378 assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
379 assert_eq!(parsed.piece, 5);
380 }
381
382 #[test]
386 fn ext_handshake_disable_extension_via_zero() {
387 let mut hs = ExtHandshake::new();
389 hs.m.insert("ut_pex".into(), 0);
390
391 assert_eq!(hs.ext_id("ut_pex"), Some(0));
393
394 let bytes = hs.to_bytes().unwrap();
396 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
397
398 assert_eq!(
400 parsed.ext_id("ut_pex"),
401 Some(0),
402 "BEP 10: message ID 0 means disabled, but must survive round-trip"
403 );
404
405 assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
407 assert_eq!(parsed.ext_id("lt_trackers"), Some(3));
408
409 assert_eq!(parsed.ext_id("nonexistent"), None);
411 assert_eq!(parsed.ext_id("ut_pex"), Some(0));
412 }
413
414 #[test]
416 fn ext_handshake_includes_lt_donthave() {
417 let hs = ExtHandshake::new();
418 assert_eq!(hs.ext_id("lt_donthave"), Some(5));
419 let bytes = hs.to_bytes().unwrap();
421 let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
422 assert_eq!(parsed.ext_id("lt_donthave"), Some(5));
423 }
424}