Skip to main content

pollen_crdt/
merkle.rs

1//! Merkle tree for anti-entropy synchronization.
2
3use bytes::Bytes;
4use std::collections::BTreeMap;
5
6/// Simple Merkle tree for detecting data differences.
7pub struct MerkleTree {
8    /// Leaf nodes (key -> hash).
9    leaves: BTreeMap<String, [u8; 32]>,
10    /// Number of levels in the tree.
11    levels: usize,
12}
13
14impl MerkleTree {
15    /// Create a new empty Merkle tree.
16    pub fn new() -> Self {
17        Self {
18            leaves: BTreeMap::new(),
19            levels: 16, // 16 levels = 65536 buckets
20        }
21    }
22
23    /// Insert or update a key-value pair.
24    pub fn insert(&mut self, key: &str, value: &[u8]) {
25        let hash = Self::hash(value);
26        self.leaves.insert(key.to_string(), hash);
27    }
28
29    /// Remove a key.
30    pub fn remove(&mut self, key: &str) {
31        self.leaves.remove(key);
32    }
33
34    /// Get the root hash.
35    pub fn root_hash(&self) -> Bytes {
36        if self.leaves.is_empty() {
37            return Bytes::from_static(&[0u8; 32]);
38        }
39
40        let hashes: Vec<_> = self.leaves.values().cloned().collect();
41        let combined = Self::combine_hashes(&hashes);
42        Bytes::copy_from_slice(&combined)
43    }
44
45    /// Get hashes for a specific level (for incremental sync).
46    pub fn level_hashes(&self, level: usize) -> Vec<(String, Bytes)> {
47        if level >= self.levels {
48            return vec![];
49        }
50
51        let bucket_count = 1 << level;
52        let mut buckets: Vec<Vec<[u8; 32]>> = vec![vec![]; bucket_count];
53
54        for (key, hash) in &self.leaves {
55            let bucket = Self::key_to_bucket(key, bucket_count);
56            buckets[bucket].push(*hash);
57        }
58
59        buckets
60            .into_iter()
61            .enumerate()
62            .map(|(i, hashes)| {
63                let combined = if hashes.is_empty() {
64                    [0u8; 32]
65                } else {
66                    Self::combine_hashes(&hashes)
67                };
68                (format!("{:x}", i), Bytes::copy_from_slice(&combined))
69            })
70            .collect()
71    }
72
73    /// Find keys in a specific range.
74    pub fn keys_in_range(&self, start: &str, end: &str) -> Vec<String> {
75        self.leaves
76            .range(start.to_string()..end.to_string())
77            .map(|(k, _)| k.clone())
78            .collect()
79    }
80
81    /// Get all keys.
82    pub fn keys(&self) -> Vec<String> {
83        self.leaves.keys().cloned().collect()
84    }
85
86    /// Hash a value using a simple hash function.
87    fn hash(data: &[u8]) -> [u8; 32] {
88        use std::collections::hash_map::DefaultHasher;
89        use std::hash::{Hash, Hasher};
90
91        let mut hasher = DefaultHasher::new();
92        data.hash(&mut hasher);
93        let h1 = hasher.finish();
94
95        let mut hasher2 = DefaultHasher::new();
96        h1.hash(&mut hasher2);
97        let h2 = hasher2.finish();
98
99        let mut hasher3 = DefaultHasher::new();
100        h2.hash(&mut hasher3);
101        let h3 = hasher3.finish();
102
103        let mut hasher4 = DefaultHasher::new();
104        h3.hash(&mut hasher4);
105        let h4 = hasher4.finish();
106
107        let mut result = [0u8; 32];
108        result[0..8].copy_from_slice(&h1.to_le_bytes());
109        result[8..16].copy_from_slice(&h2.to_le_bytes());
110        result[16..24].copy_from_slice(&h3.to_le_bytes());
111        result[24..32].copy_from_slice(&h4.to_le_bytes());
112        result
113    }
114
115    /// Combine multiple hashes into one.
116    fn combine_hashes(hashes: &[[u8; 32]]) -> [u8; 32] {
117        if hashes.is_empty() {
118            return [0u8; 32];
119        }
120
121        let mut combined = hashes[0];
122        for hash in &hashes[1..] {
123            for i in 0..32 {
124                combined[i] ^= hash[i];
125            }
126        }
127        combined
128    }
129
130    /// Map a key to a bucket index.
131    fn key_to_bucket(key: &str, bucket_count: usize) -> usize {
132        use std::collections::hash_map::DefaultHasher;
133        use std::hash::{Hash, Hasher};
134
135        let mut hasher = DefaultHasher::new();
136        key.hash(&mut hasher);
137        (hasher.finish() as usize) % bucket_count
138    }
139}
140
141impl Default for MerkleTree {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_empty_tree() {
153        let tree = MerkleTree::new();
154        let root = tree.root_hash();
155        assert_eq!(root.len(), 32);
156    }
157
158    #[test]
159    fn test_insert_and_hash() {
160        let mut tree = MerkleTree::new();
161        tree.insert("key1", b"value1");
162
163        let root1 = tree.root_hash();
164
165        tree.insert("key2", b"value2");
166        let root2 = tree.root_hash();
167
168        assert_ne!(root1, root2);
169    }
170
171    #[test]
172    fn test_same_content_same_hash() {
173        let mut tree1 = MerkleTree::new();
174        tree1.insert("key1", b"value1");
175        tree1.insert("key2", b"value2");
176
177        let mut tree2 = MerkleTree::new();
178        tree2.insert("key2", b"value2");
179        tree2.insert("key1", b"value1");
180
181        assert_eq!(tree1.root_hash(), tree2.root_hash());
182    }
183
184    #[test]
185    fn test_remove() {
186        let mut tree = MerkleTree::new();
187        tree.insert("key1", b"value1");
188        let root1 = tree.root_hash();
189
190        tree.insert("key2", b"value2");
191        tree.remove("key2");
192        let root2 = tree.root_hash();
193
194        assert_eq!(root1, root2);
195    }
196}