Skip to main content

oxigdal_edge/
conflict.rs

1//! CRDT-based conflict resolution for distributed edge nodes
2//!
3//! Provides Conflict-free Replicated Data Types (CRDTs) for automatic
4//! conflict resolution in distributed edge computing environments.
5
6use ahash::AHashMap;
7use serde::{Deserialize, Serialize};
8use std::cmp::Ordering;
9use std::collections::{HashMap, HashSet};
10use std::fmt;
11
12/// Vector clock for tracking causality
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct VectorClock {
15    clock: AHashMap<String, u64>,
16}
17
18impl VectorClock {
19    /// Create a new vector clock
20    pub fn new() -> Self {
21        Self {
22            clock: AHashMap::new(),
23        }
24    }
25
26    /// Increment clock for node
27    pub fn increment(&mut self, node_id: &str) {
28        let counter = self.clock.entry(node_id.to_string()).or_insert(0);
29        *counter = counter.saturating_add(1);
30    }
31
32    /// Get clock value for node
33    pub fn get(&self, node_id: &str) -> u64 {
34        self.clock.get(node_id).copied().unwrap_or(0)
35    }
36
37    /// Merge with another vector clock
38    pub fn merge(&mut self, other: &VectorClock) {
39        for (node_id, &other_count) in &other.clock {
40            let count = self.clock.entry(node_id.clone()).or_insert(0);
41            *count = (*count).max(other_count);
42        }
43    }
44
45    /// Compare vector clocks for causality
46    pub fn compare(&self, other: &VectorClock) -> ClockOrdering {
47        let mut less = false;
48        let mut greater = false;
49
50        // Check all nodes in self
51        for (node_id, &self_count) in &self.clock {
52            let other_count = other.get(node_id);
53            match self_count.cmp(&other_count) {
54                Ordering::Less => less = true,
55                Ordering::Greater => greater = true,
56                Ordering::Equal => {}
57            }
58        }
59
60        // Check nodes only in other
61        for node_id in other.clock.keys() {
62            if !self.clock.contains_key(node_id) {
63                less = true;
64            }
65        }
66
67        match (less, greater) {
68            (true, false) => ClockOrdering::Before,
69            (false, true) => ClockOrdering::After,
70            (false, false) => ClockOrdering::Equal,
71            (true, true) => ClockOrdering::Concurrent,
72        }
73    }
74
75    /// Check if this clock is concurrent with another
76    pub fn is_concurrent(&self, other: &VectorClock) -> bool {
77        matches!(self.compare(other), ClockOrdering::Concurrent)
78    }
79
80    /// Check if this clock happens before another
81    pub fn happens_before(&self, other: &VectorClock) -> bool {
82        matches!(self.compare(other), ClockOrdering::Before)
83    }
84}
85
86impl Default for VectorClock {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92/// Clock ordering relationship
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum ClockOrdering {
95    /// This clock is before the other
96    Before,
97    /// This clock is after the other
98    After,
99    /// Clocks are equal
100    Equal,
101    /// Clocks are concurrent (conflict)
102    Concurrent,
103}
104
105/// Last-Write-Wins Register CRDT
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct LwwRegister<T> {
108    value: T,
109    timestamp: VectorClock,
110    logical_time: u64,
111    node_id: String,
112}
113
114impl<T: Clone> LwwRegister<T> {
115    /// Create new LWW register
116    pub fn new(value: T, node_id: String) -> Self {
117        let mut timestamp = VectorClock::new();
118        timestamp.increment(&node_id);
119        Self {
120            value,
121            timestamp,
122            logical_time: 1,
123            node_id,
124        }
125    }
126
127    /// Get current value
128    pub fn value(&self) -> &T {
129        &self.value
130    }
131
132    /// Update value
133    pub fn update(&mut self, value: T) {
134        self.value = value;
135        self.timestamp.increment(&self.node_id);
136        self.logical_time += 1;
137    }
138
139    /// Merge with another register (conflict resolution)
140    pub fn merge(&mut self, other: &LwwRegister<T>) {
141        match self.timestamp.compare(&other.timestamp) {
142            ClockOrdering::Before => {
143                self.value = other.value.clone();
144                self.timestamp = other.timestamp.clone();
145                self.logical_time = other.logical_time;
146            }
147            ClockOrdering::Concurrent => {
148                // For concurrent updates, use logical time as tie-breaker
149                // If logical times are equal, use node_id
150                let should_adopt_other = other.logical_time > self.logical_time
151                    || (other.logical_time == self.logical_time && self.node_id < other.node_id);
152
153                if should_adopt_other {
154                    self.value = other.value.clone();
155                    self.timestamp = other.timestamp.clone();
156                    self.logical_time = other.logical_time;
157                }
158            }
159            ClockOrdering::After | ClockOrdering::Equal => {
160                // Keep current value
161            }
162        }
163    }
164}
165
166/// Grow-only Set CRDT
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct GSet<T: Eq + std::hash::Hash> {
169    elements: HashSet<T>,
170}
171
172impl<T: Eq + std::hash::Hash> GSet<T> {
173    /// Create new G-Set
174    pub fn new() -> Self {
175        Self {
176            elements: HashSet::new(),
177        }
178    }
179
180    /// Add element to set
181    pub fn insert(&mut self, element: T) {
182        self.elements.insert(element);
183    }
184
185    /// Check if set contains element
186    pub fn contains(&self, element: &T) -> bool {
187        self.elements.contains(element)
188    }
189
190    /// Get set size
191    pub fn len(&self) -> usize {
192        self.elements.len()
193    }
194
195    /// Check if set is empty
196    pub fn is_empty(&self) -> bool {
197        self.elements.is_empty()
198    }
199
200    /// Merge with another G-Set
201    pub fn merge(&mut self, other: &GSet<T>)
202    where
203        T: Clone,
204    {
205        for element in &other.elements {
206            self.elements.insert(element.clone());
207        }
208    }
209}
210
211impl<T: Eq + std::hash::Hash> Default for GSet<T> {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217/// Two-Phase Set CRDT (supports both add and remove)
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct TwoPhaseSet<T: Eq + std::hash::Hash + Clone> {
220    added: HashSet<T>,
221    removed: HashSet<T>,
222}
223
224impl<T: Eq + std::hash::Hash + Clone> TwoPhaseSet<T> {
225    /// Create new Two-Phase Set
226    pub fn new() -> Self {
227        Self {
228            added: HashSet::new(),
229            removed: HashSet::new(),
230        }
231    }
232
233    /// Add element to set
234    pub fn insert(&mut self, element: T) {
235        if !self.removed.contains(&element) {
236            self.added.insert(element);
237        }
238    }
239
240    /// Remove element from set
241    pub fn remove(&mut self, element: &T) -> bool {
242        if self.added.contains(element) {
243            self.removed.insert(element.clone());
244            true
245        } else {
246            false
247        }
248    }
249
250    /// Check if set contains element
251    pub fn contains(&self, element: &T) -> bool {
252        self.added.contains(element) && !self.removed.contains(element)
253    }
254
255    /// Get visible elements
256    pub fn elements(&self) -> impl Iterator<Item = &T> {
257        self.added.iter().filter(|e| !self.removed.contains(e))
258    }
259
260    /// Get set size
261    pub fn len(&self) -> usize {
262        self.elements().count()
263    }
264
265    /// Check if set is empty
266    pub fn is_empty(&self) -> bool {
267        self.len() == 0
268    }
269
270    /// Merge with another Two-Phase Set
271    pub fn merge(&mut self, other: &TwoPhaseSet<T>) {
272        for element in &other.added {
273            self.added.insert(element.clone());
274        }
275        for element in &other.removed {
276            self.removed.insert(element.clone());
277        }
278    }
279}
280
281impl<T: Eq + std::hash::Hash + Clone> Default for TwoPhaseSet<T> {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287/// CRDT Map combining multiple CRDTs
288pub type CrdtSet<T> = TwoPhaseSet<T>;
289
290/// CRDT Map for key-value storage
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CrdtMap<K, V>
293where
294    K: Eq + std::hash::Hash + Clone,
295    V: Clone,
296{
297    entries: HashMap<K, LwwRegister<V>>,
298    node_id: String,
299}
300
301impl<K, V> CrdtMap<K, V>
302where
303    K: Eq + std::hash::Hash + Clone,
304    V: Clone,
305{
306    /// Create new CRDT map
307    pub fn new(node_id: String) -> Self {
308        Self {
309            entries: HashMap::new(),
310            node_id,
311        }
312    }
313
314    /// Insert or update key-value pair
315    pub fn insert(&mut self, key: K, value: V) {
316        if let Some(register) = self.entries.get_mut(&key) {
317            register.update(value);
318        } else {
319            let register = LwwRegister::new(value, self.node_id.clone());
320            self.entries.insert(key, register);
321        }
322    }
323
324    /// Get value for key
325    pub fn get(&self, key: &K) -> Option<&V> {
326        self.entries.get(key).map(|r| r.value())
327    }
328
329    /// Check if map contains key
330    pub fn contains_key(&self, key: &K) -> bool {
331        self.entries.contains_key(key)
332    }
333
334    /// Get map size
335    pub fn len(&self) -> usize {
336        self.entries.len()
337    }
338
339    /// Check if map is empty
340    pub fn is_empty(&self) -> bool {
341        self.entries.is_empty()
342    }
343
344    /// Iterate over key-value pairs
345    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
346        self.entries.iter().map(|(k, v)| (k, v.value()))
347    }
348
349    /// Merge with another CRDT map
350    pub fn merge(&mut self, other: &CrdtMap<K, V>) {
351        for (key, other_register) in &other.entries {
352            if let Some(register) = self.entries.get_mut(key) {
353                register.merge(other_register);
354            } else {
355                self.entries.insert(key.clone(), other_register.clone());
356            }
357        }
358    }
359}
360
361/// Conflict resolver for edge nodes
362pub struct ConflictResolver {
363    node_id: String,
364}
365
366impl ConflictResolver {
367    /// Create new conflict resolver
368    pub fn new(node_id: String) -> Self {
369        Self { node_id }
370    }
371
372    /// Create CRDT map
373    pub fn create_map<K, V>(&self) -> CrdtMap<K, V>
374    where
375        K: Eq + std::hash::Hash + Clone,
376        V: Clone,
377    {
378        CrdtMap::new(self.node_id.clone())
379    }
380
381    /// Create CRDT set
382    pub fn create_set<T: Eq + std::hash::Hash + Clone>(&self) -> CrdtSet<T> {
383        CrdtSet::new()
384    }
385
386    /// Resolve conflict between two values using Last-Write-Wins
387    pub fn resolve_lww<T: Clone>(
388        &self,
389        local: &T,
390        local_clock: &VectorClock,
391        remote: &T,
392        remote_clock: &VectorClock,
393    ) -> T {
394        match local_clock.compare(remote_clock) {
395            ClockOrdering::Before => remote.clone(),
396            ClockOrdering::After | ClockOrdering::Equal => local.clone(),
397            ClockOrdering::Concurrent => {
398                // Tie-break deterministically
399                if self.node_id.as_str() < "remote" {
400                    remote.clone()
401                } else {
402                    local.clone()
403                }
404            }
405        }
406    }
407
408    /// Get node ID
409    pub fn node_id(&self) -> &str {
410        &self.node_id
411    }
412}
413
414impl fmt::Display for VectorClock {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        write!(f, "{{")?;
417        for (i, (node, count)) in self.clock.iter().enumerate() {
418            if i > 0 {
419                write!(f, ", ")?;
420            }
421            write!(f, "{}: {}", node, count)?;
422        }
423        write!(f, "}}")
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_vector_clock_increment() {
433        let mut clock = VectorClock::new();
434        clock.increment("node1");
435        clock.increment("node1");
436        clock.increment("node2");
437
438        assert_eq!(clock.get("node1"), 2);
439        assert_eq!(clock.get("node2"), 1);
440        assert_eq!(clock.get("node3"), 0);
441    }
442
443    #[test]
444    fn test_vector_clock_merge() {
445        let mut clock1 = VectorClock::new();
446        clock1.increment("node1");
447        clock1.increment("node1");
448
449        let mut clock2 = VectorClock::new();
450        clock2.increment("node2");
451
452        clock1.merge(&clock2);
453        assert_eq!(clock1.get("node1"), 2);
454        assert_eq!(clock1.get("node2"), 1);
455    }
456
457    #[test]
458    fn test_vector_clock_compare() {
459        let mut clock1 = VectorClock::new();
460        clock1.increment("node1");
461
462        let mut clock2 = VectorClock::new();
463        clock2.increment("node1");
464        clock2.increment("node1");
465
466        assert_eq!(clock1.compare(&clock2), ClockOrdering::Before);
467        assert_eq!(clock2.compare(&clock1), ClockOrdering::After);
468
469        let mut clock3 = VectorClock::new();
470        clock3.increment("node2");
471
472        assert_eq!(clock1.compare(&clock3), ClockOrdering::Concurrent);
473    }
474
475    #[test]
476    fn test_lww_register() {
477        let mut reg1 = LwwRegister::new(42, "node1".to_string());
478        let mut reg2 = LwwRegister::new(100, "node2".to_string());
479
480        reg1.update(50);
481        reg2.merge(&reg1);
482
483        assert_eq!(*reg2.value(), 50);
484    }
485
486    #[test]
487    fn test_gset() {
488        let mut set1 = GSet::new();
489        set1.insert(1);
490        set1.insert(2);
491
492        let mut set2 = GSet::new();
493        set2.insert(2);
494        set2.insert(3);
495
496        set1.merge(&set2);
497
498        assert_eq!(set1.len(), 3);
499        assert!(set1.contains(&1));
500        assert!(set1.contains(&2));
501        assert!(set1.contains(&3));
502    }
503
504    #[test]
505    fn test_two_phase_set() {
506        let mut set = TwoPhaseSet::new();
507        set.insert(1);
508        set.insert(2);
509        set.insert(3);
510
511        assert_eq!(set.len(), 3);
512        assert!(set.contains(&2));
513
514        set.remove(&2);
515        assert_eq!(set.len(), 2);
516        assert!(!set.contains(&2));
517
518        // Once removed, cannot be added again
519        set.insert(2);
520        assert!(!set.contains(&2));
521    }
522
523    #[test]
524    fn test_two_phase_set_merge() {
525        let mut set1 = TwoPhaseSet::new();
526        set1.insert(1);
527        set1.insert(2);
528
529        let mut set2 = TwoPhaseSet::new();
530        set2.insert(2);
531        set2.insert(3);
532        set2.remove(&2);
533
534        set1.merge(&set2);
535
536        assert!(set1.contains(&1));
537        assert!(!set1.contains(&2)); // Removed in set2
538        assert!(set1.contains(&3));
539    }
540
541    #[test]
542    fn test_crdt_map() {
543        let mut map1 = CrdtMap::new("node1".to_string());
544        map1.insert("key1", 100);
545        map1.insert("key2", 200);
546
547        let mut map2 = CrdtMap::new("node2".to_string());
548        map2.insert("key2", 250);
549        map2.insert("key3", 300);
550
551        map1.merge(&map2);
552
553        assert_eq!(map1.get(&"key1"), Some(&100));
554        assert_eq!(map1.get(&"key3"), Some(&300));
555        // key2 will be resolved based on vector clocks
556    }
557
558    #[test]
559    fn test_conflict_resolver() {
560        let resolver = ConflictResolver::new("node1".to_string());
561        assert_eq!(resolver.node_id(), "node1");
562
563        let map: CrdtMap<String, i32> = resolver.create_map();
564        assert!(map.is_empty());
565
566        let set: CrdtSet<i32> = resolver.create_set();
567        assert!(set.is_empty());
568    }
569}