use std::collections::{HashMap, HashSet};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::Arc;
use std::time::{Duration, Instant};
use rand::RngExt;
use tokio::sync::{Mutex, RwLock, mpsc};
use crate::dht::{DhtRpc, generate_node_id, get_peers, krpc};
use crate::error::Error;
use crate::metainfo::Metainfo;
use crate::peer::{PeerConnection, PeerId, PeerMessage};
use crate::piece::{EndGame, PieceManager, PieceSelector};
use crate::storage::{FileStorage, Storage};
use crate::tracker::{AnnounceEvent, AnnounceRequest, Tracker};
use super::peer_manager::PeerManager;
use super::torrent::TorrentCommand;
use super::upload::UploadManager;
use super::{TorrentState, TorrentStatus};
pub(crate) enum PeerEvent {
Message(PeerMessage),
Disconnected,
}
pub(crate) struct PeerInfo {
bitfield: Vec<bool>,
am_choked: bool,
#[allow(dead_code)]
am_interested: bool,
peer_interested: bool,
uploaded_bytes: u64,
downloaded_bytes: u64,
}
impl PeerInfo {
fn new() -> Self {
PeerInfo {
bitfield: Vec::new(),
am_choked: true,
am_interested: false,
peer_interested: false,
uploaded_bytes: 0,
downloaded_bytes: 0,
}
}
}
pub(crate) struct ActiveDownload {
#[allow(dead_code)]
index: u32,
data: Vec<u8>,
received: Vec<bool>,
block_size: u32,
#[allow(dead_code)]
num_blocks: usize,
requested_from: HashSet<SocketAddr>,
}
pub(crate) struct DownloadLoop {
pub info_hash: [u8; 20],
pub metainfo: Metainfo,
pub storage: Arc<FileStorage>,
pub piece_mgr: Arc<RwLock<PieceManager>>,
pub peer_mgr: Arc<RwLock<PeerManager>>,
pub status: Arc<RwLock<TorrentStatus>>,
pub control_rx: mpsc::Receiver<TorrentCommand>,
pub(crate) peer_id: PeerId,
pub(crate) listen_port: u16,
pub(crate) tracker: Option<Tracker>,
pub(crate) next_announce: Option<Instant>,
pub(crate) has_announced: bool,
pub(crate) announced_completed: bool,
pub(crate) peers: HashMap<SocketAddr, PeerInfo>,
pub(crate) active_downloads: HashMap<u32, ActiveDownload>,
pub(crate) selector: Box<dyn PieceSelector>,
pub(crate) peer_msg_rx: mpsc::UnboundedReceiver<(SocketAddr, PeerEvent)>,
pub(crate) peer_msg_tx: mpsc::UnboundedSender<(SocketAddr, PeerEvent)>,
pub(crate) upload_mgr: Arc<RwLock<UploadManager>>,
pub(crate) total_downloaded: u64,
pub(crate) total_uploaded: u64,
pub(crate) last_downloaded: u64,
pub(crate) last_uploaded: u64,
pub(crate) tick_count: usize,
pub(crate) enable_dht: bool,
pub(crate) piece_cache: HashMap<u32, Arc<Vec<u8>>>,
pub(crate) dht_rpc: Option<Arc<DhtRpc>>,
pub(crate) dht_node_id: [u8; 20],
pub(crate) next_dht_search: Option<Instant>,
}
const ENDGAME_THRESHOLD: usize = 10;
const DHT_BOOTSTRAP: &[(&str, u16)] = &[
("router.bittorrent.com", 6881),
("dht.transmissionbt.com", 6881),
];
impl DownloadLoop {
pub async fn run(&mut self) {
if self.enable_dht {
self.init_dht().await;
}
{
let mut status = self.status.write().await;
status.state = TorrentState::Downloading;
}
let tick_interval = Duration::from_secs(1);
loop {
tokio::select! {
cmd = self.control_rx.recv() => {
match cmd {
Some(TorrentCommand::Pause) => {
let mut status = self.status.write().await;
status.state = TorrentState::Paused;
}
Some(TorrentCommand::Resume) => {
let mut status = self.status.write().await;
status.state = TorrentState::Downloading;
}
Some(TorrentCommand::Cancel) | None => {
let _ = self.announce_to_tracker(AnnounceEvent::Stopped).await;
break;
}
}
}
Some((addr, event)) = self.peer_msg_rx.recv() => {
self.handle_peer_event(addr, event).await;
}
_ = tokio::time::sleep(tick_interval) => {
if let Err(e) = self.tick().await {
tracing::warn!("download tick failed: {}", e);
let mut status = self.status.write().await;
status.state = TorrentState::Error;
}
}
}
}
}
async fn tick(&mut self) -> Result<(), Error> {
tracing::debug!("download tick");
self.announce_if_needed().await;
if self.enable_dht {
self.dht_search_if_needed();
}
let newly_connected = {
let mut pm = self.peer_mgr.write().await;
pm.connect_pending().await
};
for addr in &newly_connected {
let conn_arc = {
let pm = self.peer_mgr.read().await;
pm.connection(addr)
};
if let Some(conn_arc) = conn_arc {
self.spawn_peer_reader(*addr, conn_arc);
self.peers.insert(*addr, PeerInfo::new());
self.send_bitfield(*addr).await?;
}
}
if self.active_downloads.is_empty() {
self.maybe_request_piece().await?;
}
self.tick_count += 1;
{
let mut status = self.status.write().await;
let pm = self.piece_mgr.read().await;
status.progress = pm.progress();
status.num_peers = self.peer_mgr.read().await.num_connections();
status.download_rate = (self.total_downloaded - self.last_downloaded) as f64;
status.upload_rate = (self.total_uploaded - self.last_uploaded) as f64;
self.last_downloaded = self.total_downloaded;
self.last_uploaded = self.total_uploaded;
}
let is_seeding = {
let pm = self.piece_mgr.read().await;
let complete = pm.missing_pieces().is_empty();
if complete {
let mut status = self.status.write().await;
status.state = TorrentState::Seeding;
}
complete
};
if is_seeding && !self.announced_completed {
let _ = self.announce_to_tracker(AnnounceEvent::Completed).await;
self.announced_completed = true;
}
if self.tick_count.is_multiple_of(10) {
self.run_choke_unchoke().await?;
}
Ok(())
}
async fn handle_peer_event(&mut self, addr: SocketAddr, event: PeerEvent) {
match event {
PeerEvent::Disconnected => {
self.peers.remove(&addr);
self.peer_mgr.write().await.remove_peer(&addr);
let affected: Vec<u32> = self
.active_downloads
.iter()
.filter(|(_, d)| d.requested_from.contains(&addr))
.map(|(i, _)| *i)
.collect();
for idx in affected {
self.active_downloads.remove(&idx);
}
}
PeerEvent::Message(msg) => {
if let Err(_e) = self.handle_peer_message(addr, msg).await {
self.peers.remove(&addr);
self.peer_mgr.write().await.remove_peer(&addr);
}
}
}
}
async fn handle_peer_message(
&mut self,
addr: SocketAddr,
msg: PeerMessage,
) -> Result<(), Error> {
let peer = match self.peers.get_mut(&addr) {
Some(p) => p,
None => return Ok(()),
};
match msg {
PeerMessage::KeepAlive => {}
PeerMessage::Choke => {
peer.am_choked = true;
}
PeerMessage::Unchoke => {
peer.am_choked = false;
}
PeerMessage::Interested => {
peer.peer_interested = true;
}
PeerMessage::NotInterested => {
peer.peer_interested = false;
}
PeerMessage::Have(index) => {
let idx = index as usize;
if idx < peer.bitfield.len() {
peer.bitfield[idx] = true;
}
}
PeerMessage::Bitfield(bytes) => {
let num_pieces = self.metainfo.info.num_pieces();
peer.bitfield = parse_bitfield(&bytes, num_pieces);
}
PeerMessage::Piece { index, begin, data } => {
self.storage.write_block(index, begin, &data).await?;
self.total_downloaded += data.len() as u64;
if let Some(peer) = self.peers.get_mut(&addr) {
peer.downloaded_bytes += data.len() as u64;
}
let piece_complete = if let Some(dl) = self.active_downloads.get_mut(&index) {
let block_idx = (begin / dl.block_size) as usize;
if block_idx < dl.received.len() {
let start = begin as usize;
let end = start + data.len();
if end <= dl.data.len() {
dl.data[start..end].copy_from_slice(&data);
}
dl.received[block_idx] = true;
}
dl.received.iter().all(|&r| r)
} else {
false
};
if piece_complete && self.verify_and_complete_piece(index).await? {
self.broadcast_have(index).await?;
}
}
PeerMessage::Request {
index,
begin,
length,
} => {
let is_unchoked = {
let um = self.upload_mgr.read().await;
um.is_unchoked(&addr)
};
if !is_unchoked {
return Ok(());
}
let piece_data = if let Some(cached) = self.piece_cache.get(&index) {
Arc::clone(cached)
} else {
let piece_len = self.piece_len_for_index(index) as usize;
let mut piece_buf = vec![0u8; piece_len];
self.storage.read_piece(index, &mut piece_buf).await?;
Arc::new(piece_buf)
};
let start = begin as usize;
let end = (start + length as usize).min(piece_data.len());
if start < end {
let block_data = piece_data[start..end].to_vec();
let msg = PeerMessage::Piece {
index,
begin,
data: block_data,
};
self.peer_mgr.read().await.send_to(&addr, &msg).await?;
self.total_uploaded += (end - start) as u64;
if let Some(peer) = self.peers.get_mut(&addr) {
peer.uploaded_bytes += (end - start) as u64;
}
}
}
PeerMessage::Cancel { .. } | PeerMessage::Port(_) => {
}
}
Ok(())
}
async fn maybe_request_piece(&mut self) -> Result<(), Error> {
let missing = {
let pm = self.piece_mgr.read().await;
pm.missing_pieces()
};
if missing.is_empty() {
return Ok(());
}
let remaining = missing.len();
let in_endgame = remaining < ENDGAME_THRESHOLD;
if in_endgame {
self.selector = Box::new(EndGame);
}
let local_bf = {
let pm = self.piece_mgr.read().await;
pm.bitfield().to_vec()
};
let mut piece_idx: Option<u32> = None;
for peer in self.peers.values() {
if peer.am_choked || peer.bitfield.is_empty() {
continue;
}
if let Some(idx) = self.selector.select(&peer.bitfield, &local_bf) {
piece_idx = Some(idx);
break;
}
}
if let Some(idx) = piece_idx {
if in_endgame {
let request_addrs: Vec<SocketAddr> = self
.peers
.iter()
.filter(|(_, p)| {
!p.am_choked
&& !p.bitfield.is_empty()
&& (idx as usize) < p.bitfield.len()
&& p.bitfield[idx as usize]
})
.map(|(a, _)| *a)
.collect();
for addr in &request_addrs {
self.request_piece_from(addr, idx).await?;
}
} else if let Some((addr, _)) = self.peers.iter().find(|(_, p)| {
!p.am_choked
&& !p.bitfield.is_empty()
&& (idx as usize) < p.bitfield.len()
&& p.bitfield[idx as usize]
}) {
let addr = *addr;
self.request_piece_from(&addr, idx).await?;
}
}
Ok(())
}
async fn request_piece_from(&mut self, addr: &SocketAddr, index: u32) -> Result<(), Error> {
let piece_len = self.piece_len_for_index(index);
let block_size: u32 = 16 * 1024;
let block_size_u64 = block_size as u64;
let num_blocks = piece_len.div_ceil(block_size_u64) as usize;
let mut dl = ActiveDownload {
index,
data: vec![0u8; piece_len as usize],
received: vec![false; num_blocks],
block_size,
num_blocks,
requested_from: HashSet::new(),
};
dl.requested_from.insert(*addr);
let pm = self.peer_mgr.read().await;
for block_idx in 0..num_blocks {
let begin = block_idx as u32 * block_size;
let len = std::cmp::min(block_size_u64, piece_len - begin as u64) as u32;
if len == 0 {
break;
}
let msg = PeerMessage::Request {
index,
begin,
length: len,
};
pm.send_to(addr, &msg).await?;
}
self.active_downloads.insert(index, dl);
Ok(())
}
async fn verify_and_complete_piece(&mut self, index: u32) -> Result<bool, Error> {
let piece_len = self.piece_len_for_index(index) as usize;
let data = match self.active_downloads.get(&index) {
Some(dl) => dl.data[..piece_len].to_vec(),
None => return Ok(false),
};
let expected = self.metainfo.info.pieces[index as usize];
if verify_piece_hash(&data, expected) {
{
let mut pm = self.piece_mgr.write().await;
pm.set_piece(index);
}
self.piece_cache.insert(index, Arc::new(data));
self.active_downloads.remove(&index);
Ok(true)
} else {
self.active_downloads.remove(&index);
Ok(false)
}
}
async fn broadcast_have(&self, index: u32) -> Result<(), Error> {
let msg = PeerMessage::Have(index);
let pm = self.peer_mgr.read().await;
for addr in pm.connection_addrs() {
let _ = pm.send_to(&addr, &msg).await;
}
Ok(())
}
async fn announce_if_needed(&mut self) {
if self.tracker.is_none() {
return;
}
let should_announce = match self.next_announce {
None => true, Some(t) => Instant::now() >= t,
};
if !should_announce {
return;
}
let event = if !self.has_announced {
AnnounceEvent::Started
} else {
AnnounceEvent::None
};
match self.announce_to_tracker(event).await {
Ok(()) => {
self.has_announced = true;
}
Err(e) => {
let _ = e;
}
}
}
async fn announce_to_tracker(&mut self, event: AnnounceEvent) -> Result<(), Error> {
tracing::debug!("announcing to tracker (event: {:?})", event);
let tracker = match self.tracker.as_ref() {
Some(t) => t,
None => return Ok(()),
};
let (downloaded, left) = {
let pm = self.piece_mgr.read().await;
let have = pm.completed_pieces().len() as u64;
let piece_len = self.metainfo.info.piece_length;
let total_size = self.metainfo.info.total_size();
let d = have * piece_len;
let l = total_size.saturating_sub(d);
(d, l)
};
let mut req = AnnounceRequest::new(self.info_hash, self.peer_id, self.listen_port);
req.downloaded = downloaded;
req.uploaded = self.total_uploaded;
req.left = left;
req.event = event;
match tracker.announce(&req).await {
Ok(resp) => {
tracing::debug!("tracker announce: {} peers", resp.peers.len());
let interval = resp.min_interval.unwrap_or(resp.interval);
self.next_announce = Some(Instant::now() + Duration::from_secs(interval as u64));
if !resp.peers.is_empty() {
let mut pm = self.peer_mgr.write().await;
pm.add_peers(resp.peers);
}
Ok(())
}
Err(e) => {
self.next_announce = Some(Instant::now() + Duration::from_secs(30));
tracing::warn!("tracker announce failed: {}", e);
Err(e)
}
}
}
async fn run_choke_unchoke(&mut self) -> Result<(), Error> {
let max_uploads = {
let um = self.upload_mgr.read().await;
um.max_uploads()
};
if max_uploads == 0 {
return Ok(());
}
let mut peer_stats: Vec<(SocketAddr, u64)> = self
.peers
.iter()
.map(|(addr, info)| (*addr, info.uploaded_bytes))
.collect();
peer_stats.sort_by_key(|(_, u)| std::cmp::Reverse(*u));
let top_count = ((max_uploads - 1) as usize).min(peer_stats.len());
let mut to_unchoke: HashSet<SocketAddr> =
peer_stats.iter().take(top_count).map(|(a, _)| *a).collect();
let candidates: Vec<SocketAddr> =
peer_stats.iter().skip(top_count).map(|(a, _)| *a).collect();
if !candidates.is_empty() {
let idx = rand::rng().random_range(0..candidates.len());
to_unchoke.insert(candidates[idx]);
}
for addr in self.peers.keys() {
if to_unchoke.len() >= max_uploads as usize {
break;
}
to_unchoke.insert(*addr);
}
let mut um = self.upload_mgr.write().await;
let pm = self.peer_mgr.read().await;
for addr in &to_unchoke {
if !um.is_unchoked(addr) {
um.unchoke(*addr);
let _ = pm.send_to(addr, &PeerMessage::Unchoke).await;
}
}
let previously_unchoked: Vec<SocketAddr> = um.unchoked_peers().copied().collect();
for addr in previously_unchoked {
if !to_unchoke.contains(&addr) {
um.choke(&addr);
let _ = pm.send_to(&addr, &PeerMessage::Choke).await;
}
}
Ok(())
}
async fn init_dht(&mut self) {
let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0));
match DhtRpc::new(bind_addr).await {
Ok(rpc) => {
self.dht_node_id = generate_node_id();
self.dht_rpc = Some(Arc::new(rpc));
self.next_dht_search = Some(Instant::now());
}
Err(_) => {
}
}
}
fn dht_search_if_needed(&mut self) {
let should_search = match self.next_dht_search {
None => return,
Some(t) => Instant::now() >= t,
};
if !should_search {
return;
}
let rpc = match self.dht_rpc.as_ref() {
Some(r) => Arc::clone(r),
None => return,
};
self.next_dht_search = Some(Instant::now() + Duration::from_secs(30));
let info_hash = self.info_hash;
let node_id = self.dht_node_id;
let peer_mgr = self.peer_mgr.clone();
tokio::spawn(async move {
for (host, port) in DHT_BOOTSTRAP {
let addr = match tokio::net::lookup_host((*host, *port)).await {
Ok(mut addrs) => match addrs.next() {
Some(a) => a,
None => continue,
},
Err(_) => continue,
};
let tid: krpc::TransactionId = rand::random();
match get_peers(&rpc, addr, tid, &node_id, &info_hash).await {
Ok(krpc::GetPeersResult::Values { peers, .. }) => {
if !peers.is_empty() {
peer_mgr.write().await.add_peers(peers);
}
}
Ok(krpc::GetPeersResult::Nodes(_nodes)) => {
}
Err(_) => continue,
}
}
});
}
fn spawn_peer_reader(&self, addr: SocketAddr, conn_arc: Arc<Mutex<PeerConnection>>) {
let tx = self.peer_msg_tx.clone();
tokio::spawn(async move {
loop {
let msg_result = {
let mut conn = conn_arc.lock().await;
conn.recv().await
};
match msg_result {
Ok(msg) => {
if tx.send((addr, PeerEvent::Message(msg))).is_err() {
break; }
}
Err(_) => {
let _ = tx.send((addr, PeerEvent::Disconnected));
break;
}
}
}
});
}
async fn send_bitfield(&self, addr: SocketAddr) -> Result<(), Error> {
let piece_mgr = self.piece_mgr.clone();
let peer_mgr = self.peer_mgr.clone();
let bf_bytes = {
let pm = piece_mgr.read().await;
pm.to_bitfield()
};
let pm = peer_mgr.read().await;
if !bf_bytes.is_empty() {
pm.send_to(&addr, &PeerMessage::Bitfield(bf_bytes)).await?;
}
pm.send_to(&addr, &PeerMessage::Interested).await?;
Ok(())
}
fn piece_len_for_index(&self, index: u32) -> u64 {
let idx = index as u64;
let num_pieces = self.metainfo.info.num_pieces() as u64;
let piece_length = self.metainfo.info.piece_length;
if idx >= num_pieces {
return 0;
}
let start = idx * piece_length;
if idx == num_pieces - 1 {
self.metainfo.info.total_size() - start
} else {
piece_length
}
}
}
fn verify_piece_hash(data: &[u8], expected: [u8; 20]) -> bool {
use sha1::{Digest, Sha1};
let mut hasher = Sha1::new();
hasher.update(data);
let computed: [u8; 20] = hasher.finalize().into();
computed == expected
}
#[cfg(test)]
mod unit_tests {
use super::*;
#[test]
fn peer_info_default_state() {
let pi = PeerInfo::new();
assert!(pi.am_choked);
assert!(!pi.am_interested);
assert!(!pi.peer_interested);
assert!(pi.bitfield.is_empty());
assert_eq!(pi.uploaded_bytes, 0);
assert_eq!(pi.downloaded_bytes, 0);
}
#[test]
fn active_download_has_expected_fields() {
let dl = ActiveDownload {
index: 42,
data: vec![0u8; 16000],
received: vec![false; 1],
block_size: 16384,
num_blocks: 1,
requested_from: HashSet::new(),
};
assert_eq!(dl.index, 42);
assert_eq!(dl.num_blocks, 1);
assert_eq!(dl.block_size, 16384);
assert_eq!(dl.data.len(), 16000);
assert_eq!(dl.received.len(), 1);
assert_eq!(dl.requested_from.len(), 0);
}
}
fn parse_bitfield(bytes: &[u8], num_pieces: usize) -> Vec<bool> {
let mut bf = vec![false; num_pieces];
for (i, have) in bf.iter_mut().enumerate() {
let byte = i / 8;
let bit = 7 - (i % 8);
if byte < bytes.len() {
*have = (bytes[byte] & (1 << bit)) != 0;
}
}
bf
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_piece_hash_match() {
let data = b"hello world test piece data";
let expected = {
use sha1::{Digest, Sha1};
let mut h = Sha1::new();
h.update(data);
h.finalize().into()
};
assert!(verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_mismatch() {
let data = b"hello world";
let expected = [0xFFu8; 20];
assert!(!verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_empty() {
let data = b"";
let expected = {
use sha1::{Digest, Sha1};
let mut h = Sha1::new();
h.update(b"");
h.finalize().into()
};
assert!(verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_binary_data() {
let data = [0x00u8, 0xFF, 0x42, 0x7F, 0x80];
let expected = {
use sha1::{Digest, Sha1};
let mut h = Sha1::new();
h.update(&data);
h.finalize().into()
};
assert!(verify_piece_hash(&data, expected));
}
#[test]
fn verify_piece_hash_wrong_hash() {
let data = b"correct data";
let wrong_data = b"wrong data";
let wrong_hash = {
use sha1::{Digest, Sha1};
let mut h = Sha1::new();
h.update(wrong_data);
h.finalize().into()
};
assert!(!verify_piece_hash(data, wrong_hash));
}
#[test]
fn parse_bitfield_all_set() {
let bytes = vec![0xFF, 0xFF];
let bf = parse_bitfield(&bytes, 16);
assert_eq!(bf.len(), 16);
assert!(bf.iter().all(|&b| b));
}
#[test]
fn parse_bitfield_none_set() {
let bytes = vec![0x00, 0x00];
let bf = parse_bitfield(&bytes, 16);
assert_eq!(bf.len(), 16);
assert!(bf.iter().all(|&b| !b));
}
#[test]
fn parse_bitfield_partial() {
let bytes = vec![0x80, 0x00];
let bf = parse_bitfield(&bytes, 16);
assert_eq!(bf.len(), 16);
assert!(bf[0]);
assert!(!bf[1]);
assert!(!bf[7]);
assert!(!bf[8]);
}
#[test]
fn parse_bitfield_shorter_than_requested() {
let bytes = vec![0xFF];
let bf = parse_bitfield(&bytes, 16);
assert_eq!(bf.len(), 16);
assert!(bf[0..8].iter().all(|&b| b));
assert!(bf[8..16].iter().all(|&b| !b));
}
}