use std::collections::VecDeque;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs};
use std::num::NonZeroU32;
use std::time::Duration;
use anyhow::Context;
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use crate::requests::{GetPeersQueryMessage, KrpcMessage, KrpcMessageType};
use crate::util::generate_node_id;
const DEFAULT_BUF_SIZE: usize = 65_536;
const INFLIGHT_REQUEST_TIMEOUT_SECS: f64 = 1.;
#[derive(Debug)]
pub struct DhtRequester {
read_buf: Vec<u8>,
serialized_get_peers_message: Vec<u8>,
node_queue: VecDeque<SocketAddrV4>,
processed_nodes: Vec<SocketAddrV4>,
seen_peers: Vec<Ipv4Addr>,
inflight_requests: VecDeque<(Instant, SocketAddrV4)>,
rate_limiter: DefaultDirectRateLimiter,
}
impl DhtRequester {
pub fn new(bootstrap_node_addrs: Option<Vec<SocketAddrV4>>, info_hash: [u8; 20]) -> anyhow::Result<Self> {
let node_queue = VecDeque::from(bootstrap_node_addrs.unwrap_or_else(|| {
["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]
.iter()
.filter_map(|&node| {
node.to_socket_addrs().ok().and_then(|mut socket_addrs| {
socket_addrs.find_map(|addr| {
if let SocketAddr::V4(v4_addr) = addr {
Some(v4_addr)
} else {
None
}
})
})
})
.collect()
}));
if node_queue.is_empty() {
anyhow::bail!("No suitable bootstrap DHT nodes found");
}
let message = KrpcMessage {
transaction_id: "10".into(),
message_type: KrpcMessageType::Query {
name: "get_peers".into(),
query: GetPeersQueryMessage {
id: generate_node_id(),
info_hash,
},
},
};
let rate_limiter = RateLimiter::direct(Quota::per_second(NonZeroU32::new(200).unwrap()));
let mut serialized_message = Vec::new();
serde_bencode::to_writer(&message, &mut serialized_message)
.context("failed to serialize a get_peers DHT message")?;
Ok(DhtRequester {
read_buf: vec![0; DEFAULT_BUF_SIZE],
serialized_get_peers_message: serialized_message,
node_queue,
processed_nodes: Vec::new(),
inflight_requests: VecDeque::new(),
seen_peers: Vec::new(),
rate_limiter,
})
}
#[tracing::instrument(level = "debug", err, skip_all)]
pub async fn process_dht_nodes(
&mut self,
mut cancellation: oneshot::Receiver<()>,
peer_queue_sender: mpsc::Sender<SocketAddrV4>,
) -> anyhow::Result<()> {
let socket = UdpSocket::bind("0.0.0.0:6881").await.context("binding UDP socket")?;
let mut request_cleanup_interval =
tokio::time::interval(Duration::from_secs_f64(INFLIGHT_REQUEST_TIMEOUT_SECS));
'main: loop {
tokio::select! {
result = socket.recv_from(&mut self.read_buf) => {
let (read_bytes, from_node) = result.context("receiving a message from a node")?;
let from_node = match from_node {
SocketAddr::V4(addr) => addr,
_ => {
tracing::debug!("received a message from node using IPv6: {}", from_node);
continue;
}
};
self.inflight_requests.retain(|(_, node_addr)| node_addr != &from_node);
tracing::trace!("Received a response from DHT node: {}", from_node);
match serde_bencode::from_bytes::<KrpcMessage>(&self.read_buf[..read_bytes]) {
Ok(get_peers_response) => {
match get_peers_response.message_type {
KrpcMessageType::Query { query, .. } => {
tracing::warn!(
addr = %from_node,
"Unexpected response from DHT node. Expected get_peers response, got Query: {:?}",
query
)
},
KrpcMessageType::Response { response }=> {
if !response.nodes.is_some() && !response.values.is_some() {
tracing::debug!(
addr = %from_node,
"Bad get_peers response from node: no nodes or peers",
)
}
if let Some(nodes) = response.nodes {
for compact_node_info in nodes.chunks(26) {
let compact_node_addr = &compact_node_info[20..];
let ip = TryInto::<[u8; 4]>::try_into(&compact_node_addr[..4]).context("converting IP from slice to array")?;
let port = ((compact_node_addr[4] as u16) << 8) | compact_node_addr[5] as u16;
let ip_addr = Ipv4Addr::from(ip);
let node_addr = SocketAddrV4::new(ip_addr, port);
if !self.processed_nodes.contains(&node_addr) && !self.node_queue.contains(&node_addr) {
self.node_queue.push_back(node_addr);
}
}
}
if let Some(peers) = response.values {
for compact_peer_addr in peers.iter() {
let ip = TryInto::<[u8; 4]>::try_into(&compact_peer_addr.0[..4]).context("converting IP from slice to array")?;
let port = ((compact_peer_addr.0[4] as u16) << 8) | compact_peer_addr.0[5] as u16;
let ip_addr = Ipv4Addr::from(ip);
if !self.seen_peers.contains(&ip_addr) {
self.seen_peers.push(ip_addr);
let peer_addr = SocketAddrV4::new(ip_addr, port);
match peer_queue_sender.send(peer_addr).await {
Ok(_) => {},
Err(_) => {
tracing::debug!("Receiving half of the peer queue sender was dropped, shutting down...");
break 'main;
}
};
}
}
}
},
KrpcMessageType::Error { error } => {
tracing::debug!(
addr = %from_node,
"Got Error from DHT node: {:?}",
error
)
}
}
}
Err(e) => tracing::debug!(addr = %from_node, "An error happened while decoding a message from DHT node: {}", e)}
}
_ = self.rate_limiter.until_ready(), if !self.node_queue.is_empty() => {
self
.query_next_node(&socket)
.await
.context("querying next node in the queue")?;
}
_ = request_cleanup_interval.tick() => {}
_ = &mut cancellation => {
tracing::debug!("Cancellation requsted. Requester is exiting");
break;
}
}
if self
.inflight_requests
.front()
.is_some_and(|(req_time, _)| req_time.elapsed().as_secs_f64() > INFLIGHT_REQUEST_TIMEOUT_SECS)
{
self.inflight_requests
.retain(|(req_time, _)| req_time.elapsed().as_secs_f64() > INFLIGHT_REQUEST_TIMEOUT_SECS);
}
if self.node_queue.is_empty() && self.inflight_requests.is_empty() {
tracing::debug!("No DHT nodes left to query. Requester is exiting");
break;
}
}
Ok(())
}
#[tracing::instrument(level = "debug", err, skip_all)]
pub async fn query_next_node(&mut self, socket: &UdpSocket) -> anyhow::Result<bool> {
let Some(next_node) = self.node_queue.pop_front() else {
return Ok(false);
};
self.processed_nodes.push(next_node);
let request_time = Instant::now();
self.inflight_requests.push_back((request_time, next_node));
socket
.send_to(&self.serialized_get_peers_message, next_node)
.await
.with_context(|| format!("sending a get_peers message to '{}'", next_node))?;
Ok(true)
}
}