use std::collections::BTreeMap;
use bytes::Bytes;
use super::error::PeerError;
use crate::bencode::{decode, encode, Value};
pub const METADATA_PIECE_SIZE: usize = 16384;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MetadataMessageType {
Request = 0,
Data = 1,
Reject = 2,
}
impl MetadataMessageType {
pub fn from_byte(b: u8) -> Option<Self> {
match b {
0 => Some(MetadataMessageType::Request),
1 => Some(MetadataMessageType::Data),
2 => Some(MetadataMessageType::Reject),
_ => None,
}
}
pub fn as_byte(&self) -> u8 {
*self as u8
}
}
#[derive(Debug, Clone)]
pub struct MetadataMessage {
pub msg_type: MetadataMessageType,
pub piece: u32,
pub total_size: Option<u32>,
pub data: Option<Bytes>,
}
impl MetadataMessage {
pub fn request(piece: u32) -> Self {
Self {
msg_type: MetadataMessageType::Request,
piece,
total_size: None,
data: None,
}
}
pub fn data(piece: u32, total_size: u32, data: Bytes) -> Self {
Self {
msg_type: MetadataMessageType::Data,
piece,
total_size: Some(total_size),
data: Some(data),
}
}
pub fn reject(piece: u32) -> Self {
Self {
msg_type: MetadataMessageType::Reject,
piece,
total_size: None,
data: None,
}
}
pub fn encode(&self) -> Result<Bytes, PeerError> {
let mut dict = BTreeMap::new();
dict.insert(
Bytes::from_static(b"msg_type"),
Value::Integer(self.msg_type.as_byte() as i64),
);
dict.insert(
Bytes::from_static(b"piece"),
Value::Integer(self.piece as i64),
);
if let Some(total_size) = self.total_size {
dict.insert(
Bytes::from_static(b"total_size"),
Value::Integer(total_size as i64),
);
}
let encoded_dict = encode(&Value::Dict(dict))?;
if let Some(ref data) = self.data {
let mut result = Vec::with_capacity(encoded_dict.len() + data.len());
result.extend_from_slice(&encoded_dict);
result.extend_from_slice(data);
Ok(Bytes::from(result))
} else {
Ok(Bytes::from(encoded_dict))
}
}
pub fn decode(payload: &[u8]) -> Result<Self, PeerError> {
let dict_end = find_dict_end(payload)?;
let value = decode(&payload[..dict_end])?;
let dict = value
.as_dict()
.ok_or_else(|| PeerError::Extension("expected dict".into()))?;
let msg_type_byte =
dict.get(b"msg_type".as_slice())
.and_then(|v| v.as_integer())
.ok_or_else(|| PeerError::Extension("missing msg_type".into()))? as u8;
let msg_type = MetadataMessageType::from_byte(msg_type_byte)
.ok_or_else(|| PeerError::Extension("invalid msg_type".into()))?;
let piece =
dict.get(b"piece".as_slice())
.and_then(|v| v.as_integer())
.ok_or_else(|| PeerError::Extension("missing piece".into()))? as u32;
let total_size = dict
.get(b"total_size".as_slice())
.and_then(|v| v.as_integer())
.map(|v| v as u32);
let data = if msg_type == MetadataMessageType::Data && dict_end < payload.len() {
Some(Bytes::copy_from_slice(&payload[dict_end..]))
} else {
None
};
Ok(Self {
msg_type,
piece,
total_size,
data,
})
}
}
fn find_dict_end(payload: &[u8]) -> Result<usize, PeerError> {
if payload.is_empty() || payload[0] != b'd' {
return Err(PeerError::Extension("payload must start with 'd'".into()));
}
let mut depth = 0;
let mut i = 0;
while i < payload.len() {
match payload[i] {
b'd' | b'l' => {
depth += 1;
i += 1;
}
b'e' => {
depth -= 1;
i += 1;
if depth == 0 {
return Ok(i);
}
}
b'i' => {
i += 1;
while i < payload.len() && payload[i] != b'e' {
i += 1;
}
i += 1; }
b'0'..=b'9' => {
let len_start = i;
while i < payload.len() && payload[i] != b':' {
i += 1;
}
let len_str = std::str::from_utf8(&payload[len_start..i])
.map_err(|_| PeerError::Extension("invalid string length".into()))?;
let len: usize = len_str
.parse()
.map_err(|_| PeerError::Extension("invalid string length".into()))?;
i += 1; i += len; }
_ => {
return Err(PeerError::Extension("invalid bencode".into()));
}
}
}
Err(PeerError::Extension("unterminated dict".into()))
}
pub fn metadata_piece_count(metadata_size: usize) -> usize {
metadata_size.div_ceil(METADATA_PIECE_SIZE)
}
pub fn metadata_piece_size(piece: u32, total_size: usize) -> usize {
let offset = piece as usize * METADATA_PIECE_SIZE;
if offset >= total_size {
0
} else {
(total_size - offset).min(METADATA_PIECE_SIZE)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_encode_decode() {
let msg = MetadataMessage::request(5);
let encoded = msg.encode().unwrap();
let decoded = MetadataMessage::decode(&encoded).unwrap();
assert_eq!(decoded.msg_type, MetadataMessageType::Request);
assert_eq!(decoded.piece, 5);
assert!(decoded.total_size.is_none());
assert!(decoded.data.is_none());
}
#[test]
fn test_data_encode_decode() {
let data = Bytes::from(vec![1, 2, 3, 4, 5]);
let msg = MetadataMessage::data(2, 1000, data.clone());
let encoded = msg.encode().unwrap();
let decoded = MetadataMessage::decode(&encoded).unwrap();
assert_eq!(decoded.msg_type, MetadataMessageType::Data);
assert_eq!(decoded.piece, 2);
assert_eq!(decoded.total_size, Some(1000));
assert_eq!(decoded.data, Some(data));
}
#[test]
fn test_reject_encode_decode() {
let msg = MetadataMessage::reject(10);
let encoded = msg.encode().unwrap();
let decoded = MetadataMessage::decode(&encoded).unwrap();
assert_eq!(decoded.msg_type, MetadataMessageType::Reject);
assert_eq!(decoded.piece, 10);
}
#[test]
fn test_metadata_piece_count() {
assert_eq!(metadata_piece_count(0), 0);
assert_eq!(metadata_piece_count(1), 1);
assert_eq!(metadata_piece_count(16384), 1);
assert_eq!(metadata_piece_count(16385), 2);
assert_eq!(metadata_piece_count(32768), 2);
assert_eq!(metadata_piece_count(50000), 4);
}
}