libsession 0.1.3

Session messenger core library - cryptography, config management, networking
Documentation
//! Snode pool management: caching, strike tracking, swarm lookups, and disk persistence.
//!
//! Port of `session::network::SnodePool` data structures.
//! The async networking (refresh logic) is omitted since it needs live integration testing.

use std::collections::HashMap;
use std::path::PathBuf;

use crate::network::key_types::Ed25519Pubkey;
use crate::network::service_node::ServiceNode;
use crate::network::swarm::{self, SwarmId};

/// Configuration for the snode pool, extracted from NetworkConfig.
#[derive(Debug, Clone)]
pub struct SnodePoolConfig {
    pub cache_directory: Option<PathBuf>,
    pub fallback_snode_pool_path: Option<PathBuf>,
    pub cache_expiration_secs: u64,
    pub cache_min_lifetime_ms: u64,
    pub enforce_subnet_diversity: bool,
    pub cache_min_size: usize,
    pub cache_min_swarm_size: usize,
    pub cache_num_nodes_to_use_for_refresh: u8,
    pub cache_min_num_refresh_presence_to_include_node: u8,
    pub cache_node_strike_threshold: u16,
}

impl Default for SnodePoolConfig {
    fn default() -> Self {
        Self {
            cache_directory: None,
            fallback_snode_pool_path: None,
            cache_expiration_secs: 2 * 60 * 60,
            cache_min_lifetime_ms: 2000,
            enforce_subnet_diversity: true,
            cache_min_size: 12,
            cache_min_swarm_size: 3,
            cache_num_nodes_to_use_for_refresh: 3,
            cache_min_num_refresh_presence_to_include_node: 2,
            cache_node_strike_threshold: 3,
        }
    }
}

/// The snode pool: maintains a cache of service nodes with strike tracking
/// and swarm caching.
pub struct SnodePool {
    config: SnodePoolConfig,
    snode_cache: Vec<ServiceNode>,
    all_swarms: Vec<(SwarmId, Vec<ServiceNode>)>,
    snode_strikes: HashMap<Ed25519Pubkey, u16>,
    cache_file_path: Option<PathBuf>,
}

impl SnodePool {
    pub fn new(config: SnodePoolConfig) -> Self {
        let cache_file_path = config
            .cache_directory
            .as_ref()
            .map(|dir| dir.join("snode_pool_cache"));

        Self {
            config,
            snode_cache: Vec::new(),
            all_swarms: Vec::new(),
            snode_strikes: HashMap::new(),
            cache_file_path,
        }
    }

    /// Returns the number of nodes in the pool.
    pub fn size(&self) -> usize {
        self.snode_cache.len()
    }

    /// Returns true if the pool is empty.
    pub fn is_empty(&self) -> bool {
        self.snode_cache.is_empty()
    }

    /// Adds nodes to the cache and regenerates swarms.
    pub fn add_nodes(&mut self, nodes: Vec<ServiceNode>) {
        // Deduplicate by pubkey
        for node in nodes {
            if !self
                .snode_cache
                .iter()
                .any(|n| n.ed25519_pubkey == node.ed25519_pubkey)
            {
                self.snode_cache.push(node);
            }
        }
        self.regenerate_swarms();
    }

    /// Replaces the entire cache with the given nodes.
    pub fn set_nodes(&mut self, nodes: Vec<ServiceNode>) {
        self.snode_cache = nodes;
        self.regenerate_swarms();
    }

    /// Clears the cache entirely.
    pub fn clear_cache(&mut self) {
        self.snode_cache.clear();
        self.all_swarms.clear();
    }

    /// Gets random unused nodes from the pool, excluding the given list.
    pub fn get_unused_nodes(
        &self,
        count: usize,
        exclude: &[ServiceNode],
    ) -> Vec<ServiceNode> {
        use rand::seq::SliceRandom;

        let available: Vec<&ServiceNode> = self
            .snode_cache
            .iter()
            .filter(|n| !exclude.iter().any(|e| e.ed25519_pubkey == n.ed25519_pubkey))
            .collect();

        let mut rng = rand::rng();
        let mut shuffled: Vec<&ServiceNode> = available;
        shuffled.shuffle(&mut rng);

        shuffled
            .into_iter()
            .take(count)
            .cloned()
            .collect()
    }

    /// Gets the swarm for a given pubkey.
    pub fn get_swarm(
        &self,
        swarm_pubkey: &crate::network::key_types::X25519Pubkey,
    ) -> Option<(SwarmId, Vec<ServiceNode>)> {
        if self.all_swarms.is_empty() {
            return None;
        }
        swarm::get_swarm(swarm_pubkey, &self.all_swarms)
    }

    // -----------------------------------------------------------------------
    // Strike tracking
    // -----------------------------------------------------------------------

    /// Records a failure for a node. If `permanent`, removes it from the cache.
    pub fn record_node_failure(&mut self, pubkey: &Ed25519Pubkey, permanent: bool) {
        if permanent {
            self.snode_cache
                .retain(|n| n.ed25519_pubkey != *pubkey);
            self.snode_strikes.remove(pubkey);
            self.regenerate_swarms();
            return;
        }

        let count = self.snode_strikes.entry(*pubkey).or_insert(0);
        *count += 1;

        if *count >= self.config.cache_node_strike_threshold {
            self.snode_cache
                .retain(|n| n.ed25519_pubkey != *pubkey);
            self.snode_strikes.remove(pubkey);
            self.regenerate_swarms();
        }
    }

    /// Returns the strike count for a node.
    pub fn node_strike_count(&self, pubkey: &Ed25519Pubkey) -> u16 {
        self.snode_strikes.get(pubkey).copied().unwrap_or(0)
    }

    /// Clears all strike records.
    pub fn clear_node_strikes(&mut self) {
        self.snode_strikes.clear();
    }

    // -----------------------------------------------------------------------
    // Disk persistence
    // -----------------------------------------------------------------------

    /// Saves the cache to disk in pipe-delimited format.
    pub fn save_to_disk(&self) -> std::io::Result<()> {
        let path = match &self.cache_file_path {
            Some(p) => p,
            None => return Ok(()),
        };

        let mut content = String::new();
        for node in &self.snode_cache {
            content.push_str(&node.to_disk());
        }

        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        std::fs::write(path, content)?;
        Ok(())
    }

    /// Loads the cache from disk.
    pub fn load_from_disk(&mut self) -> std::io::Result<()> {
        let path = match &self.cache_file_path {
            Some(p) => p.clone(),
            None => return Ok(()),
        };

        if !path.exists() {
            return Ok(());
        }

        let content = std::fs::read_to_string(&path)?;
        let mut nodes = Vec::new();

        for line in content.lines() {
            if line.trim().is_empty() {
                continue;
            }
            match ServiceNode::from_disk(line) {
                Ok(node) => nodes.push(node),
                Err(e) => {
                    tracing::warn!("Failed to parse cached snode: {}", e);
                }
            }
        }

        self.snode_cache = nodes;
        self.regenerate_swarms();
        Ok(())
    }

    // -----------------------------------------------------------------------
    // Internal helpers
    // -----------------------------------------------------------------------

    fn regenerate_swarms(&mut self) {
        self.all_swarms = swarm::generate_swarms(&self.snode_cache);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::network::key_types::Ed25519Pubkey;

    fn make_node(id: u8, swarm_id: SwarmId) -> ServiceNode {
        let mut pk = [0u8; 32];
        pk[0] = id;
        ServiceNode {
            ed25519_pubkey: Ed25519Pubkey(pk),
            ip: [1, 2, 3, id],
            https_port: 443,
            omq_port: 22000,
            storage_server_version: [2, 11, 0],
            swarm_id,
            requested_unlock_height: 0,
        }
    }

    #[test]
    fn test_add_and_size() {
        let mut pool = SnodePool::new(SnodePoolConfig::default());
        assert!(pool.is_empty());

        pool.add_nodes(vec![make_node(1, 100), make_node(2, 100)]);
        assert_eq!(pool.size(), 2);

        // Adding duplicate should not increase size
        pool.add_nodes(vec![make_node(1, 100)]);
        assert_eq!(pool.size(), 2);
    }

    #[test]
    fn test_get_unused_nodes() {
        let mut pool = SnodePool::new(SnodePoolConfig::default());
        pool.add_nodes(vec![
            make_node(1, 100),
            make_node(2, 100),
            make_node(3, 200),
        ]);

        let exclude = vec![make_node(1, 100)];
        let unused = pool.get_unused_nodes(10, &exclude);
        assert_eq!(unused.len(), 2);
        assert!(
            unused
                .iter()
                .all(|n| n.ed25519_pubkey != make_node(1, 100).ed25519_pubkey)
        );
    }

    #[test]
    fn test_strike_tracking() {
        let mut pool = SnodePool::new(SnodePoolConfig {
            cache_node_strike_threshold: 3,
            ..Default::default()
        });

        let node = make_node(1, 100);
        pool.add_nodes(vec![node.clone()]);

        pool.record_node_failure(&node.ed25519_pubkey, false);
        assert_eq!(pool.node_strike_count(&node.ed25519_pubkey), 1);
        assert_eq!(pool.size(), 1);

        pool.record_node_failure(&node.ed25519_pubkey, false);
        assert_eq!(pool.node_strike_count(&node.ed25519_pubkey), 2);
        assert_eq!(pool.size(), 1);

        // Third strike removes the node
        pool.record_node_failure(&node.ed25519_pubkey, false);
        assert_eq!(pool.size(), 0);
    }

    #[test]
    fn test_permanent_failure() {
        let mut pool = SnodePool::new(SnodePoolConfig::default());
        let node = make_node(1, 100);
        pool.add_nodes(vec![node.clone()]);

        pool.record_node_failure(&node.ed25519_pubkey, true);
        assert_eq!(pool.size(), 0);
    }

    #[test]
    fn test_clear_cache() {
        let mut pool = SnodePool::new(SnodePoolConfig::default());
        pool.add_nodes(vec![make_node(1, 100), make_node(2, 200)]);
        assert_eq!(pool.size(), 2);

        pool.clear_cache();
        assert!(pool.is_empty());
    }

    #[test]
    fn test_clear_strikes() {
        let mut pool = SnodePool::new(SnodePoolConfig::default());
        let node = make_node(1, 100);
        pool.add_nodes(vec![node.clone()]);
        pool.record_node_failure(&node.ed25519_pubkey, false);
        assert_eq!(pool.node_strike_count(&node.ed25519_pubkey), 1);

        pool.clear_node_strikes();
        assert_eq!(pool.node_strike_count(&node.ed25519_pubkey), 0);
    }
}