#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::unchecked_time_subtraction,
reason = "M175: DHT iterative lookup — node counts bounded by ALPHA fan-out; remaining time-sub sites are test fixtures"
)]
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::{Duration, Instant};
use dashmap::DashMap;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, trace};
use irontide_core::{AddressFamily, Id20};
#[cfg(test)]
use crate::compact::CompactNodeInfo;
use crate::krpc::{
GetPeersResponse, KrpcBody, KrpcMessage, KrpcQuery, KrpcResponse, TransactionId,
};
use crate::routing_table::RoutingTable;
use crate::actor::{PendingQuery, PendingQueryKind, SharedRateLimiter};
pub(crate) struct LookupConfig {
pub max_depth: u8,
pub max_nodes: usize,
}
struct TrackedNode {
id: Id20,
addr: SocketAddr,
depth: u8,
returned_peers: bool,
responded: bool,
last_queried: Option<Instant>,
}
pub(crate) struct DhtLookup {
target: Id20,
config: LookupConfig,
nodes: Vec<TrackedNode>,
queried_addrs: HashSet<SocketAddr>,
address_family: AddressFamily,
socket: Arc<UdpSocket>,
pending: Arc<DashMap<u16, PendingQuery>>,
rate_limiter: Arc<SharedRateLimiter>,
routing_table: Arc<parking_lot::RwLock<RoutingTable>>,
next_txn_id: Arc<AtomicU16>,
own_id: Id20,
peer_tx: mpsc::UnboundedSender<Vec<SocketAddr>>,
token_tx: mpsc::UnboundedSender<(Id20, Id20, SocketAddr, Vec<u8>)>,
node_tx: mpsc::UnboundedSender<(Id20, SocketAddr)>,
empty_inject_count: u32,
read_only_mode: bool,
want: Option<Vec<crate::krpc::WantFamily>>,
}
const REQUERY_INTERVAL: Duration = Duration::from_mins(1);
const REQUERY_DEDUP: Duration = Duration::from_mins(1);
const QUERY_TIMEOUT: Duration = Duration::from_secs(10);
type QueryFuture = std::pin::Pin<
Box<
dyn std::future::Future<Output = (SocketAddr, Result<(Id20, GetPeersResponse), ()>)> + Send,
>,
>;
impl DhtLookup {
#[allow(clippy::too_many_arguments)]
pub fn new(
target: Id20,
config: LookupConfig,
address_family: AddressFamily,
socket: Arc<UdpSocket>,
pending: Arc<DashMap<u16, PendingQuery>>,
rate_limiter: Arc<SharedRateLimiter>,
routing_table: Arc<parking_lot::RwLock<RoutingTable>>,
next_txn_id: Arc<AtomicU16>,
own_id: Id20,
peer_tx: mpsc::UnboundedSender<Vec<SocketAddr>>,
token_tx: mpsc::UnboundedSender<(Id20, Id20, SocketAddr, Vec<u8>)>,
node_tx: mpsc::UnboundedSender<(Id20, SocketAddr)>,
read_only_mode: bool,
want: Option<Vec<crate::krpc::WantFamily>>,
) -> Self {
Self {
target,
config,
nodes: Vec::with_capacity(256),
queried_addrs: HashSet::new(),
address_family,
socket,
pending,
rate_limiter,
routing_table,
next_txn_id,
own_id,
peer_tx,
token_tx,
node_tx,
empty_inject_count: 0,
read_only_mode,
want,
}
}
pub async fn run(mut self) {
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
self.inject_roots(&mut futures);
let requery_sleep = tokio::time::sleep(Duration::from_secs(1));
tokio::pin!(requery_sleep);
loop {
if self.peer_tx.is_closed() {
debug!(
target = %self.target,
"DhtLookup: peer channel closed, shutting down"
);
break;
}
tokio::select! {
() = &mut requery_sleep => {
let injected = self.inject_roots(&mut futures);
let next_delay = if injected == 0 {
self.empty_inject_count =
self.empty_inject_count.saturating_add(1);
Duration::from_secs(
(1u64 << self.empty_inject_count
.saturating_sub(1)
.min(4))
.min(15),
)
} else {
self.empty_inject_count = 0;
if injected < 8 {
REQUERY_INTERVAL / 8 * injected.min(8)
} else {
REQUERY_INTERVAL
}
};
requery_sleep
.as_mut()
.reset(tokio::time::Instant::now() + next_delay);
}
result = futures.next(), if !futures.is_empty() => {
match result {
Some((addr, Ok((sender_id, gp)))) => {
self.process_response(addr, sender_id, &gp, &mut futures);
}
Some((addr, Err(()))) => {
self.mark_error(addr);
}
None => {
}
}
}
() = tokio::time::sleep(Duration::from_secs(1)), if futures.is_empty() => {}
}
}
}
fn inject_roots(&mut self, futures: &mut FuturesUnordered<QueryFuture>) -> u32 {
let roots = self.routing_table.read().closest(&self.target, 8);
let mut spawned = 0u32;
for node in roots {
if self.maybe_add_node(node.id, node.addr, 0) {
futures.push(self.spawn_query(node.addr, Some(node.id)));
spawned = spawned.saturating_add(1);
}
}
if spawned > 0 {
trace!(
target = %self.target,
spawned,
total_nodes = self.nodes.len(),
"DhtLookup: injected roots"
);
}
spawned
}
fn maybe_add_node(&mut self, id: Id20, addr: SocketAddr, depth: u8) -> bool {
if depth > self.config.max_depth {
return false;
}
let family_ok = match self.address_family {
AddressFamily::V4 => addr.is_ipv4(),
AddressFamily::V6 => addr.is_ipv6(),
};
if !family_ok {
return false;
}
if id == self.own_id {
return false;
}
if let Some(existing) = self.nodes.iter().find(|n| n.addr == addr) {
if let Some(last) = existing.last_queried {
if last.elapsed() < REQUERY_DEDUP {
return false;
}
} else {
return false;
}
if let Some(existing) = self.nodes.iter_mut().find(|n| n.addr == addr) {
existing.last_queried = None; }
self.queried_addrs.remove(&addr);
return true;
}
self.nodes.push(TrackedNode {
id,
addr,
depth,
returned_peers: false,
responded: false,
last_queried: None,
});
if self.nodes.len() > self.config.max_nodes {
let target = self.target;
self.sort_nodes(&target);
self.nodes.pop();
if !self.nodes.iter().any(|n| n.addr == addr) {
return false;
}
}
true
}
fn spawn_query(&mut self, addr: SocketAddr, node_id: Option<Id20>) -> QueryFuture {
self.queried_addrs.insert(addr);
if let Some(node) = self.nodes.iter_mut().find(|n| n.addr == addr) {
node.last_queried = Some(Instant::now());
}
let socket = self.socket.clone();
let pending = self.pending.clone();
let rate_limiter = self.rate_limiter.clone();
let next_txn_id = self.next_txn_id.clone();
let own_id = self.own_id;
let target = self.target;
let read_only = self.read_only_mode;
let want = self.want.clone();
Box::pin(async move {
rate_limiter.acquire().await;
let txn = next_txn_id.fetch_add(1, Ordering::Relaxed);
let txn = if txn == 0 {
next_txn_id.fetch_add(1, Ordering::Relaxed)
} else {
txn
};
let (tx, rx) = oneshot::channel();
let msg = KrpcMessage {
transaction_id: TransactionId::from_u16(txn),
body: KrpcBody::Query(KrpcQuery::GetPeers {
id: own_id,
info_hash: target,
noseed: None,
scrape: None,
want: want.clone(),
}),
sender_ip: None,
read_only,
};
let Ok(bytes) = msg.to_bytes() else {
return (addr, Err(()));
};
pending.insert(
txn,
PendingQuery {
sent_at: Instant::now(),
addr,
kind: PendingQueryKind::GetPeers { info_hash: target },
node_id,
response_tx: Some(tx),
},
);
if socket.send_to(&bytes, addr).await.is_err() {
pending.remove(&txn);
return (addr, Err(()));
}
if let Ok(Ok(resp)) = tokio::time::timeout(QUERY_TIMEOUT, rx).await {
if let KrpcResponse::GetPeers(gp) = resp.response {
(addr, Ok((resp.sender_id, gp)))
} else {
(addr, Err(()))
}
} else {
pending.remove(&txn);
(addr, Err(()))
}
})
}
fn process_response(
&mut self,
addr: SocketAddr,
sender_id: Id20,
gp: &GetPeersResponse,
futures: &mut FuturesUnordered<QueryFuture>,
) {
if let Some(node) = self.nodes.iter_mut().find(|n| n.addr == addr) {
node.responded = true;
node.returned_peers = !gp.peers.is_empty();
}
if !gp.peers.is_empty() {
let _ = self.peer_tx.send(gp.peers.clone());
}
if let Some(token) = &gp.token {
let _ = self
.token_tx
.send((self.target, sender_id, addr, token.clone()));
}
let depth = self
.nodes
.iter()
.find(|n| n.addr == addr)
.map_or(0, |n| n.depth);
let new_depth = depth.saturating_add(1);
for node in &gp.nodes {
self.forward_node(node.id, node.addr);
if new_depth <= self.config.max_depth
&& self.maybe_add_node(node.id, node.addr, new_depth)
{
futures.push(self.spawn_query(node.addr, Some(node.id)));
}
}
for node in &gp.nodes6 {
self.forward_node(node.id, node.addr);
if new_depth <= self.config.max_depth
&& self.maybe_add_node(node.id, node.addr, new_depth)
{
futures.push(self.spawn_query(node.addr, Some(node.id)));
}
}
}
fn forward_node(&self, id: Id20, addr: SocketAddr) {
let _ = self.node_tx.send((id, addr));
}
fn mark_error(&mut self, addr: SocketAddr) {
if let Some(node) = self.nodes.iter_mut().find(|n| n.addr == addr) {
node.responded = false;
node.returned_peers = false;
}
}
fn sort_nodes(&mut self, target: &Id20) {
self.nodes.sort_by(|a, b| {
b.returned_peers
.cmp(&a.returned_peers)
.then_with(|| {
b.responded.cmp(&a.responded)
})
.then_with(|| {
a.id.xor_distance(target).cmp(&b.id.xor_distance(target))
})
.then_with(|| {
match (a.last_queried, b.last_queried) {
(Some(a_t), Some(b_t)) => b_t.cmp(&a_t),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
}
})
});
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tracked(byte: u8, port: u16, responded: bool, returned_peers: bool) -> TrackedNode {
let mut id_bytes = [0u8; 20];
id_bytes[19] = byte;
TrackedNode {
id: Id20(id_bytes),
addr: SocketAddr::from(([127, 0, 0, 1], port)),
depth: 0,
returned_peers,
responded,
last_queried: None,
}
}
fn make_id(byte: u8) -> Id20 {
let mut bytes = [0u8; 20];
bytes[19] = byte;
Id20(bytes)
}
#[test]
fn tracked_node_multi_key_sort() {
let target = Id20::ZERO;
let now = Instant::now();
let mut nodes = vec![
tracked(0xFF, 1, true, false),
tracked(0x01, 2, false, false),
tracked(0x02, 3, true, true),
tracked(0xFE, 4, true, true),
];
let mut wrapper_nodes = std::mem::take(&mut nodes);
wrapper_nodes.sort_by(|a, b| {
b.returned_peers
.cmp(&a.returned_peers)
.then_with(|| b.responded.cmp(&a.responded))
.then_with(|| a.id.xor_distance(&target).cmp(&b.id.xor_distance(&target)))
.then_with(|| match (a.last_queried, b.last_queried) {
(Some(a_t), Some(b_t)) => b_t.cmp(&a_t),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
})
});
assert_eq!(wrapper_nodes[0].addr.port(), 3); assert_eq!(wrapper_nodes[1].addr.port(), 4); assert_eq!(wrapper_nodes[2].addr.port(), 1); assert_eq!(wrapper_nodes[3].addr.port(), 2);
let mut tied = [
{
let mut n = tracked(0x10, 10, true, true);
n.last_queried = Some(now - Duration::from_secs(30)); n
},
{
let mut n = tracked(0x10, 11, true, true);
n.id.0[18] = 1; n.last_queried = Some(now); n
},
];
tied.sort_by(|a, b| {
b.returned_peers
.cmp(&a.returned_peers)
.then_with(|| b.responded.cmp(&a.responded))
.then_with(|| a.id.xor_distance(&target).cmp(&b.id.xor_distance(&target)))
.then_with(|| match (a.last_queried, b.last_queried) {
(Some(a_t), Some(b_t)) => b_t.cmp(&a_t),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
})
});
let mut fresh_test = [
{
let mut n = tracked(0x10, 20, true, true);
n.last_queried = Some(now - Duration::from_secs(30)); n
},
{
let mut n = tracked(0x10, 21, true, true);
n.last_queried = Some(now); n
},
];
fresh_test.sort_by(|a, b| {
b.returned_peers
.cmp(&a.returned_peers)
.then_with(|| b.responded.cmp(&a.responded))
.then_with(|| a.id.xor_distance(&target).cmp(&b.id.xor_distance(&target)))
.then_with(|| match (a.last_queried, b.last_queried) {
(Some(a_t), Some(b_t)) => b_t.cmp(&a_t),
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
})
});
assert_eq!(
fresh_test[0].addr.port(),
21,
"fresher node should sort first"
);
assert_eq!(
fresh_test[1].addr.port(),
20,
"stale node should sort second"
);
}
#[tokio::test]
async fn should_request_depth_limit() {
let rt = RoutingTable::new(Id20::ZERO);
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr4: SocketAddr = "1.2.3.4:1000".parse().expect("parse");
assert!(lookup.maybe_add_node(make_id(1), addr4, 4));
let addr5: SocketAddr = "1.2.3.5:1001".parse().expect("parse");
assert!(!lookup.maybe_add_node(make_id(2), addr5, 5));
}
#[tokio::test]
async fn should_request_requery_dedup() {
let rt = RoutingTable::new(Id20::ZERO);
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr: SocketAddr = "1.2.3.4:1000".parse().expect("parse");
let id = make_id(1);
assert!(lookup.maybe_add_node(id, addr, 0));
if let Some(n) = lookup.nodes.iter_mut().find(|n| n.addr == addr) {
n.last_queried = Some(Instant::now());
}
assert!(!lookup.maybe_add_node(id, addr, 0));
}
#[tokio::test]
async fn should_request_evict_self_rejected() {
let own_id = make_id(0xAA);
let rt = RoutingTable::new(own_id);
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
own_id,
peer_tx,
token_tx,
node_tx,
false, None, );
let addr: SocketAddr = "1.2.3.4:1000".parse().expect("parse");
assert!(!lookup.maybe_add_node(own_id, addr, 0));
}
#[tokio::test]
async fn max_nodes_capacity_eviction() {
let target = Id20::ZERO;
let rt = RoutingTable::new(make_id(0xFF));
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut lookup = DhtLookup::new(
target,
LookupConfig {
max_depth: 4,
max_nodes: 3,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let a1: SocketAddr = "1.0.0.1:1".parse().expect("parse");
let a2: SocketAddr = "1.0.0.2:2".parse().expect("parse");
let a3: SocketAddr = "1.0.0.3:3".parse().expect("parse");
assert!(lookup.maybe_add_node(make_id(0x80), a1, 0));
assert!(lookup.maybe_add_node(make_id(0x90), a2, 0));
assert!(lookup.maybe_add_node(make_id(0xA0), a3, 0));
assert_eq!(lookup.nodes.len(), 3);
let a4: SocketAddr = "1.0.0.4:4".parse().expect("parse");
assert!(lookup.maybe_add_node(make_id(0x01), a4, 0));
assert_eq!(lookup.nodes.len(), 3);
assert!(lookup.nodes.iter().any(|n| n.addr == a4));
}
#[tokio::test]
async fn next_queries_no_alpha_cap() {
let rt = RoutingTable::new(Id20::ZERO);
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let mut count = 0u32;
for i in 1..=10u8 {
let addr: SocketAddr = SocketAddr::from(([10, 0, 0, i], 6880 + u16::from(i)));
if lookup.maybe_add_node(make_id(i), addr, 0) {
count = count.saturating_add(1);
}
}
assert_eq!(count, 10);
assert_eq!(lookup.nodes.len(), 10);
}
#[tokio::test]
async fn lookup_streams_peers_to_caller() {
let (peer_tx, mut peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let rt = RoutingTable::new(Id20::ZERO);
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr: SocketAddr = "1.2.3.4:6881".parse().expect("parse");
lookup.maybe_add_node(make_id(1), addr, 0);
let peer_addr: SocketAddr = "5.6.7.8:9999".parse().expect("parse");
let gp = GetPeersResponse {
id: make_id(1),
token: None,
peers: vec![peer_addr],
nodes: Vec::new(),
nodes6: Vec::new(),
bfpe: None,
bfsd: None,
};
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.process_response(addr, make_id(1), &gp, &mut futures);
let received = peer_rx.try_recv().expect("should have received peers");
assert_eq!(received, vec![peer_addr]);
}
#[tokio::test]
async fn lookup_sends_tokens_to_actor() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, mut token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let target = make_id(0x42);
let rt = RoutingTable::new(Id20::ZERO);
let mut lookup = DhtLookup::new(
target,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr: SocketAddr = "1.2.3.4:6881".parse().expect("parse");
let sender_id = make_id(1);
lookup.maybe_add_node(sender_id, addr, 0);
let gp = GetPeersResponse {
id: sender_id,
token: Some(b"my_token".to_vec()),
peers: Vec::new(),
nodes: Vec::new(),
nodes6: Vec::new(),
bfpe: None,
bfsd: None,
};
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.process_response(addr, sender_id, &gp, &mut futures);
let (ih, nid, tkn_addr, tkn) = token_rx.try_recv().expect("should have received token");
assert_eq!(ih, target);
assert_eq!(nid, sender_id);
assert_eq!(tkn_addr, addr);
assert_eq!(tkn, b"my_token");
}
#[tokio::test]
async fn lookup_completes_single_round() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let rt = RoutingTable::new(Id20::ZERO);
let mut lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr: SocketAddr = "1.2.3.4:6881".parse().expect("parse");
lookup.maybe_add_node(make_id(1), addr, 0);
let new_node1 = CompactNodeInfo {
id: make_id(2),
addr: "1.2.3.5:6882".parse().expect("parse"),
};
let new_node2 = CompactNodeInfo {
id: make_id(3),
addr: "1.2.3.6:6883".parse().expect("parse"),
};
let gp = GetPeersResponse {
id: make_id(1),
token: None,
peers: Vec::new(),
nodes: vec![new_node1, new_node2],
nodes6: Vec::new(),
bfpe: None,
bfsd: None,
};
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.process_response(addr, make_id(1), &gp, &mut futures);
assert_eq!(lookup.nodes.len(), 3);
assert_eq!(futures.len(), 2);
}
#[tokio::test]
async fn lookup_response_channel_closed() {
let (peer_tx, peer_rx) = mpsc::unbounded_channel::<Vec<SocketAddr>>();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let rt = RoutingTable::new(Id20::ZERO);
let lookup = DhtLookup::new(
Id20::ZERO,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
drop(peer_rx);
let handle = tokio::spawn(lookup.run());
let result = tokio::time::timeout(Duration::from_secs(3), handle).await;
assert!(result.is_ok(), "lookup should have exited within 3 seconds");
}
#[test]
fn lookup_stale_response_dropped() {
let pending: Arc<DashMap<u16, PendingQuery>> = Arc::new(DashMap::new());
assert!(pending.get(&999).is_none());
}
#[test]
fn lookup_rate_limiter_backpressure() {
let limiter = SharedRateLimiter::new(2);
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
}
#[tokio::test]
async fn lookup_requery_reinjects_roots() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut rt = RoutingTable::new(Id20::ZERO);
let node_addr: SocketAddr = "10.0.0.1:6881".parse().expect("parse");
rt.insert(make_id(1), node_addr);
let mut lookup = DhtLookup::new(
make_id(0x42),
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.inject_roots(&mut futures);
assert_eq!(lookup.nodes.len(), 1);
lookup.inject_roots(&mut futures);
assert_eq!(lookup.nodes.len(), 1);
}
#[tokio::test]
async fn lookup_accumulates_across_requery() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut rt = RoutingTable::new(Id20::ZERO);
rt.insert(make_id(1), "10.0.0.1:6881".parse().expect("parse"));
let rt_arc = Arc::new(parking_lot::RwLock::new(rt));
let mut lookup = DhtLookup::new(
make_id(0x42),
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
rt_arc.clone(),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.inject_roots(&mut futures);
assert_eq!(lookup.nodes.len(), 1);
rt_arc
.write()
.insert(make_id(2), "10.0.0.2:6882".parse().expect("parse"));
lookup.inject_roots(&mut futures);
assert_eq!(lookup.nodes.len(), 2);
}
#[tokio::test]
async fn announce_tokens_available_after_lookup() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, mut token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let target = make_id(0x42);
let rt = RoutingTable::new(Id20::ZERO);
let mut lookup = DhtLookup::new(
target,
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let addr1: SocketAddr = "1.2.3.4:6881".parse().expect("parse");
let addr2: SocketAddr = "5.6.7.8:6882".parse().expect("parse");
let id1 = make_id(1);
let id2 = make_id(2);
lookup.maybe_add_node(id1, addr1, 0);
lookup.maybe_add_node(id2, addr2, 0);
let gp1 = GetPeersResponse {
id: id1,
token: Some(b"token_a".to_vec()),
peers: Vec::new(),
nodes: Vec::new(),
nodes6: Vec::new(),
bfpe: None,
bfsd: None,
};
let gp2 = GetPeersResponse {
id: id2,
token: Some(b"token_b".to_vec()),
peers: Vec::new(),
nodes: Vec::new(),
nodes6: Vec::new(),
bfpe: None,
bfsd: None,
};
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
lookup.process_response(addr1, id1, &gp1, &mut futures);
lookup.process_response(addr2, id2, &gp2, &mut futures);
let mut tokens = Vec::new();
while let Ok(t) = token_rx.try_recv() {
tokens.push(t);
}
assert_eq!(tokens.len(), 2);
assert!(tokens.iter().any(|(_, _, _, t)| t == b"token_a"));
assert!(tokens.iter().any(|(_, _, _, t)| t == b"token_b"));
}
#[test]
fn inject_roots_returns_count_and_adaptive_backoff() {
let mut count = 0u32;
for expected in [1u64, 2, 4, 8, 15, 15] {
count = count.saturating_add(1);
let delay_secs = (1u64 << count.saturating_sub(1).min(4)).min(15);
assert_eq!(delay_secs, expected, "backoff step {count}");
}
}
#[tokio::test]
async fn inject_roots_returns_nonzero_for_populated_table() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let mut rt = RoutingTable::new(Id20::ZERO);
rt.insert(make_id(1), "10.0.0.1:6881".parse().expect("parse"));
rt.insert(make_id(2), "10.0.0.2:6882".parse().expect("parse"));
rt.insert(make_id(3), "10.0.0.3:6883".parse().expect("parse"));
let mut lookup = DhtLookup::new(
make_id(0x42),
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
let injected = lookup.inject_roots(&mut futures);
assert_eq!(injected, 3);
assert_eq!(lookup.empty_inject_count, 0);
let injected2 = lookup.inject_roots(&mut futures);
assert_eq!(injected2, 0);
}
#[tokio::test]
async fn empty_inject_count_increments_on_empty_roots() {
let (peer_tx, _peer_rx) = mpsc::unbounded_channel();
let (token_tx, _token_rx) = mpsc::unbounded_channel();
let (node_tx, _node_rx) = mpsc::unbounded_channel();
let tok_socket = Arc::new(UdpSocket::bind("127.0.0.1:0").await.expect("bind"));
let rt = RoutingTable::new(Id20::ZERO);
let mut lookup = DhtLookup::new(
make_id(0x42),
LookupConfig {
max_depth: 4,
max_nodes: 256,
},
AddressFamily::V4,
tok_socket,
Arc::new(DashMap::new()),
Arc::new(SharedRateLimiter::new(250)),
Arc::new(parking_lot::RwLock::new(rt)),
Arc::new(AtomicU16::new(1)),
make_id(0xFF),
peer_tx,
token_tx,
node_tx,
false, None, );
let mut futures: FuturesUnordered<QueryFuture> = FuturesUnordered::new();
assert_eq!(lookup.empty_inject_count, 0);
let injected = lookup.inject_roots(&mut futures);
assert_eq!(injected, 0);
if injected == 0 {
lookup.empty_inject_count = lookup.empty_inject_count.saturating_add(1);
}
assert_eq!(lookup.empty_inject_count, 1);
lookup.empty_inject_count = 5;
if injected > 0 {
lookup.empty_inject_count = 0;
}
assert_eq!(lookup.empty_inject_count, 5);
}
#[tokio::test]
async fn dht_lookup_query_includes_ro() {
let target = Id20::ZERO;
let own_id = make_id(0xFF);
let msg = KrpcMessage {
transaction_id: TransactionId::from_u16(99),
body: KrpcBody::Query(KrpcQuery::GetPeers {
id: own_id,
info_hash: target,
noseed: None,
scrape: None,
want: None,
}),
sender_ip: None,
read_only: true,
};
let bytes = msg.to_bytes().unwrap();
let decoded = KrpcMessage::from_bytes(&bytes).unwrap();
assert!(decoded.read_only, "get_peers query should carry ro flag");
let raw: irontide_bencode::BencodeValue = irontide_bencode::from_bytes(&bytes).unwrap();
let dict = raw.as_dict().unwrap();
assert!(
dict.contains_key(&b"ro"[..]),
"encoded bytes should contain ro key"
);
let ro_val = dict.get(&b"ro"[..]).unwrap().as_int().unwrap();
assert_eq!(ro_val, 1);
let msg_normal = KrpcMessage {
transaction_id: TransactionId::from_u16(100),
body: KrpcBody::Query(KrpcQuery::GetPeers {
id: own_id,
info_hash: target,
noseed: None,
scrape: None,
want: None,
}),
sender_ip: None,
read_only: false,
};
let bytes_normal = msg_normal.to_bytes().unwrap();
let raw_normal: irontide_bencode::BencodeValue =
irontide_bencode::from_bytes(&bytes_normal).unwrap();
let dict_normal = raw_normal.as_dict().unwrap();
assert!(
!dict_normal.contains_key(&b"ro"[..]),
"non-read-only query should not contain ro key"
);
}
}