use std::collections::BTreeMap;
const MAX_STRING_LENGTH: usize = 100 * 1024 * 1024;
use std::fmt;
use crate::error::{EngineError, ProtocolErrorKind, Result};
#[derive(Clone, PartialEq, Eq)]
pub enum BencodeValue {
Integer(i64),
Bytes(Vec<u8>),
List(Vec<BencodeValue>),
Dict(BTreeMap<Vec<u8>, BencodeValue>),
}
impl fmt::Debug for BencodeValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Integer(n) => write!(f, "Integer({})", n),
Self::Bytes(b) => {
if let Ok(s) = std::str::from_utf8(b) {
if s.len() <= 50 {
write!(f, "Bytes(\"{}\")", s)
} else {
write!(f, "Bytes(\"{}...\" [{} bytes])", &s[..50], b.len())
}
} else {
write!(f, "Bytes([{} bytes])", b.len())
}
}
Self::List(l) => f.debug_tuple("List").field(l).finish(),
Self::Dict(d) => {
let readable: BTreeMap<String, &BencodeValue> = d
.iter()
.map(|(k, v)| {
let key = String::from_utf8_lossy(k).to_string();
(key, v)
})
.collect();
f.debug_tuple("Dict").field(&readable).finish()
}
}
}
}
pub struct ParseResult<'a> {
pub value: BencodeValue,
pub remaining: &'a [u8],
}
impl BencodeValue {
pub fn parse(data: &[u8]) -> Result<ParseResult<'_>> {
if data.is_empty() {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Empty input",
));
}
match data[0] {
b'i' => Self::parse_integer(data),
b'l' => Self::parse_list(data),
b'd' => Self::parse_dict(data),
b'0'..=b'9' => Self::parse_bytes(data),
c => Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
format!("Invalid bencode type marker: {:?}", c as char),
)),
}
}
pub fn parse_exact(data: &[u8]) -> Result<Self> {
let result = Self::parse(data)?;
if !result.remaining.is_empty() {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
format!("Trailing data: {} bytes", result.remaining.len()),
));
}
Ok(result.value)
}
fn parse_integer(data: &[u8]) -> Result<ParseResult<'_>> {
if data.is_empty() || data[0] != b'i' {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Expected integer",
));
}
let end = data[1..].iter().position(|&c| c == b'e').ok_or_else(|| {
EngineError::protocol(ProtocolErrorKind::BencodeParse, "Unterminated integer")
})? + 1;
let num_str = std::str::from_utf8(&data[1..end]).map_err(|_| {
EngineError::protocol(ProtocolErrorKind::BencodeParse, "Invalid integer encoding")
})?;
if num_str.len() > 1 && num_str.starts_with('0') {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Invalid integer: leading zero",
));
}
if num_str == "-0" {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Invalid integer: negative zero",
));
}
if num_str.starts_with("-0") && num_str.len() > 2 {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Invalid integer: leading zero after minus",
));
}
let value = num_str.parse::<i64>().map_err(|_| {
EngineError::protocol(ProtocolErrorKind::BencodeParse, "Integer parse error")
})?;
Ok(ParseResult {
value: BencodeValue::Integer(value),
remaining: &data[end + 1..], })
}
fn parse_bytes(data: &[u8]) -> Result<ParseResult<'_>> {
let colon = data.iter().position(|&c| c == b':').ok_or_else(|| {
EngineError::protocol(ProtocolErrorKind::BencodeParse, "Expected colon in string")
})?;
let len_str = std::str::from_utf8(&data[..colon]).map_err(|_| {
EngineError::protocol(ProtocolErrorKind::BencodeParse, "Invalid string length")
})?;
let len = len_str.parse::<usize>().map_err(|_| {
EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Invalid string length number",
)
})?;
if len > MAX_STRING_LENGTH {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
format!(
"String length {} exceeds maximum allowed {} bytes",
len, MAX_STRING_LENGTH
),
));
}
let start = colon + 1;
let end = start + len;
if end > data.len() {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
format!(
"String length {} exceeds available data {}",
len,
data.len() - start
),
));
}
Ok(ParseResult {
value: BencodeValue::Bytes(data[start..end].to_vec()),
remaining: &data[end..],
})
}
fn parse_list(data: &[u8]) -> Result<ParseResult<'_>> {
if data.is_empty() || data[0] != b'l' {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Expected list",
));
}
let mut items = Vec::new();
let mut remaining = &data[1..];
while !remaining.is_empty() && remaining[0] != b'e' {
let result = Self::parse(remaining)?;
items.push(result.value);
remaining = result.remaining;
}
if remaining.is_empty() {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Unterminated list",
));
}
Ok(ParseResult {
value: BencodeValue::List(items),
remaining: &remaining[1..], })
}
fn parse_dict(data: &[u8]) -> Result<ParseResult<'_>> {
if data.is_empty() || data[0] != b'd' {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Expected dict",
));
}
let mut items = BTreeMap::new();
let mut remaining = &data[1..]; let mut last_key: Option<Vec<u8>> = None;
while !remaining.is_empty() && remaining[0] != b'e' {
let key_result = Self::parse_bytes(remaining)?;
let key = match key_result.value {
BencodeValue::Bytes(k) => k,
_ => {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Dict key must be a string",
))
}
};
if let Some(ref lk) = last_key {
if &key <= lk {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Dict keys not in sorted order",
));
}
}
last_key = Some(key.clone());
remaining = key_result.remaining;
let value_result = Self::parse(remaining)?;
items.insert(key, value_result.value);
remaining = value_result.remaining;
}
if remaining.is_empty() {
return Err(EngineError::protocol(
ProtocolErrorKind::BencodeParse,
"Unterminated dict",
));
}
Ok(ParseResult {
value: BencodeValue::Dict(items),
remaining: &remaining[1..], })
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::new();
self.encode_to(&mut buf);
buf
}
pub fn encode_to(&self, buf: &mut Vec<u8>) {
match self {
Self::Integer(n) => {
buf.push(b'i');
buf.extend_from_slice(n.to_string().as_bytes());
buf.push(b'e');
}
Self::Bytes(b) => {
buf.extend_from_slice(b.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(b);
}
Self::List(l) => {
buf.push(b'l');
for item in l {
item.encode_to(buf);
}
buf.push(b'e');
}
Self::Dict(d) => {
buf.push(b'd');
for (k, v) in d {
buf.extend_from_slice(k.len().to_string().as_bytes());
buf.push(b':');
buf.extend_from_slice(k);
v.encode_to(buf);
}
buf.push(b'e');
}
}
}
pub fn as_string(&self) -> Option<&str> {
match self {
Self::Bytes(b) => std::str::from_utf8(b).ok(),
_ => None,
}
}
pub fn as_int(&self) -> Option<i64> {
match self {
Self::Integer(n) => Some(*n),
_ => None,
}
}
pub fn as_uint(&self) -> Option<u64> {
match self {
Self::Integer(n) if *n >= 0 => Some(*n as u64),
_ => None,
}
}
pub fn as_bytes(&self) -> Option<&[u8]> {
match self {
Self::Bytes(b) => Some(b),
_ => None,
}
}
pub fn as_list(&self) -> Option<&[BencodeValue]> {
match self {
Self::List(l) => Some(l),
_ => None,
}
}
pub fn as_list_mut(&mut self) -> Option<&mut Vec<BencodeValue>> {
match self {
Self::List(l) => Some(l),
_ => None,
}
}
pub fn as_dict(&self) -> Option<&BTreeMap<Vec<u8>, BencodeValue>> {
match self {
Self::Dict(d) => Some(d),
_ => None,
}
}
pub fn as_dict_mut(&mut self) -> Option<&mut BTreeMap<Vec<u8>, BencodeValue>> {
match self {
Self::Dict(d) => Some(d),
_ => None,
}
}
pub fn get(&self, key: &str) -> Option<&BencodeValue> {
match self {
Self::Dict(d) => d.get(key.as_bytes()),
_ => None,
}
}
pub fn get_bytes(&self, key: &[u8]) -> Option<&BencodeValue> {
match self {
Self::Dict(d) => d.get(key),
_ => None,
}
}
pub fn is_dict(&self) -> bool {
matches!(self, Self::Dict(_))
}
pub fn is_list(&self) -> bool {
matches!(self, Self::List(_))
}
pub fn is_bytes(&self) -> bool {
matches!(self, Self::Bytes(_))
}
pub fn is_int(&self) -> bool {
matches!(self, Self::Integer(_))
}
}
pub fn find_info_dict_bytes(data: &[u8]) -> Result<&[u8]> {
let root = BencodeValue::parse_exact(data)?;
let dict = root.as_dict().ok_or_else(|| {
EngineError::protocol(ProtocolErrorKind::InvalidTorrent, "Root is not a dict")
})?;
if !dict.contains_key(b"info".as_slice()) {
return Err(EngineError::protocol(
ProtocolErrorKind::InvalidTorrent,
"Missing 'info' key",
));
}
let info_key = b"4:info";
let mut pos = 0;
while pos < data.len() {
if data[pos..].starts_with(info_key) {
let info_start = pos + info_key.len();
if info_start < data.len() && data[info_start] == b'd' {
let result = BencodeValue::parse(&data[info_start..])?;
let info_len = data.len() - info_start - result.remaining.len();
return Ok(&data[info_start..info_start + info_len]);
}
}
pos += 1;
}
Err(EngineError::protocol(
ProtocolErrorKind::InvalidTorrent,
"Could not locate info dict bytes",
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_integer() {
let result = BencodeValue::parse(b"i42e").unwrap();
assert_eq!(result.value, BencodeValue::Integer(42));
assert!(result.remaining.is_empty());
let result = BencodeValue::parse(b"i-42e").unwrap();
assert_eq!(result.value, BencodeValue::Integer(-42));
let result = BencodeValue::parse(b"i0e").unwrap();
assert_eq!(result.value, BencodeValue::Integer(0));
assert!(BencodeValue::parse(b"i03e").is_err());
assert!(BencodeValue::parse(b"i-0e").is_err());
}
#[test]
fn test_parse_bytes() {
let result = BencodeValue::parse(b"4:spam").unwrap();
assert_eq!(result.value, BencodeValue::Bytes(b"spam".to_vec()));
assert!(result.remaining.is_empty());
let result = BencodeValue::parse(b"0:").unwrap();
assert_eq!(result.value, BencodeValue::Bytes(vec![]));
let data = b"5:\x00\x01\x02\x03\x04";
let result = BencodeValue::parse(data).unwrap();
assert_eq!(result.value, BencodeValue::Bytes(vec![0, 1, 2, 3, 4]));
}
#[test]
fn test_parse_list() {
let result = BencodeValue::parse(b"l4:spami42ee").unwrap();
if let BencodeValue::List(items) = result.value {
assert_eq!(items.len(), 2);
assert_eq!(items[0], BencodeValue::Bytes(b"spam".to_vec()));
assert_eq!(items[1], BencodeValue::Integer(42));
} else {
panic!("Expected list");
}
let result = BencodeValue::parse(b"le").unwrap();
assert_eq!(result.value, BencodeValue::List(vec![]));
let result = BencodeValue::parse(b"ll4:spamee").unwrap();
if let BencodeValue::List(items) = result.value {
assert_eq!(items.len(), 1);
assert!(matches!(&items[0], BencodeValue::List(_)));
} else {
panic!("Expected list");
}
}
#[test]
fn test_parse_dict() {
let result = BencodeValue::parse(b"d3:cow3:moo4:spam4:eggse").unwrap();
if let BencodeValue::Dict(d) = &result.value {
assert_eq!(d.len(), 2);
assert_eq!(
d.get(b"cow".as_slice()),
Some(&BencodeValue::Bytes(b"moo".to_vec()))
);
assert_eq!(
d.get(b"spam".as_slice()),
Some(&BencodeValue::Bytes(b"eggs".to_vec()))
);
} else {
panic!("Expected dict");
}
let result = BencodeValue::parse(b"de").unwrap();
assert_eq!(result.value, BencodeValue::Dict(BTreeMap::new()));
}
#[test]
fn test_encode() {
let v = BencodeValue::Integer(42);
assert_eq!(v.encode(), b"i42e");
let v = BencodeValue::Bytes(b"spam".to_vec());
assert_eq!(v.encode(), b"4:spam");
let v = BencodeValue::List(vec![
BencodeValue::Bytes(b"spam".to_vec()),
BencodeValue::Integer(42),
]);
assert_eq!(v.encode(), b"l4:spami42ee");
let mut d = BTreeMap::new();
d.insert(b"cow".to_vec(), BencodeValue::Bytes(b"moo".to_vec()));
d.insert(b"spam".to_vec(), BencodeValue::Bytes(b"eggs".to_vec()));
let v = BencodeValue::Dict(d);
assert_eq!(v.encode(), b"d3:cow3:moo4:spam4:eggse");
}
#[test]
fn test_roundtrip() {
let original = b"d5:itemsli1ei2ei3ee4:name4:test5:valuei42ee";
let value = BencodeValue::parse_exact(original).unwrap();
let encoded = value.encode();
assert_eq!(encoded, original.to_vec());
assert_eq!(value.get("name").and_then(|v| v.as_string()), Some("test"));
assert_eq!(value.get("value").and_then(|v| v.as_int()), Some(42));
assert_eq!(
value
.get("items")
.and_then(|v| v.as_list())
.map(|l| l.len()),
Some(3)
);
}
#[test]
fn test_accessor_methods() {
let data = b"d4:listli1ei2ei3ee4:name4:test3:numi42ee";
let value = BencodeValue::parse_exact(data).unwrap();
assert_eq!(value.get("num").and_then(|v| v.as_int()), Some(42));
assert_eq!(value.get("name").and_then(|v| v.as_string()), Some("test"));
assert_eq!(
value.get("list").and_then(|v| v.as_list()).map(|l| l.len()),
Some(3)
);
assert!(value.get("missing").is_none());
}
}