use std::collections::HashMap;
use std::time::Duration;
use sha1::{Digest, Sha1};
use tokio::sync::RwLock;
use crate::error::{EngineError, ProtocolErrorKind, Result};
use crate::torrent::bencode::BencodeValue;
use crate::torrent::metainfo::{Metainfo, Sha1Hash};
pub const METADATA_PIECE_SIZE: usize = 16 * 1024;
pub const METADATA_EXTENSION_NAME: &str = "ut_metadata";
pub const OUR_METADATA_EXTENSION_ID: u8 = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MetadataMessageType {
Request = 0,
Data = 1,
Reject = 2,
}
impl MetadataMessageType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Request),
1 => Some(Self::Data),
2 => Some(Self::Reject),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct MetadataMessage {
pub msg_type: MetadataMessageType,
pub piece: usize,
pub total_size: Option<usize>,
pub data: Option<Vec<u8>>,
}
impl MetadataMessage {
pub fn request(piece: usize) -> Self {
Self {
msg_type: MetadataMessageType::Request,
piece,
total_size: None,
data: None,
}
}
pub fn data(piece: usize, total_size: usize, data: Vec<u8>) -> Self {
Self {
msg_type: MetadataMessageType::Data,
piece,
total_size: Some(total_size),
data: Some(data),
}
}
pub fn reject(piece: usize) -> Self {
Self {
msg_type: MetadataMessageType::Reject,
piece,
total_size: None,
data: None,
}
}
pub fn encode(&self) -> Vec<u8> {
let mut dict = std::collections::BTreeMap::new();
dict.insert(
b"msg_type".to_vec(),
BencodeValue::Integer(self.msg_type as i64),
);
dict.insert(b"piece".to_vec(), BencodeValue::Integer(self.piece as i64));
if let Some(size) = self.total_size {
dict.insert(b"total_size".to_vec(), BencodeValue::Integer(size as i64));
}
let mut encoded = BencodeValue::Dict(dict).encode();
if let Some(ref data) = self.data {
encoded.extend_from_slice(data);
}
encoded
}
pub fn parse(data: &[u8]) -> Result<Self> {
let parse_result = BencodeValue::parse(data)?;
let consumed = data.len() - parse_result.remaining.len();
let dict = parse_result.value.as_dict().ok_or_else(|| {
EngineError::protocol(
ProtocolErrorKind::MetadataError,
"Metadata message must be a dict",
)
})?;
let msg_type = dict
.get(b"msg_type".as_slice())
.and_then(|v: &BencodeValue| v.as_uint())
.and_then(|v| MetadataMessageType::from_u8(v as u8))
.ok_or_else(|| {
EngineError::protocol(
ProtocolErrorKind::MetadataError,
"Invalid or missing msg_type",
)
})?;
let piece = dict
.get(b"piece".as_slice())
.and_then(|v: &BencodeValue| v.as_uint())
.map(|v| v as usize)
.ok_or_else(|| {
EngineError::protocol(ProtocolErrorKind::MetadataError, "Invalid or missing piece")
})?;
let total_size = dict
.get(b"total_size".as_slice())
.and_then(|v: &BencodeValue| v.as_uint())
.map(|v| v as usize);
let piece_data = if msg_type == MetadataMessageType::Data && consumed < data.len() {
Some(data[consumed..].to_vec())
} else {
None
};
Ok(Self {
msg_type,
piece,
total_size,
data: piece_data,
})
}
}
pub struct MetadataFetcher {
info_hash: Sha1Hash,
total_size: RwLock<Option<usize>>,
pieces: RwLock<HashMap<usize, Vec<u8>>>,
pending_requests: RwLock<HashMap<usize, std::time::Instant>>,
metadata: RwLock<Option<Vec<u8>>>,
}
impl MetadataFetcher {
pub fn new(info_hash: Sha1Hash) -> Self {
Self {
info_hash,
total_size: RwLock::new(None),
pieces: RwLock::new(HashMap::new()),
pending_requests: RwLock::new(HashMap::new()),
metadata: RwLock::new(None),
}
}
pub async fn is_complete(&self) -> bool {
self.metadata.read().await.is_some()
}
pub async fn num_pieces(&self) -> Option<usize> {
let size = (*self.total_size.read().await)?;
Some(size.div_ceil(METADATA_PIECE_SIZE))
}
pub async fn get_needed_pieces(&self) -> Vec<usize> {
let Some(num_pieces) = self.num_pieces().await else {
return vec![0];
};
let pieces = self.pieces.read().await;
let pending = self.pending_requests.read().await;
let now = std::time::Instant::now();
let timeout = Duration::from_secs(30);
(0..num_pieces)
.filter(|i| {
!pieces.contains_key(i)
&& !pending
.get(i)
.map(|t| now.duration_since(*t) < timeout)
.unwrap_or(false)
})
.collect()
}
pub async fn mark_requested(&self, piece: usize) {
self.pending_requests
.write()
.await
.insert(piece, std::time::Instant::now());
}
pub async fn process_message(&self, msg: MetadataMessage) -> Result<bool> {
match msg.msg_type {
MetadataMessageType::Request => {
Ok(false)
}
MetadataMessageType::Reject => {
self.pending_requests.write().await.remove(&msg.piece);
Ok(false)
}
MetadataMessageType::Data => {
let Some(total_size) = msg.total_size else {
return Err(EngineError::protocol(
ProtocolErrorKind::MetadataError,
"Data message missing total_size",
));
};
let Some(data) = msg.data else {
return Err(EngineError::protocol(
ProtocolErrorKind::MetadataError,
"Data message missing piece data",
));
};
let expected_size = if msg.piece < (total_size / METADATA_PIECE_SIZE) {
METADATA_PIECE_SIZE
} else {
total_size % METADATA_PIECE_SIZE
};
if data.len() != expected_size && expected_size != 0 {
return Err(EngineError::protocol(
ProtocolErrorKind::MetadataError,
format!(
"Piece {} has wrong size: {} (expected {})",
msg.piece,
data.len(),
expected_size
),
));
}
{
let mut size = self.total_size.write().await;
if size.is_none() {
*size = Some(total_size);
}
}
{
let mut pieces = self.pieces.write().await;
pieces.insert(msg.piece, data);
self.pending_requests.write().await.remove(&msg.piece);
}
if let Some(num_pieces) = self.num_pieces().await {
let pieces = self.pieces.read().await;
if pieces.len() >= num_pieces {
drop(pieces);
return self.assemble_and_verify().await;
}
}
Ok(false)
}
}
}
async fn assemble_and_verify(&self) -> Result<bool> {
let Some(total_size) = *self.total_size.read().await else {
return Ok(false);
};
let num_pieces = total_size.div_ceil(METADATA_PIECE_SIZE);
let pieces = self.pieces.read().await;
let mut assembled = Vec::with_capacity(total_size);
for i in 0..num_pieces {
let Some(piece) = pieces.get(&i) else {
return Ok(false); };
assembled.extend_from_slice(piece);
}
assembled.truncate(total_size);
let mut hasher = Sha1::new();
hasher.update(&assembled);
let hash: [u8; 20] = hasher.finalize().into();
if hash != self.info_hash {
tracing::warn!(
"Metadata info_hash mismatch: expected {:?}, got {:?}",
self.info_hash,
hash
);
drop(pieces);
self.pieces.write().await.clear();
return Ok(false);
}
*self.metadata.write().await = Some(assembled);
Ok(true)
}
pub async fn get_metadata(&self) -> Option<Vec<u8>> {
self.metadata.read().await.clone()
}
pub async fn parse_metainfo(&self) -> Result<Option<Metainfo>> {
let Some(data) = self.get_metadata().await else {
return Ok(None);
};
let info_value = BencodeValue::parse_exact(&data)?;
let mut torrent_dict = std::collections::BTreeMap::new();
torrent_dict.insert(b"info".to_vec(), info_value);
let torrent_bytes = BencodeValue::Dict(torrent_dict).encode();
let metainfo = Metainfo::parse(&torrent_bytes)?;
Ok(Some(metainfo))
}
pub async fn raw_torrent_bytes(&self) -> Option<Vec<u8>> {
let data = self.get_metadata().await?;
let info_value = BencodeValue::parse_exact(&data).ok()?;
let mut torrent_dict = std::collections::BTreeMap::new();
torrent_dict.insert(b"info".to_vec(), info_value);
Some(BencodeValue::Dict(torrent_dict).encode())
}
pub async fn total_size(&self) -> Option<usize> {
*self.total_size.read().await
}
pub async fn received_count(&self) -> usize {
self.pieces.read().await.len()
}
}
pub fn build_metadata_extension_handshake(
metadata_size: Option<usize>,
) -> std::collections::BTreeMap<Vec<u8>, BencodeValue> {
let mut m = std::collections::BTreeMap::new();
m.insert(
b"ut_metadata".to_vec(),
BencodeValue::Integer(OUR_METADATA_EXTENSION_ID as i64),
);
let mut result = std::collections::BTreeMap::new();
result.insert(b"m".to_vec(), BencodeValue::Dict(m));
if let Some(size) = metadata_size {
result.insert(
b"metadata_size".to_vec(),
BencodeValue::Integer(size as i64),
);
}
result
}
pub fn parse_metadata_size(handshake_data: &[u8]) -> Option<usize> {
let value = BencodeValue::parse_exact(handshake_data).ok()?;
let dict = value.as_dict()?;
dict.get(b"metadata_size".as_slice())
.and_then(|v| v.as_uint())
.map(|v| v as usize)
}
pub fn parse_metadata_extension_id(handshake_data: &[u8]) -> Option<u8> {
let value = BencodeValue::parse_exact(handshake_data).ok()?;
let dict = value.as_dict()?;
let m = dict.get(b"m".as_slice())?.as_dict()?;
m.get(b"ut_metadata".as_slice())
.and_then(|v| v.as_uint())
.map(|v| v as u8)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metadata_message_request() {
let msg = MetadataMessage::request(5);
let encoded = msg.encode();
let parsed = MetadataMessage::parse(&encoded).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Request);
assert_eq!(parsed.piece, 5);
assert!(parsed.data.is_none());
}
#[test]
fn test_metadata_message_data() {
let data = vec![1, 2, 3, 4, 5];
let msg = MetadataMessage::data(0, 5, data.clone());
let encoded = msg.encode();
let parsed = MetadataMessage::parse(&encoded).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Data);
assert_eq!(parsed.piece, 0);
assert_eq!(parsed.total_size, Some(5));
assert_eq!(parsed.data, Some(data));
}
#[test]
fn test_metadata_message_reject() {
let msg = MetadataMessage::reject(3);
let encoded = msg.encode();
let parsed = MetadataMessage::parse(&encoded).unwrap();
assert_eq!(parsed.msg_type, MetadataMessageType::Reject);
assert_eq!(parsed.piece, 3);
}
#[tokio::test]
async fn test_metadata_fetcher_pieces() {
let test_metadata = b"d4:name4:test12:piece lengthi16384ee";
let mut hasher = Sha1::new();
hasher.update(test_metadata);
let info_hash: [u8; 20] = hasher.finalize().into();
let fetcher = MetadataFetcher::new(info_hash);
let msg = MetadataMessage::data(0, test_metadata.len(), test_metadata.to_vec());
let complete = fetcher.process_message(msg).await.unwrap();
assert!(complete);
assert!(fetcher.is_complete().await);
let metadata = fetcher.get_metadata().await.unwrap();
assert_eq!(metadata, test_metadata);
}
#[tokio::test]
async fn test_metadata_fetcher_wrong_hash() {
let wrong_hash = [0u8; 20]; let fetcher = MetadataFetcher::new(wrong_hash);
let test_metadata = b"d4:name4:test12:piece lengthi16384ee";
let msg = MetadataMessage::data(0, test_metadata.len(), test_metadata.to_vec());
let complete = fetcher.process_message(msg).await.unwrap();
assert!(!complete);
assert!(!fetcher.is_complete().await);
}
#[test]
fn test_parse_metadata_extension_id() {
let handshake = b"d1:md11:ut_metadatai2eee";
let id = parse_metadata_extension_id(handshake);
assert_eq!(id, Some(2));
}
#[test]
fn test_parse_metadata_size() {
let handshake = b"d13:metadata_sizei12345ee";
let size = parse_metadata_size(handshake);
assert_eq!(size, Some(12345));
}
}