use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use sha1::{Digest, Sha1};
use crate::error::Error;
use crate::peer::PeerMessage;
use crate::piece::EndGame;
use super::DownloadLoop;
use super::types::{ActiveDownload, BLOCK_SIZE};
impl DownloadLoop {
pub(super) async fn fill_pipelines(&mut self) -> Result<(), Error> {
let num_pieces = self.metainfo.info.num_pieces();
let mut availability = vec![0usize; num_pieces];
for peer in self.peers.values() {
if peer.am_choked || peer.bitfield.is_empty() {
continue;
}
for (i, &has) in peer.bitfield.iter().enumerate() {
if i >= num_pieces {
break;
}
if has {
availability[i] += 1;
}
}
}
let our_bf = {
let pm = self.piece_mgr.read().await;
pm.bitfield().to_vec()
};
let missing_count = our_bf.iter().filter(|&&b| !b).count();
let in_endgame = missing_count > 0 && missing_count < self.endgame_threshold;
if in_endgame {
self.selector = Box::new(EndGame);
}
let peer_addrs: Vec<SocketAddr> = self.peers.keys().copied().collect();
for addr in peer_addrs {
let can_req = self.peers.get(&addr).is_some_and(|p| p.can_request());
if !can_req {
continue;
}
let block_opt = self.find_block_for_peer(&addr, in_endgame);
let (index, begin) = if let Some(blk) = block_opt {
blk
} else if self.active_downloads.len() < self.max_concurrent_pieces {
let peer_has_bitfield = self
.peers
.get(&addr)
.is_some_and(|p| !p.bitfield.is_empty());
if !peer_has_bitfield {
continue;
}
let selected = self.selector.select(&our_bf, &availability);
if let Some(idx) = selected {
if self.active_downloads.contains_key(&idx) {
continue;
}
let piece_len = self.piece_len_for_index(idx);
if piece_len == 0 {
continue;
}
let dl = ActiveDownload::new(idx, piece_len, BLOCK_SIZE);
#[allow(clippy::unwrap_used)]
let blk_begin = dl.next_unrequested().unwrap();
self.active_downloads.insert(idx, dl);
(idx, blk_begin)
} else {
continue;
}
} else {
continue;
};
let dl = match self.active_downloads.get(&index) {
Some(d) => d,
None => continue,
};
let len = dl.block_len(begin);
if len == 0 {
continue;
}
let msg = PeerMessage::Request {
index,
begin,
length: len,
};
self.peer_mgr.read().await.send_to(&addr, &msg).await?;
if let Some(peer) = self.peers.get_mut(&addr) {
peer.push_request(index, begin);
}
if let Some(dl) = self.active_downloads.get_mut(&index) {
dl.mark_requested(begin, addr);
}
}
Ok(())
}
pub(super) fn find_block_for_peer(
&self, addr: &SocketAddr, in_endgame: bool,
) -> Option<(u32, u32)> {
let peer = self.peers.get(addr)?;
if peer.bitfield.is_empty() {
return None;
}
for (idx, dl) in &self.active_downloads {
let idx_usize = *idx as usize;
if idx_usize >= peer.bitfield.len() || !peer.bitfield[idx_usize] {
continue;
}
if let Some(begin) = dl.next_unrequested() {
return Some((*idx, begin));
}
if in_endgame {
for (block_i, assigned) in dl.requested.iter().enumerate() {
if assigned.as_ref() == Some(addr) {
continue;
}
if assigned.is_some() {
return Some((*idx, block_i as u32 * dl.block_size));
}
}
}
}
None
}
pub(super) async fn expire_stale_requests(&mut self) {
let now = Instant::now();
let timeout = self.request_timeout;
let mut dead_peers = Vec::new();
for (addr, peer) in &mut self.peers {
let had_requests = peer.pipeline.iter().any(Option::is_some);
if !had_requests {
continue;
}
let mut all_expired = true;
for slot in &mut peer.pipeline {
if let Some((index, begin, sent_at)) = *slot {
if now.duration_since(sent_at) > timeout {
if let Some(dl) = self.active_downloads.get_mut(&index) {
let block_idx = (begin / dl.block_size) as usize;
if block_idx < dl.requested.len() {
dl.requested[block_idx] = None;
}
}
*slot = None;
} else {
all_expired = false;
}
}
}
if all_expired {
dead_peers.push(*addr);
}
}
for addr in &dead_peers {
for dl in self.active_downloads.values_mut() {
for assigned in &mut dl.requested {
if *assigned == Some(*addr) {
*assigned = None;
}
}
}
self.peers.remove(addr);
self.peer_mgr.write().await.remove_peer(addr);
}
}
pub(super) async fn verify_and_complete_piece(&mut self, index: u32) -> Result<bool, Error> {
let piece_len = self.piece_len_for_index(index) as usize;
let expected = match self.metainfo.info.pieces.get(index as usize) {
Some(h) => *h,
None => return Ok(false),
};
let hash_ok = match self.active_downloads.get(&index) {
Some(dl) => verify_piece_hash(&dl.data[..piece_len], expected),
None => return Ok(false),
};
if hash_ok {
let data = match self.active_downloads.get(&index) {
Some(dl) => dl.data[..piece_len].to_vec(),
None => return Ok(false),
};
{
let mut pm = self.piece_mgr.write().await;
pm.set_piece(index);
}
if self.piece_cache.len() >= self.piece_cache_size {
self.piece_cache.remove(0);
}
self.piece_cache.push((index, Arc::new(data)));
self.active_downloads.remove(&index);
Ok(true)
} else {
let mut penalized: HashSet<SocketAddr> = HashSet::new();
if let Some(dl) = self.active_downloads.get(&index) {
for addr in dl.requested.iter().flatten() {
if penalized.insert(*addr) {
if let Some(peer) = self.peers.get_mut(addr) {
peer.corrupt_blocks += 1;
tracing::warn!(
"peer {} sent corrupt data ({} strikes)",
addr,
peer.corrupt_blocks
);
}
}
}
}
let mut ban: Vec<SocketAddr> = Vec::new();
for (addr, peer) in &self.peers {
if peer.corrupt_blocks >= self.corrupt_ban_threshold {
ban.push(*addr);
}
}
for addr in &ban {
tracing::warn!("banning peer {} for repeated corrupt data", addr);
self.peers.remove(addr);
self.peer_mgr.write().await.remove_peer(addr);
}
self.active_downloads.remove(&index);
Ok(false)
}
}
pub(super) async fn broadcast_have(&self, index: u32) -> Result<(), Error> {
let msg = PeerMessage::Have(index);
let pm = self.peer_mgr.read().await;
for addr in pm.connection_addrs() {
let _ = pm.send_to(&addr, &msg).await;
}
Ok(())
}
pub(super) fn piece_len_for_index(&self, index: u32) -> u64 {
let idx = index as u64;
let num_pieces = self.metainfo.info.num_pieces() as u64;
let piece_length = self.metainfo.info.piece_length;
if idx >= num_pieces {
return 0;
}
let start = idx * piece_length;
if idx == num_pieces - 1 {
self.metainfo.info.total_size() - start
} else {
piece_length
}
}
}
pub(super) fn verify_piece_hash(data: &[u8], expected: [u8; 20]) -> bool {
let mut hasher = Sha1::new();
hasher.update(data);
let computed: [u8; 20] = hasher.finalize().into();
computed == expected
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_piece_hash_match() {
let data = b"hello world test piece data";
let expected = {
let mut h = Sha1::new();
h.update(data);
h.finalize().into()
};
assert!(verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_mismatch() {
let data = b"hello world";
let expected = [0xFFu8; 20];
assert!(!verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_empty() {
let data = b"";
let expected = {
let mut h = Sha1::new();
h.update(b"");
h.finalize().into()
};
assert!(verify_piece_hash(data, expected));
}
#[test]
fn verify_piece_hash_binary_data() {
let data = [0x00u8, 0xFF, 0x42, 0x7F, 0x80];
let expected = {
let mut h = Sha1::new();
h.update(&data);
h.finalize().into()
};
assert!(verify_piece_hash(&data, expected));
}
#[test]
fn verify_piece_hash_wrong_hash() {
let data = b"correct data";
let wrong_data = b"wrong data";
let wrong_hash = {
let mut h = Sha1::new();
h.update(wrong_data);
h.finalize().into()
};
assert!(!verify_piece_hash(data, wrong_hash));
}
#[test]
fn block_len_for_short_last_block() {
let piece_len: u64 = 50000;
let block_size: u32 = 16384;
let remaining = piece_len.saturating_sub(49152);
assert_eq!(remaining.min(block_size as u64) as u32, 848);
}
}