use std::collections::{HashMap, VecDeque};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use crate::error::Error;
use crate::peer::{PeerConnection, PeerId, PeerMessage};
const MAX_RETRIES: u32 = 3;
const PEER_COOLDOWN: Duration = Duration::from_secs(30);
#[derive(Debug)]
struct BackoffState {
attempts: u32,
cooldown_until: Instant,
}
impl BackoffState {
fn new() -> Self {
BackoffState {
attempts: 1,
cooldown_until: Instant::now() + PEER_COOLDOWN,
}
}
fn increment(&mut self) {
self.attempts += 1;
self.cooldown_until = Instant::now() + PEER_COOLDOWN;
}
}
pub(crate) struct PeerManager {
peer_id: PeerId,
info_hash: [u8; 20],
connections: HashMap<SocketAddr, Arc<Mutex<PeerConnection>>>,
pending: VecDeque<SocketAddr>,
max_connections: u32,
backoff: HashMap<SocketAddr, BackoffState>,
}
impl PeerManager {
pub fn new(info_hash: [u8; 20], peer_id: PeerId, max_connections: u32) -> Self {
PeerManager {
peer_id,
info_hash,
connections: HashMap::new(),
pending: VecDeque::new(),
max_connections,
backoff: HashMap::new(),
}
}
pub fn add_peers(&mut self, addrs: Vec<SocketAddr>) {
for addr in addrs {
if !self.connections.contains_key(&addr) && !self.pending.contains(&addr) {
self.pending.push_back(addr);
}
}
}
pub async fn send_to(&self, addr: &SocketAddr, msg: &PeerMessage) -> Result<(), Error> {
if let Some(conn) = self.connections.get(addr) {
let mut guard = conn.lock().await;
guard.send(msg).await
} else {
Ok(())
}
}
pub fn remove_peer(&mut self, addr: &SocketAddr) {
tracing::debug!("peer disconnected: {}", addr);
self.connections.remove(addr);
}
pub fn num_connections(&self) -> usize {
self.connections.len()
}
pub async fn connect_pending(&mut self) -> Vec<SocketAddr> {
let batch_size = (self.max_connections as usize).saturating_sub(self.connections.len());
let drain_count = batch_size.min(self.pending.len());
let raw_batch: Vec<SocketAddr> = self.pending.drain(..drain_count).collect();
if raw_batch.is_empty() {
return vec![];
}
let now = Instant::now();
let mut batch = Vec::with_capacity(raw_batch.len());
for addr in raw_batch {
if let Some(state) = self.backoff.get(&addr) {
if state.cooldown_until > now {
self.pending.push_back(addr);
continue;
}
}
batch.push(addr);
}
if batch.is_empty() {
return vec![];
}
let mut handles = Vec::with_capacity(batch.len());
for &addr in &batch {
let info_hash = self.info_hash;
let peer_id = self.peer_id;
handles.push(tokio::spawn(async move {
let result = PeerConnection::connect(addr, info_hash, peer_id).await;
(addr, result)
}));
}
let mut connected = Vec::new();
for handle in handles {
match handle.await {
Ok((addr, Ok(conn))) => {
tracing::info!("peer connected: {}", addr);
self.connections.insert(addr, Arc::new(Mutex::new(conn)));
self.backoff.remove(&addr); connected.push(addr);
}
Ok((addr, Err(_))) => {
let state = self.backoff.entry(addr).or_insert_with(BackoffState::new);
if state.attempts < MAX_RETRIES {
state.increment();
tracing::debug!(
"re-enqueuing peer {} (attempt {}/{}, cooldown {}s)",
addr,
state.attempts,
MAX_RETRIES,
PEER_COOLDOWN.as_secs()
);
self.pending.push_back(addr);
} else {
tracing::debug!("peer {}: max retries reached, discarding", addr);
self.backoff.remove(&addr);
}
}
Err(_) => {
}
}
}
connected
}
pub fn connection(&self, addr: &SocketAddr) -> Option<Arc<Mutex<PeerConnection>>> {
self.connections.get(addr).cloned()
}
pub fn connection_addrs(&self) -> Vec<SocketAddr> {
self.connections.keys().copied().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_addr(n: u8) -> SocketAddr {
SocketAddr::new(
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, n)),
6881,
)
}
#[test]
fn new_creates_empty() {
let pm = PeerManager::new([0u8; 20], PeerId::random(), 10);
assert_eq!(pm.num_connections(), 0);
assert!(pm.connection_addrs().is_empty());
}
#[test]
fn add_peers_to_pending() {
let mut pm = PeerManager::new([0u8; 20], PeerId::random(), 10);
pm.add_peers(vec![test_addr(1), test_addr(2)]);
assert_eq!(pm.num_connections(), 0);
}
#[test]
fn at_capacity_precondition() {
let pm = PeerManager {
peer_id: PeerId::random(),
info_hash: [0u8; 20],
connections: HashMap::new(),
pending: vec![test_addr(1)].into_iter().collect(),
max_connections: 0,
backoff: HashMap::new(),
};
assert_eq!(pm.max_connections, 0);
}
#[test]
fn remove_peer_nonexistent() {
let mut pm = PeerManager::new([0u8; 20], PeerId::random(), 10);
pm.remove_peer(&test_addr(99)); assert_eq!(pm.num_connections(), 0);
}
#[test]
fn connection_addrs_empty() {
let pm = PeerManager::new([0u8; 20], PeerId::random(), 10);
assert!(pm.connection_addrs().is_empty());
}
}