irontide-dht 0.165.0

Kademlia DHT for BitTorrent (BEP 5)
Documentation
//! Per-info_hash peer storage and token generation/validation.
//!
//! Tokens are generated per-IP using a rotating secret so that
//! `announce_peer` requests can be validated without persistent state.

use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};

use irontide_core::{Id20, sha1};

/// How long peers are kept before expiry.
const PEER_EXPIRY: Duration = Duration::from_secs(30 * 60); // 30 minutes

/// How often the token secret rotates.
const TOKEN_ROTATION: Duration = Duration::from_secs(5 * 60); // 5 minutes

/// Stores peers per info_hash and generates/validates announce tokens.
#[derive(Debug)]
pub struct PeerStore {
    /// Current secret for token generation.
    secret: [u8; 20],
    /// Previous secret (still valid for token validation).
    prev_secret: [u8; 20],
    /// When the current secret was created.
    secret_created: Instant,
    /// Peers per info_hash.
    peers: HashMap<Id20, Vec<StoredPeer>>,
}

#[derive(Debug, Clone)]
struct StoredPeer {
    addr: SocketAddr,
    added: Instant,
}

impl PeerStore {
    /// Create an empty peer store with a fresh token secret.
    pub fn new() -> Self {
        let secret = generate_secret();
        PeerStore {
            secret,
            prev_secret: secret,
            secret_created: Instant::now(),
            peers: HashMap::new(),
        }
    }

    /// Generate a token for the given IP address.
    pub fn generate_token(&mut self, ip: &IpAddr) -> Vec<u8> {
        self.maybe_rotate();
        make_token(&self.secret, ip)
    }

    /// Validate a token for the given IP address.
    pub fn validate_token(&mut self, token: &[u8], ip: &IpAddr) -> bool {
        self.maybe_rotate();
        let current = make_token(&self.secret, ip);
        let previous = make_token(&self.prev_secret, ip);
        token == current.as_slice() || token == previous.as_slice()
    }

    /// Add a peer for an info_hash.
    pub fn add_peer(&mut self, info_hash: Id20, addr: SocketAddr) {
        let peers = self.peers.entry(info_hash).or_default();
        // Update existing or add new
        if let Some(existing) = peers.iter_mut().find(|p| p.addr == addr) {
            existing.added = Instant::now();
        } else {
            peers.push(StoredPeer {
                addr,
                added: Instant::now(),
            });
        }
    }

    /// Get peers for an info_hash (up to `max` results).
    pub fn get_peers(&self, info_hash: &Id20, max: usize) -> Vec<SocketAddr> {
        self.peers
            .get(info_hash)
            .map(|peers| {
                let cutoff = Instant::now() - PEER_EXPIRY;
                peers
                    .iter()
                    .filter(|p| p.added > cutoff)
                    .take(max)
                    .map(|p| p.addr)
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Get all non-expired peer addresses for an info_hash (for bloom filter generation).
    pub fn all_peers(&self, info_hash: &Id20) -> Vec<SocketAddr> {
        self.peers
            .get(info_hash)
            .map(|peers| {
                let cutoff = Instant::now() - PEER_EXPIRY;
                peers
                    .iter()
                    .filter(|p| p.added > cutoff)
                    .map(|p| p.addr)
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Remove expired peers from all info_hashes.
    pub fn cleanup(&mut self) {
        let cutoff = Instant::now() - PEER_EXPIRY;
        self.peers.retain(|_, peers| {
            peers.retain(|p| p.added > cutoff);
            !peers.is_empty()
        });
    }

    /// Number of info_hashes with stored peers.
    pub fn info_hash_count(&self) -> usize {
        self.peers.len()
    }

    /// Total number of stored peers across all info_hashes.
    pub fn peer_count(&self) -> usize {
        self.peers.values().map(|p| p.len()).sum()
    }

    /// Return up to `max` randomly sampled info hashes from the store.
    ///
    /// Uses Fisher-Yates partial shuffle on the key vector.
    /// If the store has fewer than `max` info hashes, returns all of them.
    pub fn random_info_hashes(&self, max: usize) -> Vec<Id20> {
        let keys: Vec<Id20> = self.peers.keys().copied().collect();
        let count = keys.len().min(max);
        if count == 0 {
            return Vec::new();
        }
        if count == keys.len() {
            return keys;
        }

        // Partial Fisher-Yates shuffle using the thread-local xorshift
        let mut keys = keys;
        for i in 0..count {
            let j = i + (xorshift_next() as usize % (keys.len() - i));
            keys.swap(i, j);
        }
        keys.truncate(count);
        keys
    }

    fn maybe_rotate(&mut self) {
        if self.secret_created.elapsed() >= TOKEN_ROTATION {
            self.prev_secret = self.secret;
            self.secret = generate_secret();
            self.secret_created = Instant::now();
        }
    }
}

impl Default for PeerStore {
    fn default() -> Self {
        Self::new()
    }
}

fn make_token(secret: &[u8; 20], ip: &IpAddr) -> Vec<u8> {
    let ip_bytes = match ip {
        IpAddr::V4(v4) => v4.octets().to_vec(),
        IpAddr::V6(v6) => v6.octets().to_vec(),
    };
    let mut data = Vec::with_capacity(secret.len() + ip_bytes.len());
    data.extend_from_slice(secret);
    data.extend_from_slice(&ip_bytes);
    let hash = sha1(&data);
    hash.0[..8].to_vec() // 8-byte token is sufficient
}

/// Thread-local xorshift64 PRNG. Returns a random u64.
fn xorshift_next() -> u64 {
    use std::cell::Cell;
    use std::time::SystemTime;

    thread_local! {
        static STATE: Cell<u64> = Cell::new(
            SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .unwrap_or_default()
                .as_nanos() as u64
                ^ 0x517cc1b727220a95 // mix constant to avoid collisions with generate_secret
        );
    }

    STATE.with(|s| {
        let mut x = s.get();
        x ^= x << 13;
        x ^= x >> 7;
        x ^= x << 17;
        s.set(x);
        x
    })
}

fn generate_secret() -> [u8; 20] {
    use std::cell::Cell;
    use std::time::SystemTime;

    thread_local! {
        static STATE: Cell<u64> = Cell::new(
            SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .unwrap_or_default()
                .as_nanos() as u64
        );
    }

    let mut secret = [0u8; 20];
    for byte in &mut secret {
        STATE.with(|s| {
            let mut x = s.get();
            x ^= x << 13;
            x ^= x >> 7;
            x ^= x << 17;
            s.set(x);
            *byte = x as u8;
        });
    }
    secret
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn token_validates_same_ip() {
        let mut store = PeerStore::new();
        let ip: IpAddr = "192.168.1.1".parse().unwrap();
        let token = store.generate_token(&ip);
        assert!(store.validate_token(&token, &ip));
    }

    #[test]
    fn token_rejects_different_ip() {
        let mut store = PeerStore::new();
        let ip1: IpAddr = "192.168.1.1".parse().unwrap();
        let ip2: IpAddr = "10.0.0.1".parse().unwrap();
        let token = store.generate_token(&ip1);
        assert!(!store.validate_token(&token, &ip2));
    }

    #[test]
    fn add_and_get_peers() {
        let mut store = PeerStore::new();
        let hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
        let addr: SocketAddr = "10.0.0.1:6881".parse().unwrap();

        store.add_peer(hash, addr);
        let peers = store.get_peers(&hash, 10);
        assert_eq!(peers.len(), 1);
        assert_eq!(peers[0], addr);
    }

    #[test]
    fn get_peers_unknown_hash() {
        let store = PeerStore::new();
        let hash = Id20::ZERO;
        let peers = store.get_peers(&hash, 10);
        assert!(peers.is_empty());
    }

    #[test]
    fn duplicate_peer_updates() {
        let mut store = PeerStore::new();
        let hash = Id20::ZERO;
        let addr: SocketAddr = "10.0.0.1:6881".parse().unwrap();

        store.add_peer(hash, addr);
        store.add_peer(hash, addr); // duplicate
        assert_eq!(store.peer_count(), 1);
    }

    #[test]
    fn cleanup_preserves_recent_peers() {
        let mut store = PeerStore::new();
        let hash = Id20::ZERO;
        store.add_peer(hash, "10.0.0.1:6881".parse().unwrap());
        store.cleanup();
        assert_eq!(store.peer_count(), 1);
    }

    #[test]
    fn random_info_hashes_empty_store() {
        let store = PeerStore::new();
        let samples = store.random_info_hashes(20);
        assert!(samples.is_empty());
    }

    #[test]
    fn random_info_hashes_returns_up_to_max() {
        let mut store = PeerStore::new();
        for i in 0..5u8 {
            let mut hash_bytes = [0u8; 20];
            hash_bytes[0] = i;
            store.add_peer(
                Id20(hash_bytes),
                format!("10.0.0.{}:6881", i).parse().unwrap(),
            );
        }
        // Ask for fewer than total
        let samples = store.random_info_hashes(3);
        assert_eq!(samples.len(), 3);
        // Ask for more than total
        let samples = store.random_info_hashes(20);
        assert_eq!(samples.len(), 5);
    }

    #[test]
    fn random_info_hashes_all_valid() {
        let mut store = PeerStore::new();
        let mut expected = std::collections::HashSet::new();
        for i in 0..10u8 {
            let mut hash_bytes = [0u8; 20];
            hash_bytes[0] = i;
            let id = Id20(hash_bytes);
            expected.insert(id);
            store.add_peer(id, format!("10.0.0.{}:6881", i).parse().unwrap());
        }
        let samples = store.random_info_hashes(10);
        assert_eq!(samples.len(), 10);
        for sample in &samples {
            assert!(expected.contains(sample), "unexpected info hash in sample");
        }
    }
}