#![allow(
clippy::cast_possible_truncation,
reason = "M175: BEP 10 extension protocol — message-id bytes bounded by extension count (u8)"
)]
use std::collections::BTreeMap;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExtHandshake {
#[serde(default)]
pub m: BTreeMap<String, u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub v: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub p: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reqq: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub metadata_size: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub upload_only: Option<u8>,
}
impl ExtHandshake {
#[must_use]
pub fn new() -> Self {
let mut m = BTreeMap::new();
m.insert("ut_metadata".into(), 1);
m.insert("ut_pex".into(), 2);
m.insert("lt_trackers".into(), 3);
m.insert("ut_holepunch".into(), 4);
m.insert("lt_donthave".into(), 5);
Self {
m,
v: Some("Torrent 0.65.0".into()),
p: None,
reqq: Some(250),
metadata_size: None,
upload_only: None,
}
}
#[must_use]
pub fn new_with_plugins(plugin_names: &[&str]) -> Self {
let mut hs = Self::new();
for (i, name) in plugin_names.iter().enumerate() {
hs.m.insert((*name).into(), 10 + i as u8);
}
hs
}
#[must_use]
pub fn new_upload_only() -> Self {
let mut hs = Self::new();
hs.upload_only = Some(1);
hs
}
#[must_use]
pub fn is_upload_only(&self) -> bool {
self.upload_only.unwrap_or(0) != 0
}
pub fn to_bytes(&self) -> Result<Bytes> {
let data = irontide_bencode::to_bytes(self)?;
Ok(Bytes::from(data))
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
Ok(irontide_bencode::from_bytes_lenient(data)?)
}
#[must_use]
pub fn ext_id(&self, name: &str) -> Option<u8> {
self.m.get(name).copied()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExtMessage {
Handshake(Bytes),
Metadata(MetadataMessage),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MetadataMessageType {
Request = 0,
Data = 1,
Reject = 2,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MetadataMessage {
pub msg_type: MetadataMessageType,
pub piece: u32,
pub total_size: Option<u64>,
pub data: Option<Bytes>,
}
#[derive(Serialize, Deserialize)]
struct MetadataDict {
msg_type: u8,
piece: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
total_size: Option<u64>,
}
impl MetadataMessage {
#[must_use]
pub fn request(piece: u32) -> Self {
Self {
msg_type: MetadataMessageType::Request,
piece,
total_size: None,
data: None,
}
}
pub fn data(piece: u32, total_size: u64, data: Bytes) -> Self {
Self {
msg_type: MetadataMessageType::Data,
piece,
total_size: Some(total_size),
data: Some(data),
}
}
#[must_use]
pub fn reject(piece: u32) -> Self {
Self {
msg_type: MetadataMessageType::Reject,
piece,
total_size: None,
data: None,
}
}
pub fn to_bytes(&self) -> Result<Bytes> {
let dict = MetadataDict {
msg_type: self.msg_type as u8,
piece: self.piece,
total_size: self.total_size,
};
let mut buf = irontide_bencode::to_bytes(&dict)?;
if let Some(ref data) = self.data {
buf.extend_from_slice(data);
}
Ok(Bytes::from(buf))
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
let dict_end = find_bencode_dict_end(data)?;
let dict: MetadataDict = irontide_bencode::from_bytes_lenient(&data[..dict_end])?;
let msg_type = match dict.msg_type {
0 => MetadataMessageType::Request,
1 => MetadataMessageType::Data,
2 => MetadataMessageType::Reject,
n => {
return Err(Error::InvalidExtended(format!(
"unknown metadata msg_type {n}"
)));
}
};
let trailing = if dict_end < data.len() {
Some(Bytes::copy_from_slice(&data[dict_end..]))
} else {
None
};
Ok(Self {
msg_type,
piece: dict.piece,
total_size: dict.total_size,
data: trailing,
})
}
}
fn find_bencode_dict_end(data: &[u8]) -> Result<usize> {
if data.first() != Some(&b'd') {
return Err(Error::InvalidExtended("expected bencode dict".into()));
}
let mut pos = 1;
let mut depth = 1u32;
while pos < data.len() && depth > 0 {
match data[pos] {
b'd' | b'l' => {
depth += 1;
pos += 1;
}
b'e' => {
depth -= 1;
pos += 1;
}
b'i' => {
pos += 1;
while pos < data.len() && data[pos] != b'e' {
pos += 1;
}
pos += 1; }
b'0'..=b'9' => {
let len_start = pos;
while pos < data.len() && data[pos] != b':' {
pos += 1;
}
let len: usize = std::str::from_utf8(&data[len_start..pos])
.map_err(|_| Error::InvalidExtended("bad string length".into()))?
.parse()
.map_err(|_| Error::InvalidExtended("bad string length".into()))?;
pos += 1 + len; }
b => {
return Err(Error::InvalidExtended(format!(
"unexpected byte {b:#04x} at position {pos}"
)));
}
}
}
if depth != 0 {
return Err(Error::InvalidExtended("unterminated dict".into()));
}
Ok(pos)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ext_handshake_round_trip() {
let hs = ExtHandshake::new();
let bytes = hs.to_bytes().unwrap();
let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
assert_eq!(hs.m, parsed.m);
assert_eq!(hs.v, parsed.v);
assert_eq!(hs.reqq, parsed.reqq);
}
#[test]
fn ext_handshake_ext_id_lookup() {
let hs = ExtHandshake::new();
assert_eq!(hs.ext_id("ut_metadata"), Some(1));
assert_eq!(hs.ext_id("ut_pex"), Some(2));
assert_eq!(hs.ext_id("lt_trackers"), Some(3));
assert_eq!(hs.ext_id("ut_holepunch"), Some(4));
assert_eq!(hs.ext_id("unknown"), None);
}
#[test]
fn ext_handshake_upload_only_round_trip() {
let hs = ExtHandshake::new_upload_only();
assert!(hs.is_upload_only());
let bytes = hs.to_bytes().unwrap();
let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
assert!(parsed.is_upload_only());
assert_eq!(parsed.upload_only, Some(1));
}
#[test]
fn ext_handshake_no_upload_only_default() {
let hs = ExtHandshake::new();
assert!(!hs.is_upload_only());
assert_eq!(hs.upload_only, None);
}
#[test]
fn ext_handshake_with_plugins() {
let hs = ExtHandshake::new_with_plugins(&["ut_comment", "ut_holepunch"]);
assert_eq!(hs.ext_id("ut_metadata"), Some(1));
assert_eq!(hs.ext_id("ut_pex"), Some(2));
assert_eq!(hs.ext_id("lt_trackers"), Some(3));
assert_eq!(hs.ext_id("ut_comment"), Some(10));
assert_eq!(hs.ext_id("ut_holepunch"), Some(11));
}
#[test]
fn ext_handshake_with_plugins_round_trip() {
let hs = ExtHandshake::new_with_plugins(&["ut_echo"]);
let bytes = hs.to_bytes().unwrap();
let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
assert_eq!(parsed.ext_id("ut_echo"), Some(10));
assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
}
#[test]
fn ext_handshake_no_plugins() {
let hs = ExtHandshake::new_with_plugins(&[]);
assert_eq!(hs.m.len(), 5); }
#[test]
fn ext_handshake_holepunch_can_be_removed() {
let mut hs = ExtHandshake::new();
hs.m.remove("ut_holepunch");
assert_eq!(hs.ext_id("ut_holepunch"), None);
assert_eq!(hs.ext_id("ut_metadata"), Some(1));
assert_eq!(hs.ext_id("ut_pex"), Some(2));
}
#[test]
fn metadata_request_round_trip() {
let msg = MetadataMessage::request(3);
let bytes = msg.to_bytes().unwrap();
let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Request);
assert_eq!(parsed.piece, 3);
assert!(parsed.data.is_none());
}
#[test]
fn metadata_data_with_trailing() {
let msg = MetadataMessage {
msg_type: MetadataMessageType::Data,
piece: 0,
total_size: Some(31415),
data: Some(Bytes::from_static(b"raw metadata bytes here")),
};
let bytes = msg.to_bytes().unwrap();
let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Data);
assert_eq!(parsed.piece, 0);
assert_eq!(parsed.total_size, Some(31415));
assert_eq!(
parsed.data.as_deref(),
Some(b"raw metadata bytes here".as_ref())
);
}
#[test]
fn metadata_reject() {
let msg = MetadataMessage::reject(5);
let bytes = msg.to_bytes().unwrap();
let parsed = MetadataMessage::from_bytes(&bytes).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
assert_eq!(parsed.piece, 5);
}
#[test]
fn ext_handshake_disable_extension_via_zero() {
let mut hs = ExtHandshake::new();
hs.m.insert("ut_pex".into(), 0);
assert_eq!(hs.ext_id("ut_pex"), Some(0));
let bytes = hs.to_bytes().unwrap();
let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
assert_eq!(
parsed.ext_id("ut_pex"),
Some(0),
"BEP 10: message ID 0 means disabled, but must survive round-trip"
);
assert_eq!(parsed.ext_id("ut_metadata"), Some(1));
assert_eq!(parsed.ext_id("lt_trackers"), Some(3));
assert_eq!(parsed.ext_id("nonexistent"), None);
assert_eq!(parsed.ext_id("ut_pex"), Some(0));
}
#[test]
fn ext_handshake_includes_lt_donthave() {
let hs = ExtHandshake::new();
assert_eq!(hs.ext_id("lt_donthave"), Some(5));
let bytes = hs.to_bytes().unwrap();
let parsed = ExtHandshake::from_bytes(&bytes).unwrap();
assert_eq!(parsed.ext_id("lt_donthave"), Some(5));
}
}