gcache 0.0.1

A cache group to accurate remote data access
Documentation
use std::collections::hash_map;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;

pub struct ConsistentHash<K: Hash, H: Hasher = hash_map::DefaultHasher> {
    ring: Ring<K, H>,
}

#[derive(Hash)]
struct VitualNode<K: Hash>(u8, K);

#[derive(Hash)]
struct RingEntry<K: Hash> {
    sum: u64,
    node: VitualNode<K>,
}

struct Ring<K: Hash, H: Hasher> {
    _h: Option<PhantomData<H>>,
    pub ring: Vec<RingEntry<K>>,
}

impl<K: Hash + Ord, H: Hasher + Default> ConsistentHash<K, H> {
    pub fn new() -> Self {
        Self { ring: Ring::new() }
    }
}

impl<K: Hash + Ord + Clone, H: Hasher + Default> ConsistentHash<K, H> {
    pub fn add(&mut self, k: K, replicas: u8) {
        for id in 0..replicas {
            self.ring.add(VitualNode(id, k.clone()))
        }
    }

    pub fn remove<Q: ?Sized + PartialEq<K>>(&mut self, q: &Q) {
        self.ring.remove(q)
    }

    pub fn pick<Q: Hash + ?Sized>(&self, q: &Q) -> &K {
        self.ring.pick_node(q)
    }
}

impl<K: Hash> VitualNode<K> {
    fn new<Q: Into<K>>(id: u8, k: Q) -> Self {
        Self(id, k.into())
    }
}

impl<K: Hash + Ord, H: Hasher + Default> Ring<K, H> {
    pub fn new() -> Self {
        Self {
            _h: None,
            ring: Vec::new(),
        }
    }

    pub fn add(&mut self, node: VitualNode<K>) {
        let mut h = H::default();
        node.hash(&mut h);
        let sum = h.finish();
        let index = self.find(sum);
        let new_entry = RingEntry { sum, node };
        if index == self.ring.len() {
            self.ring.push(new_entry);
            return;
        }
        if self.ring[index].sum == sum {
            if new_entry.node.1 < self.ring[index].node.1 {
                self.ring.insert(index, new_entry);
            }
            return;
        }
        let ring_part = self.ring.drain(index..).collect::<Vec<_>>();
        self.ring.push(new_entry);
        self.ring.extend(ring_part);
    }

    pub fn remove<Q: ?Sized + PartialEq<K>>(&mut self, k: &Q) {
        let mut deleted = 0;
        for i in 0..self.ring.len() {
            if k == &self.ring[i].node.1 {
                deleted += 1;
                continue;
            }
            self.ring.swap(i - deleted, i);
        }
        self.ring.truncate(self.ring.len() - deleted);
    }

    pub fn pick_node<Q: Hash + ?Sized>(&self, q: &Q) -> &K {
        let mut h = H::default();
        q.hash(&mut h);
        let sum = h.finish();
        let i = self.find(sum) % self.ring.len();
        &self.ring[i].node.1
    }

    fn find(&self, sum: u64) -> usize {
        if self.ring.len() == 0 {
            return 0;
        }
        let mut start = 0;
        let mut end = self.ring.len();
        while start < end {
            let mid = (start + end) / 2;
            let v = self.ring[mid].sum;
            if v == sum {
                return mid;
            } else if v < sum {
                start = mid + 1
            } else if v > sum {
                end = mid
            }
        }
        return start;
    }
}

#[cfg(test)]
mod tests {
    use std::collections::hash_map::DefaultHasher;
    use std::collections::HashMap;

    use bytes::Bytes;

    use super::{ConsistentHash, Ring, RingEntry, VitualNode};

    #[test]
    fn test_ring_find() {
        let ring = Ring::<String, DefaultHasher> {
            _h: None,
            ring: vec![
                RingEntry {
                    sum: 1,
                    node: VitualNode::new(0, ""),
                },
                RingEntry {
                    sum: 2,
                    node: VitualNode::new(0, String::new()),
                },
                RingEntry {
                    sum: 4,
                    node: VitualNode::new(0, ""),
                },
                RingEntry {
                    sum: 6,
                    node: VitualNode::new(0, ""),
                },
            ],
        };
        assert_eq!(ring.find(0), 0);
        assert_eq!(ring.find(1), 0);
        assert_eq!(ring.find(2), 1);
        assert_eq!(ring.find(3), 2);
        assert_eq!(ring.find(4), 2);
        assert_eq!(ring.find(5), 3);
        assert_eq!(ring.find(6), 3);
        assert_eq!(ring.find(7), 4);
        assert_eq!(ring.find(100), 4);
    }

    #[test]
    fn test_ring() {
        let mut ring = Ring::<String, DefaultHasher>::new();
        ring.add(VitualNode::new(0, "8080"));
        ring.add(VitualNode::new(1, "8080"));
        ring.add(VitualNode::new(0, "8081"));
        ring.add(VitualNode::new(1, "8081"));
        ring.add(VitualNode::new(1, "8082"));
        ring.add(VitualNode::new(0, "8082"));
        ring.add(VitualNode::new(0, "8082"));
        assert_eq!(ring.ring.len(), 6);
        for i in 0..ring.ring.len() - 1 {
            assert!(ring.ring[i].sum < ring.ring[i + 1].sum);
        }
        ring.remove("8080");
        assert_eq!(ring.ring.len(), 4);
        ring.remove("8081");
        assert_eq!(ring.ring.len(), 2);
        ring.remove("8082");
        assert_eq!(ring.ring.len(), 0);
    }

    #[test]
    fn test_consistent_hash() {
        let mut map = ConsistentHash::<Bytes>::new();
        map.add("8080".into(), 255);
        map.add("8081".into(), 255);
        map.add("8082".into(), 255);
        assert_eq!(map.ring.ring.len(), 255 * 3);
        let mut count_map = [b"8080", b"8081", b"8082"]
            .into_iter()
            .map(|node| (Bytes::from_static(node), 0))
            .collect::<HashMap<_, _>>();
        let round = 1000000;
        for i in 1..round {
            if let Some(c) = count_map.get_mut(map.pick(&i)) {
                *c += 1
            }
        }
        for (n, &c) in count_map.iter() {
            let node = std::str::from_utf8(n).unwrap();
            assert!(c > round * 3 / 10, "count of {}: {}", node, c);
            assert!(c < round * 11 / 30, "count of {}: {}", node, c);
        }
        map.remove("8082");
        for (_, c) in count_map.iter_mut() {
            *c = 0
        }
        for i in 1..round {
            if let Some(c) = count_map.get_mut(map.pick(&i)) {
                *c += 1
            }
        }
        for (n, &c) in count_map.iter() {
            let node = std::str::from_utf8(n).unwrap();
            if node == "8082" {
                assert_eq!(c, 0)
            } else {
                assert!(c > round * 9 / 20, "count of {}: {}", node, c);
                assert!(c < round * 11 / 20, "count of {}: {}", node, c);
            }
        }
    }
}