use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::bencode::{decode as bencode_decode, encode as bencode_encode};
use crate::error::{Error, ErrorKind};
use crate::metainfo::Metainfo;
use crate::peer::metadata::{
METADATA_PIECE_SIZE, MetadataData, MetadataRequest, UT_METADATA_EXT, UT_METADATA_ID,
};
use crate::peer::{ExtensionNegotiation, PeerConnection, PeerId, PeerMessage};
use crate::storage::{FileStorageFactory, StorageFactory};
use super::{InfoHash, Session};
const MAX_METADATA_PEERS: usize = 8;
pub struct TorrentBuilder<'s> {
session: &'s Session,
pub(crate) info_hash: InfoHash,
storage_factory: Option<Arc<dyn StorageFactory>>,
metadata_resolved: bool,
magnet_peers: Vec<SocketAddr>,
}
impl<'s> TorrentBuilder<'s> {
pub(crate) fn new(
session: &'s Session, info_hash: InfoHash, metadata_resolved: bool,
magnet_peers: Vec<SocketAddr>,
) -> Self {
TorrentBuilder {
session,
info_hash,
storage_factory: None,
metadata_resolved,
magnet_peers,
}
}
pub fn info_hash(&self) -> InfoHash {
self.info_hash
}
pub async fn resolve_metadata(mut self) -> Result<Self, Error> {
if self.metadata_resolved {
return Ok(self);
}
let needs_resolve = {
let torrents = self.session.torrents().read().unwrap();
let Some(handle) = torrents.get(&self.info_hash) else {
return Err(Error::new(ErrorKind::InvalidInput));
};
handle.metainfo.info.piece_length == 0
};
if needs_resolve {
let addrs: Vec<SocketAddr> = std::mem::take(&mut self.magnet_peers);
if !addrs.is_empty() {
let meta_bytes =
download_metadata_from_peers(self.info_hash, &addrs, PeerId::random()).await?;
let new_meta = Metainfo::try_from(&meta_bytes[..])?;
{
let mut torrents = self.session.torrents().write().unwrap();
match torrents.get_mut(&self.info_hash) {
Some(handle) => handle.metainfo = new_meta,
None => {
return Err(Error::new(ErrorKind::InvalidInput));
}
}
}
}
}
self.metadata_resolved = true;
Ok(self)
}
pub fn download_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.storage_factory = Some(Arc::new(FileStorageFactory::new(dir)));
self
}
pub fn storage(mut self, factory: Arc<dyn StorageFactory>) -> Self {
self.storage_factory = Some(factory);
self
}
pub async fn start(mut self) -> Result<InfoHash, Error> {
if !self.metadata_resolved {
self = self.resolve_metadata().await?;
}
if !self.magnet_peers.is_empty() {
let peer_mgr = {
let torrents = self.session.torrents().read().unwrap();
torrents.get(&self.info_hash).map(|h| h.peer_mgr.clone())
};
if let Some(peer_mgr) = peer_mgr {
peer_mgr
.write()
.await
.add_peers(std::mem::take(&mut self.magnet_peers));
}
}
{
let torrents = self.session.torrents().read().unwrap();
let active_count = torrents.values().filter(|h| h.task.is_some()).count();
let limit = self.session.config().max_active_torrents;
if limit > 0 && active_count >= limit {
return Err(Error::new(ErrorKind::InvalidInput));
}
}
let factory = match &self.storage_factory {
Some(f) => f.clone(),
None => return Ok(self.info_hash), };
let info = {
let torrents = self.session.torrents().read().unwrap();
match torrents.get(&self.info_hash) {
Some(handle) => handle.metainfo.info.clone(),
None => {
return Err(Error::new(ErrorKind::InvalidInput));
}
}
};
let storage = factory.create(&info).await?;
storage.prepare().await?;
{
let mut torrents = self.session.torrents().write().unwrap();
match torrents.get_mut(&self.info_hash) {
Some(handle) => handle.activate(storage, self.session.config()),
None => {
return Err(Error::new(ErrorKind::InvalidInput));
}
}
}
Ok(self.info_hash)
}
}
async fn download_metadata_from_peers(
info_hash: [u8; 20], addrs: &[SocketAddr], our_peer_id: PeerId,
) -> Result<Vec<u8>, Error> {
let limit = addrs.len().min(MAX_METADATA_PEERS);
let mut set = JoinSet::new();
for &addr in &addrs[..limit] {
set.spawn(download_metadata_from_peer(addr, info_hash, our_peer_id));
}
while let Some(result) = set.join_next().await {
match result {
Ok(Ok(bytes)) => return Ok(bytes),
Ok(Err(e)) => {
tracing::debug!("peer metadata download attempt failed: {}", e);
continue;
}
Err(join_err) => {
tracing::warn!("metadata download task panicked: {}", join_err);
continue;
}
}
}
Err(Error::new(ErrorKind::PeerConnectionClosed))
}
async fn download_metadata_from_peer(
addr: SocketAddr, info_hash: [u8; 20], our_peer_id: PeerId,
) -> Result<Vec<u8>, Error> {
let conn = PeerConnection::connect(addr, info_hash, our_peer_id).await?;
let mut our_neg = ExtensionNegotiation::new();
our_neg.add_extension(UT_METADATA_EXT, UT_METADATA_ID);
let handshake_data = our_neg.to_bencode();
let handshake_bytes = bencode_encode(&handshake_data);
conn.send(&PeerMessage::Extended {
ext_id: 0,
data: handshake_bytes,
})
.await?;
let (remote_ext_id, metadata_size) = loop {
match conn.recv().await? {
PeerMessage::Extended { ext_id: 0, data } => {
let (ben, _) = bencode_decode(&data)
.map_err(|_| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
let neg = ExtensionNegotiation::from_bencode(&ben)
.map_err(|_| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
let ext_id = neg.m.get(UT_METADATA_EXT).copied();
let size = neg.metadata_size.map(|s| s as u64);
break (ext_id, size);
}
PeerMessage::KeepAlive
| PeerMessage::Bitfield(_)
| PeerMessage::Unchoke
| PeerMessage::Have(_) => continue,
_ => return Err(Error::new(ErrorKind::PeerInvalidExtendedMessage)),
}
};
let ext_id = remote_ext_id.ok_or_else(|| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
let total_size =
metadata_size.ok_or_else(|| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
let num_pieces = total_size.div_ceil(METADATA_PIECE_SIZE);
let piece_size = METADATA_PIECE_SIZE as usize;
let mut buf = vec![0u8; total_size as usize];
for piece_idx in 0..num_pieces as u32 {
let req = MetadataRequest { piece: piece_idx };
let req_ben = req.to_bencode();
conn.send(&PeerMessage::Extended {
ext_id,
data: bencode_encode(&req_ben),
})
.await?;
let resp = conn.recv().await?;
match resp {
PeerMessage::Extended {
ext_id: resp_id,
data,
} if resp_id == ext_id => {
let (dict, raw_data) = split_bep9_data(&data)?;
let (ben, _) = bencode_decode(&dict)
.map_err(|_| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
if MetadataData::is_reject(&ben) {
return Err(Error::new(ErrorKind::PeerInvalidExtendedMessage));
}
let piece = MetadataData::from_bencode(&ben, raw_data)?;
let offset = piece.piece as usize * piece_size;
let end = (offset + piece.data.len()).min(buf.len());
buf[offset..end].copy_from_slice(&piece.data);
}
_ => return Err(Error::new(ErrorKind::PeerInvalidExtendedMessage)),
}
}
Ok(buf)
}
fn split_bep9_data(data: &[u8]) -> Result<(Vec<u8>, Vec<u8>), Error> {
let (_, rest) =
bencode_decode(data).map_err(|_| Error::new(ErrorKind::PeerInvalidExtendedMessage))?;
let dict_len = data.len() - rest.len();
Ok((data[..dict_len].to_vec(), rest.to_vec()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_dict_with_data() {
let piece = MetadataData {
piece: 0,
total_size: 42,
data: b"hello".to_vec(),
};
let ben = piece.to_bencode_with_data();
let mut data = bencode_encode(&ben);
data.extend_from_slice(&piece.data);
let (parsed_dict, parsed_raw) = split_bep9_data(&data).unwrap();
assert_eq!(parsed_dict, bencode_encode(&ben));
assert_eq!(parsed_raw, piece.data);
}
#[test]
fn split_empty_raw_data() {
let piece = MetadataData {
piece: 0,
total_size: 0,
data: vec![],
};
let ben = piece.to_bencode_with_data();
let data = bencode_encode(&ben);
let (parsed_dict, parsed_raw) = split_bep9_data(&data).unwrap();
assert_eq!(parsed_dict, data);
assert!(parsed_raw.is_empty());
}
#[test]
fn split_truncated_dict_errors() {
let data = b"d8:msg_typei1e5:piecei0e"; assert!(split_bep9_data(data).is_err());
}
#[test]
fn split_plain_bytes_errors() {
assert!(split_bep9_data(b"not a dict").is_err());
}
}