pollen-router 0.1.0

Consistent hashing task router for Pollen
Documentation
//! Consistent hash ring implementation.

use pollen_types::NodeId;
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};

/// Consistent hash ring with virtual nodes.
pub struct HashRing {
    /// Ring: position -> node_id
    ring: BTreeMap<u64, NodeId>,
    /// Number of virtual nodes per physical node.
    replicas: usize,
    /// Physical nodes.
    nodes: Vec<NodeId>,
}

impl HashRing {
    /// Create a new hash ring.
    pub fn new(replicas: usize) -> Self {
        Self {
            ring: BTreeMap::new(),
            replicas,
            nodes: Vec::new(),
        }
    }

    /// Add a node to the ring.
    pub fn add(&mut self, node: NodeId) {
        if self.nodes.contains(&node) {
            return;
        }

        self.nodes.push(node);

        for i in 0..self.replicas {
            let key = format!("{}:{}", node, i);
            let hash = self.hash(key.as_bytes());
            self.ring.insert(hash, node);
        }
    }

    /// Remove a node from the ring.
    pub fn remove(&mut self, node: NodeId) {
        self.nodes.retain(|n| *n != node);

        for i in 0..self.replicas {
            let key = format!("{}:{}", node, i);
            let hash = self.hash(key.as_bytes());
            self.ring.remove(&hash);
        }
    }

    /// Clear all nodes from the ring.
    pub fn clear(&mut self) {
        self.ring.clear();
        self.nodes.clear();
    }

    /// Get the node responsible for a key.
    pub fn get(&self, key: &[u8]) -> Option<&NodeId> {
        if self.ring.is_empty() {
            return None;
        }

        let hash = self.hash(key);

        // Find the first node with position >= hash
        if let Some((_, node)) = self.ring.range(hash..).next() {
            return Some(node);
        }

        // Wrap around to the first node
        self.ring.values().next()
    }

    /// Get N unique nodes for a key (for replication).
    pub fn get_n(&self, key: &[u8], n: usize) -> Vec<NodeId> {
        if self.ring.is_empty() || n == 0 {
            return vec![];
        }

        let hash = self.hash(key);
        let mut result = Vec::with_capacity(n.min(self.nodes.len()));
        let mut seen = std::collections::HashSet::new();

        // Start from the hash position and walk around the ring
        for (_, node) in self.ring.range(hash..).chain(self.ring.range(..hash)) {
            if seen.insert(*node) {
                result.push(*node);
                if result.len() >= n || result.len() >= self.nodes.len() {
                    break;
                }
            }
        }

        result
    }

    /// Check if the ring is empty.
    pub fn is_empty(&self) -> bool {
        self.ring.is_empty()
    }

    /// Get the number of physical nodes.
    pub fn len(&self) -> usize {
        self.nodes.len()
    }

    /// Get all physical nodes.
    pub fn nodes(&self) -> &[NodeId] {
        &self.nodes
    }

    /// Hash a key to a ring position.
    fn hash(&self, key: &[u8]) -> u64 {
        use std::collections::hash_map::DefaultHasher;
        let mut hasher = DefaultHasher::new();
        key.hash(&mut hasher);
        hasher.finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_empty_ring() {
        let ring = HashRing::new(10);
        assert!(ring.get(b"test").is_none());
        assert!(ring.is_empty());
    }

    #[test]
    fn test_single_node() {
        let mut ring = HashRing::new(10);
        let node = NodeId::from_raw(1);
        ring.add(node);

        // All keys should map to the single node
        assert_eq!(ring.get(b"key1"), Some(&node));
        assert_eq!(ring.get(b"key2"), Some(&node));
        assert_eq!(ring.get(b"key3"), Some(&node));
    }

    #[test]
    fn test_multiple_nodes() {
        let mut ring = HashRing::new(100);
        let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();

        for node in &nodes {
            ring.add(*node);
        }

        assert_eq!(ring.len(), 3);

        // Keys should be distributed across nodes
        let mut distribution = std::collections::HashMap::new();
        for i in 0..1000 {
            let key = format!("key{}", i);
            if let Some(node) = ring.get(key.as_bytes()) {
                *distribution.entry(*node).or_insert(0) += 1;
            }
        }

        // Each node should have some keys
        for node in &nodes {
            assert!(distribution.get(node).unwrap_or(&0) > &0);
        }
    }

    #[test]
    fn test_get_n() {
        let mut ring = HashRing::new(100);
        let nodes: Vec<_> = (1..=5).map(NodeId::from_raw).collect();

        for node in &nodes {
            ring.add(*node);
        }

        // Should get 3 unique nodes
        let replicas = ring.get_n(b"test", 3);
        assert_eq!(replicas.len(), 3);

        // All should be unique
        let unique: std::collections::HashSet<_> = replicas.iter().collect();
        assert_eq!(unique.len(), 3);
    }

    #[test]
    fn test_node_removal() {
        let mut ring = HashRing::new(100);
        let node1 = NodeId::from_raw(1);
        let node2 = NodeId::from_raw(2);

        ring.add(node1);
        ring.add(node2);

        // Remove node1
        ring.remove(node1);
        assert_eq!(ring.len(), 1);

        // All keys should now map to node2
        assert_eq!(ring.get(b"test"), Some(&node2));
    }

    #[test]
    fn test_consistent_hashing() {
        let mut ring = HashRing::new(100);
        let nodes: Vec<_> = (1..=3).map(NodeId::from_raw).collect();

        for node in &nodes {
            ring.add(*node);
        }

        // Record where keys are assigned
        let mut assignments: std::collections::HashMap<String, NodeId> = std::collections::HashMap::new();
        for i in 0..100 {
            let key = format!("key{}", i);
            if let Some(node) = ring.get(key.as_bytes()) {
                assignments.insert(key, *node);
            }
        }

        // Add a new node
        ring.add(NodeId::from_raw(4));

        // Most keys should stay on the same node
        let mut unchanged = 0;
        for (key, old_node) in &assignments {
            if let Some(new_node) = ring.get(key.as_bytes()) {
                if *new_node == *old_node {
                    unchanged += 1;
                }
            }
        }

        // At least 70% should be unchanged (typically much higher)
        assert!(unchanged >= 70);
    }
}