irontide-dht 1.0.1

Kademlia DHT for BitTorrent (BEP 5)
Documentation
#![allow(
    clippy::cast_possible_truncation,
    clippy::cast_possible_wrap,
    clippy::cast_sign_loss,
    reason = "M175: peer store — token bytes packed by spec; remaining narrow casts test fixtures"
)]

//! Per-info_hash peer storage and token generation/validation.
//!
//! Tokens are generated per-IP using a rotating secret so that
//! `announce_peer` requests can be validated without persistent state.

use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};

use irontide_core::{Id20, sha1};

/// How long peers are kept before expiry.
const PEER_EXPIRY: Duration = Duration::from_mins(30); // 30 minutes

/// How often the token secret rotates.
const TOKEN_ROTATION: Duration = Duration::from_mins(5); // 5 minutes

/// Stores peers per `info_hash` and generates/validates announce tokens.
#[derive(Debug)]
pub struct PeerStore {
    /// Current secret for token generation.
    secret: [u8; 20],
    /// Previous secret (still valid for token validation).
    prev_secret: [u8; 20],
    /// When the current secret was created.
    secret_created: Instant,
    /// Peers per `info_hash`.
    peers: HashMap<Id20, Vec<StoredPeer>>,
}

#[derive(Debug, Clone)]
struct StoredPeer {
    addr: SocketAddr,
    added: Instant,
}

impl PeerStore {
    /// Create an empty peer store with a fresh token secret.
    #[must_use]
    pub fn new() -> Self {
        let secret = generate_secret();
        Self {
            secret,
            prev_secret: secret,
            secret_created: Instant::now(),
            peers: HashMap::new(),
        }
    }

    /// Generate a token for the given IP address.
    pub fn generate_token(&mut self, ip: &IpAddr) -> Vec<u8> {
        self.maybe_rotate();
        make_token(&self.secret, ip)
    }

    /// Validate a token for the given IP address.
    pub fn validate_token(&mut self, token: &[u8], ip: &IpAddr) -> bool {
        self.maybe_rotate();
        let current = make_token(&self.secret, ip);
        let previous = make_token(&self.prev_secret, ip);
        token == current.as_slice() || token == previous.as_slice()
    }

    /// Add a peer for an `info_hash`.
    pub fn add_peer(&mut self, info_hash: Id20, addr: SocketAddr) {
        let peers = self.peers.entry(info_hash).or_default();
        // Update existing or add new
        if let Some(existing) = peers.iter_mut().find(|p| p.addr == addr) {
            existing.added = Instant::now();
        } else {
            peers.push(StoredPeer {
                addr,
                added: Instant::now(),
            });
        }
    }

    /// Get peers for an `info_hash` (up to `max` results).
    ///
    /// M175 BUG FIX: previously used `Instant::now() - PEER_EXPIRY`, which
    /// panics when process uptime < `PEER_EXPIRY` (typically 30 min). Compare
    /// elapsed time forward via `Instant::duration_since` instead.
    #[must_use]
    pub fn get_peers(&self, info_hash: &Id20, max: usize) -> Vec<SocketAddr> {
        self.peers
            .get(info_hash)
            .map(|peers| {
                let now = Instant::now();
                peers
                    .iter()
                    .filter(|p| now.duration_since(p.added) <= PEER_EXPIRY)
                    .take(max)
                    .map(|p| p.addr)
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Get all non-expired peer addresses for an `info_hash` (for bloom filter generation).
    #[must_use]
    pub fn all_peers(&self, info_hash: &Id20) -> Vec<SocketAddr> {
        self.peers
            .get(info_hash)
            .map(|peers| {
                let now = Instant::now();
                peers
                    .iter()
                    .filter(|p| now.duration_since(p.added) <= PEER_EXPIRY)
                    .map(|p| p.addr)
                    .collect()
            })
            .unwrap_or_default()
    }

    /// Remove expired peers from all `info_hashes`.
    pub fn cleanup(&mut self) {
        let now = Instant::now();
        self.peers.retain(|_, peers| {
            peers.retain(|p| now.duration_since(p.added) <= PEER_EXPIRY);
            !peers.is_empty()
        });
    }

    /// Number of `info_hashes` with stored peers.
    #[must_use]
    pub fn info_hash_count(&self) -> usize {
        self.peers.len()
    }

    /// Total number of stored peers across all `info_hashes`.
    #[must_use]
    pub fn peer_count(&self) -> usize {
        self.peers.values().map(std::vec::Vec::len).sum()
    }

    /// Return up to `max` randomly sampled info hashes from the store.
    ///
    /// Uses Fisher-Yates partial shuffle on the key vector.
    /// If the store has fewer than `max` info hashes, returns all of them.
    #[must_use]
    pub fn random_info_hashes(&self, max: usize) -> Vec<Id20> {
        let keys: Vec<Id20> = self.peers.keys().copied().collect();
        let count = keys.len().min(max);
        if count == 0 {
            return Vec::new();
        }
        if count == keys.len() {
            return keys;
        }

        // Partial Fisher-Yates shuffle using the thread-local xorshift
        let mut keys = keys;
        for i in 0..count {
            let j = i + (xorshift_next() as usize % (keys.len() - i));
            keys.swap(i, j);
        }
        keys.truncate(count);
        keys
    }

    fn maybe_rotate(&mut self) {
        if self.secret_created.elapsed() >= TOKEN_ROTATION {
            self.prev_secret = self.secret;
            self.secret = generate_secret();
            self.secret_created = Instant::now();
        }
    }
}

impl Default for PeerStore {
    fn default() -> Self {
        Self::new()
    }
}

fn make_token(secret: &[u8; 20], ip: &IpAddr) -> Vec<u8> {
    let ip_bytes = match ip {
        IpAddr::V4(v4) => v4.octets().to_vec(),
        IpAddr::V6(v6) => v6.octets().to_vec(),
    };
    let mut data = Vec::with_capacity(secret.len() + ip_bytes.len());
    data.extend_from_slice(secret);
    data.extend_from_slice(&ip_bytes);
    let hash = sha1(&data);
    hash.0[..8].to_vec() // 8-byte token is sufficient
}

/// Thread-local xorshift64 PRNG. Returns a random u64.
fn xorshift_next() -> u64 {
    use std::cell::Cell;
    use std::time::SystemTime;

    thread_local! {
        static STATE: Cell<u64> = Cell::new(
            SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .unwrap_or_default()
                .as_nanos() as u64
                ^ 0x517c_c1b7_2722_0a95 // mix constant to avoid collisions with generate_secret
        );
    }

    STATE.with(|s| {
        let mut x = s.get();
        x ^= x << 13;
        x ^= x >> 7;
        x ^= x << 17;
        s.set(x);
        x
    })
}

fn generate_secret() -> [u8; 20] {
    use std::cell::Cell;
    use std::time::SystemTime;

    thread_local! {
        static STATE: Cell<u64> = Cell::new(
            SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .unwrap_or_default()
                .as_nanos() as u64
        );
    }

    let mut secret = [0u8; 20];
    for byte in &mut secret {
        STATE.with(|s| {
            let mut x = s.get();
            x ^= x << 13;
            x ^= x >> 7;
            x ^= x << 17;
            s.set(x);
            *byte = x as u8;
        });
    }
    secret
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn token_validates_same_ip() {
        let mut store = PeerStore::new();
        let ip: IpAddr = "192.168.1.1".parse().unwrap();
        let token = store.generate_token(&ip);
        assert!(store.validate_token(&token, &ip));
    }

    #[test]
    fn token_rejects_different_ip() {
        let mut store = PeerStore::new();
        let ip1: IpAddr = "192.168.1.1".parse().unwrap();
        let ip2: IpAddr = "10.0.0.1".parse().unwrap();
        let token = store.generate_token(&ip1);
        assert!(!store.validate_token(&token, &ip2));
    }

    /// M175 regression: `get_peers` / `all_peers` / `cleanup` previously used
    /// `Instant::now() - PEER_EXPIRY` (typically 30 min), which panics when
    /// process uptime is below the threshold. Test process uptime is
    /// near zero — these calls must not panic.
    #[test]
    fn peer_store_queries_do_not_panic_on_fresh_process() {
        let mut store = PeerStore::new();
        let hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
        let addr: SocketAddr = "10.0.0.1:6881".parse().unwrap();
        store.add_peer(hash, addr);

        // Each call would have panicked under the old `Instant::now() - PEER_EXPIRY` form.
        let _ = store.get_peers(&hash, 10);
        let _ = store.all_peers(&hash);
        store.cleanup();
    }

    #[test]
    fn add_and_get_peers() {
        let mut store = PeerStore::new();
        let hash = Id20::from_hex("aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d").unwrap();
        let addr: SocketAddr = "10.0.0.1:6881".parse().unwrap();

        store.add_peer(hash, addr);
        let peers = store.get_peers(&hash, 10);
        assert_eq!(peers.len(), 1);
        assert_eq!(peers[0], addr);
    }

    #[test]
    fn get_peers_unknown_hash() {
        let store = PeerStore::new();
        let hash = Id20::ZERO;
        let peers = store.get_peers(&hash, 10);
        assert!(peers.is_empty());
    }

    #[test]
    fn duplicate_peer_updates() {
        let mut store = PeerStore::new();
        let hash = Id20::ZERO;
        let addr: SocketAddr = "10.0.0.1:6881".parse().unwrap();

        store.add_peer(hash, addr);
        store.add_peer(hash, addr); // duplicate
        assert_eq!(store.peer_count(), 1);
    }

    #[test]
    fn cleanup_preserves_recent_peers() {
        let mut store = PeerStore::new();
        let hash = Id20::ZERO;
        store.add_peer(hash, "10.0.0.1:6881".parse().unwrap());
        store.cleanup();
        assert_eq!(store.peer_count(), 1);
    }

    #[test]
    fn random_info_hashes_empty_store() {
        let store = PeerStore::new();
        let samples = store.random_info_hashes(20);
        assert!(samples.is_empty());
    }

    #[test]
    fn random_info_hashes_returns_up_to_max() {
        let mut store = PeerStore::new();
        for i in 0..5u8 {
            let mut hash_bytes = [0u8; 20];
            hash_bytes[0] = i;
            store.add_peer(
                Id20(hash_bytes),
                format!("10.0.0.{i}:6881").parse().unwrap(),
            );
        }
        // Ask for fewer than total
        let samples = store.random_info_hashes(3);
        assert_eq!(samples.len(), 3);
        // Ask for more than total
        let samples = store.random_info_hashes(20);
        assert_eq!(samples.len(), 5);
    }

    #[test]
    fn random_info_hashes_all_valid() {
        let mut store = PeerStore::new();
        let mut expected = std::collections::HashSet::new();
        for i in 0..10u8 {
            let mut hash_bytes = [0u8; 20];
            hash_bytes[0] = i;
            let id = Id20(hash_bytes);
            expected.insert(id);
            store.add_peer(id, format!("10.0.0.{i}:6881").parse().unwrap());
        }
        let samples = store.random_info_hashes(10);
        assert_eq!(samples.len(), 10);
        for sample in &samples {
            assert!(expected.contains(sample), "unexpected info hash in sample");
        }
    }
}