use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{Semaphore, mpsc};
use tracing::trace;
use crate::peer_state::PeerSource;
use crate::peer_states::PeerStates;
use crate::session::{SharedBanManager, SharedIpFilter};
pub(crate) struct ConnectPeer {
pub addr: SocketAddr,
pub source: PeerSource,
pub permit: tokio::sync::OwnedSemaphorePermit,
}
pub(crate) async fn peer_adder_task(
mut rx: mpsc::UnboundedReceiver<SocketAddr>,
semaphore: Arc<Semaphore>,
peer_states: Arc<PeerStates>,
ban_manager: SharedBanManager,
ip_filter: SharedIpFilter,
connect_tx: mpsc::Sender<ConnectPeer>,
) {
loop {
let Some(addr) = rx.recv().await else { return };
if addr.port() == 0 {
trace!(%addr, "peer_adder: port zero, skipping");
continue;
}
if ban_manager.read().is_banned(&addr.ip()) {
trace!(%addr, "peer_adder: banned, skipping");
continue;
}
if ip_filter.read().is_blocked(addr.ip()) {
trace!(%addr, "peer_adder: IP-filtered, skipping");
continue;
}
if peer_states.is_live(&addr) {
trace!(%addr, "peer_adder: already live, skipping");
continue;
}
if peer_states.is_eviction_banned(&addr) {
trace!(%addr, "peer_adder: eviction-banned, skipping");
continue;
}
let Ok(permit) = semaphore.clone().acquire_owned().await else {
return;
};
if !peer_states.mark_connecting(addr) {
drop(permit);
continue;
}
let source = peer_states.source(&addr).unwrap_or(PeerSource::Dht);
if connect_tx
.send(ConnectPeer {
addr,
source,
permit,
})
.await
.is_err()
{
return; }
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
use crate::ban::{BanConfig, BanManager};
use crate::ip_filter::IpFilter;
fn test_ban_manager() -> SharedBanManager {
Arc::new(parking_lot::RwLock::new(BanManager::new(
BanConfig::default(),
)))
}
fn test_ip_filter() -> SharedIpFilter {
Arc::new(parking_lot::RwLock::new(IpFilter::new()))
}
fn test_addr(port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1)), port)
}
fn test_addr_ip(last_octet: u8, port: u16) -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, last_octet)), port)
}
fn spawn_adder(
semaphore: Arc<Semaphore>,
) -> (
Arc<PeerStates>,
mpsc::Receiver<ConnectPeer>,
SharedBanManager,
SharedIpFilter,
tokio::task::JoinHandle<()>,
) {
let (queue_tx, queue_rx) = mpsc::unbounded_channel();
let peer_states = Arc::new(PeerStates::new(queue_tx));
let (connect_tx, connect_rx) = mpsc::channel(64);
let ban_manager = test_ban_manager();
let ip_filter = test_ip_filter();
let handle = tokio::spawn(peer_adder_task(
queue_rx,
semaphore,
Arc::clone(&peer_states),
Arc::clone(&ban_manager),
Arc::clone(&ip_filter),
connect_tx,
));
(peer_states, connect_rx, ban_manager, ip_filter, handle)
}
#[tokio::test]
async fn adder_connects_on_permit_available() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(6881);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let cp = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp.addr, addr);
assert_eq!(cp.source, PeerSource::Tracker);
}
#[tokio::test]
async fn adder_blocks_at_capacity() {
let sem = Arc::new(Semaphore::new(1));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr1 = test_addr_ip(1, 6881);
let addr2 = test_addr_ip(2, 6882);
peer_states.add_if_not_seen(addr1, PeerSource::Dht);
let cp1 = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp1.addr, addr1);
peer_states.add_if_not_seen(addr2, PeerSource::Dht);
let result = tokio::time::timeout(Duration::from_millis(100), connect_rx.recv()).await;
assert!(result.is_err(), "should have timed out");
drop(cp1.permit);
let cp2 = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp2.addr, addr2);
}
#[tokio::test]
async fn adder_deduplicates_peers() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(6881);
assert!(peer_states.add_if_not_seen(addr, PeerSource::Tracker));
assert!(!peer_states.add_if_not_seen(addr, PeerSource::Dht));
let cp = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp.addr, addr);
let result = tokio::time::timeout(Duration::from_millis(100), connect_rx.recv()).await;
assert!(result.is_err(), "duplicate should not produce ConnectPeer");
}
#[tokio::test]
async fn adder_skips_banned_peers() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ban_manager, ..) = spawn_adder(sem);
let addr = test_addr(6881);
ban_manager.write().ban(addr.ip());
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let result = tokio::time::timeout(Duration::from_millis(100), connect_rx.recv()).await;
assert!(result.is_err(), "banned peer should be skipped");
}
#[tokio::test]
async fn adder_skips_port_zero() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(0); peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let result = tokio::time::timeout(Duration::from_millis(100), connect_rx.recv()).await;
assert!(result.is_err(), "port-zero peer should be skipped");
}
#[tokio::test]
async fn adder_skips_ip_filtered() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, _, ip_filter, ..) = spawn_adder(sem);
let addr = test_addr(6881);
ip_filter.write().add_rule(
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 0)),
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 255)),
1,
);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let result = tokio::time::timeout(Duration::from_millis(100), connect_rx.recv()).await;
assert!(result.is_err(), "IP-filtered peer should be skipped");
}
#[tokio::test]
async fn adder_revalidates_after_permit_wait() {
let sem = Arc::new(Semaphore::new(1));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr1 = test_addr_ip(1, 6881);
let addr2 = test_addr_ip(2, 6882);
peer_states.add_if_not_seen(addr1, PeerSource::Dht);
let cp1 = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
peer_states.add_if_not_seen(addr2, PeerSource::Dht);
peer_states.mark_connecting(addr2);
peer_states.mark_live(addr2);
drop(cp1.permit);
let result = tokio::time::timeout(Duration::from_millis(200), connect_rx.recv()).await;
assert!(
result.is_err(),
"peer connected during wait should be skipped after revalidation"
);
}
#[tokio::test]
async fn adder_exits_on_connect_channel_close() {
let sem = Arc::new(Semaphore::new(10));
let (queue_tx, queue_rx) = mpsc::unbounded_channel();
let peer_states = Arc::new(PeerStates::new(queue_tx));
let (connect_tx, connect_rx) = mpsc::channel(64);
let ban_manager = test_ban_manager();
let ip_filter = test_ip_filter();
let handle = tokio::spawn(peer_adder_task(
queue_rx,
sem,
Arc::clone(&peer_states),
ban_manager,
ip_filter,
connect_tx,
));
peer_states.add_if_not_seen(test_addr(6881), PeerSource::Tracker);
drop(connect_rx);
tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("timed out")
.expect("task panicked");
}
#[tokio::test]
async fn adder_exits_on_semaphore_close() {
let sem = Arc::new(Semaphore::new(10));
let sem_clone = Arc::clone(&sem);
let (peer_states, _connect_rx, _, _, handle) = spawn_adder(sem);
sem_clone.close();
let addr = test_addr(6881);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
tokio::time::timeout(Duration::from_secs(1), handle)
.await
.expect("timed out")
.expect("task panicked");
}
#[tokio::test]
async fn dead_peer_requeued_after_backoff() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(6881);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let cp = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp.addr, addr);
peer_states.mark_live(addr);
let backoff = peer_states.mark_dead(addr);
assert!(
backoff.is_some(),
"mark_dead should return backoff duration"
);
assert!(peer_states.mark_queued_for_retry(addr));
let cp2 = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("retried peer should pass after backoff")
.expect("channel closed");
assert_eq!(cp2.addr, addr);
}
#[test]
fn backoff_increases_exponentially() {
let expected = [10u64, 60, 360, 2160, 3600];
for (attempt, &want) in expected.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let attempt = attempt as u32;
let got = 10u64.saturating_mul(6u64.saturating_pow(attempt)).min(3600);
assert_eq!(got, want, "attempt {attempt}: expected {want}s, got {got}s");
}
for attempt in 5u32..=10 {
let got = 10u64.saturating_mul(6u64.saturating_pow(attempt)).min(3600);
assert_eq!(got, 3600, "attempt {attempt} should cap at 3600s");
}
}
#[tokio::test]
async fn retried_peer_passes_adder_checks() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(6881);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
let cp = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert_eq!(cp.addr, addr);
assert!(!peer_states.add_if_not_seen(addr, PeerSource::Dht));
peer_states.mark_live(addr);
let _ = peer_states.mark_dead(addr);
assert!(peer_states.mark_queued_for_retry(addr));
let cp2 = tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("retried peer with expired backoff should pass")
.expect("channel closed");
assert_eq!(cp2.addr, addr);
}
#[tokio::test]
async fn counter_increments_on_new_peer() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr1 = test_addr_ip(1, 6881);
let addr2 = test_addr_ip(2, 6882);
let addr3 = test_addr_ip(3, 6883);
peer_states.add_if_not_seen(addr1, PeerSource::Tracker);
peer_states.add_if_not_seen(addr2, PeerSource::Dht);
peer_states.add_if_not_seen(addr3, PeerSource::Tracker);
for _ in 0..3 {
tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
}
assert_eq!(peer_states.stats.snapshot().known, 3);
}
#[tokio::test]
async fn counter_ignores_duplicates() {
let sem = Arc::new(Semaphore::new(10));
let (peer_states, mut connect_rx, ..) = spawn_adder(sem);
let addr = test_addr(6881);
peer_states.add_if_not_seen(addr, PeerSource::Tracker);
tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out")
.expect("channel closed");
assert!(!peer_states.add_if_not_seen(addr, PeerSource::Dht));
let sentinel = test_addr_ip(99, 9999);
peer_states.add_if_not_seen(sentinel, PeerSource::Tracker);
tokio::time::timeout(Duration::from_secs(1), connect_rx.recv())
.await
.expect("timed out waiting for sentinel")
.expect("channel closed");
assert_eq!(peer_states.stats.snapshot().known, 2);
}
}