mod kbucket;
pub mod krpc;
use self::kbucket::KBucket;
use std::net::SocketAddr;
use rand::RngExt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Node {
pub id: [u8; 20],
pub addr: SocketAddr,
}
pub struct RoutingTable {
pub node_id: [u8; 20],
buckets: Vec<KBucket>,
}
const NUM_BUCKETS: usize = 160;
impl Default for RoutingTable {
fn default() -> Self {
Self::new()
}
}
impl RoutingTable {
pub fn new() -> Self {
RoutingTable {
node_id: generate_node_id(),
buckets: (0..NUM_BUCKETS).map(|_| KBucket::new()).collect(),
}
}
pub fn with_id(node_id: [u8; 20]) -> Self {
RoutingTable {
node_id,
buckets: (0..NUM_BUCKETS).map(|_| KBucket::new()).collect(),
}
}
pub fn insert(&mut self, node: Node) -> bool {
tracing::debug!("DHT insert: {}", node.addr);
let bucket_idx = bucket_index(&self.node_id, &node.id);
self.buckets[bucket_idx].insert(node)
}
pub fn find_closest(&self, target: &[u8; 20], count: usize) -> Vec<Node> {
let mut all_nodes: Vec<&Node> = self.buckets.iter().flat_map(|b| b.iter()).collect();
all_nodes.sort_by_key(|n| xor_distance(&n.id, target));
all_nodes.into_iter().take(count).cloned().collect()
}
pub fn num_nodes(&self) -> usize {
self.buckets.iter().map(|b| b.len()).sum()
}
pub fn generate_node_id() -> [u8; 20] {
generate_node_id()
}
}
fn xor_distance(a: &[u8; 20], b: &[u8; 20]) -> [u8; 20] {
let mut dist = [0u8; 20];
for i in 0..20 {
dist[i] = a[i] ^ b[i];
}
dist
}
fn bucket_index(our_id: &[u8; 20], node_id: &[u8; 20]) -> usize {
for i in 0..20 {
let diff = our_id[i] ^ node_id[i];
if diff != 0 {
let leading_zeros = diff.leading_zeros() as usize;
return i * 8 + leading_zeros;
}
}
0 }
pub fn generate_node_id() -> [u8; 20] {
use sha1::{Digest, Sha1};
let seed: u64 = rand::rng().random();
let mut hasher = Sha1::new();
hasher.update(seed.to_be_bytes());
hasher.update(b"torrent-rs-dht-node");
let result = hasher.finalize();
let mut id = [0u8; 20];
id.copy_from_slice(&result);
id
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xor_distance_same() {
let id = [0x42u8; 20];
assert_eq!(xor_distance(&id, &id), [0u8; 20]);
}
#[test]
fn xor_distance_different() {
let a = [0x00u8; 20];
let mut b = [0x00u8; 20];
b[0] = 0x01;
let dist = xor_distance(&a, &b);
assert_eq!(dist[0], 0x01);
}
#[test]
fn bucket_index_first_byte_diff() {
let our = [0x00u8; 20];
let node = [0x80u8; 20]; assert_eq!(bucket_index(&our, &node), 0);
}
#[test]
fn bucket_index_second_byte() {
let our = [0x00u8; 20];
let mut node = [0x00u8; 20];
node[1] = 0x01;
let idx = bucket_index(&our, &node);
assert_eq!(idx, 15);
}
#[test]
fn routing_table_insert() {
let mut rt = RoutingTable::new();
let node = Node {
id: [0x01u8; 20],
addr: "127.0.0.1:6881".parse().unwrap(),
};
assert!(rt.insert(node));
assert_eq!(rt.num_nodes(), 1);
}
#[test]
fn routing_table_find_closest() {
let mut rt = RoutingTable::with_id([0xFFu8; 20]);
for i in 0..16 {
let mut id = [0u8; 20];
id[0] = i;
rt.insert(Node {
id,
addr: "127.0.0.1:6881".parse().unwrap(),
});
}
let target = [0x0Au8; 20];
let closest = rt.find_closest(&target, 4);
assert_eq!(closest.len(), 4);
}
#[test]
fn routing_table_max_per_bucket() {
let mut rt = RoutingTable::with_id([0x00u8; 20]);
for i in 0..12 {
let mut id = [0x00u8; 20];
id[0] = 0x80 + i; rt.insert(Node {
id,
addr: "127.0.0.1:6881".parse().unwrap(),
});
}
assert_eq!(rt.num_nodes(), 8);
}
#[test]
fn node_id_generation() {
let id1 = generate_node_id();
let id2 = generate_node_id();
assert_eq!(id1.len(), 20);
assert_eq!(id2.len(), 20);
assert_ne!(id1, id2);
}
#[test]
fn find_closest_returns_correct_order() {
let mut rt = RoutingTable::with_id([0x00u8; 20]);
for i in 1u8..=8 {
let mut id = [0u8; 20];
id[0] = i;
rt.insert(Node {
id,
addr: "127.0.0.1:6881".parse().unwrap(),
});
}
let target = [0x00u8; 20];
let closest = rt.find_closest(&target, 4);
assert_eq!(closest.len(), 4);
for (i, node) in closest.iter().enumerate() {
assert_eq!(node.id[0], (i + 1) as u8);
}
}
#[test]
fn find_closest_count_exceeds_total() {
let mut rt = RoutingTable::with_id([0u8; 20]);
rt.insert(Node {
id: [0x01u8; 20],
addr: "127.0.0.1:6881".parse().unwrap(),
});
rt.insert(Node {
id: [0x02u8; 20],
addr: "127.0.0.1:6882".parse().unwrap(),
});
let closest = rt.find_closest(&[0x00u8; 20], 10);
assert_eq!(closest.len(), 2);
}
#[test]
fn bucket_index_last_bit_diff() {
let our = [0x00u8; 20];
let mut node = [0x00u8; 20];
node[19] = 0x01; assert_eq!(bucket_index(&our, &node), 159);
}
#[test]
fn bucket_index_same_id() {
let id = [0x42u8; 20];
assert_eq!(bucket_index(&id, &id), 0);
}
}