Skip to main content

citadel_sync/
crdt.rs

1use crate::hlc::HlcTimestamp;
2use crate::node_id::NodeId;
3
4/// Per-entry CRDT metadata for LWW (Last-Writer-Wins) conflict resolution.
5///
6/// 20 bytes on wire: entry_kind (1B) + padding (3B) + HLC timestamp (12B) + NodeId (8B).
7///
8/// Conflict resolution: higher timestamp wins, NodeId tiebreaker.
9/// This forms a join-semilattice with a total order, guaranteeing:
10/// - Commutativity: merge(a, b) == merge(b, a)
11/// - Associativity: merge(merge(a, b), c) == merge(a, merge(b, c))
12/// - Idempotency: merge(a, a) == a
13#[derive(Clone, Copy, PartialEq, Eq, Hash)]
14pub struct CrdtMeta {
15    pub timestamp: HlcTimestamp,
16    pub node_id: NodeId,
17}
18
19/// Wire size of CrdtMeta: HLC (12B) + NodeId (8B) = 20 bytes.
20pub const CRDT_META_SIZE: usize = 20;
21
22/// Wire size of a full CRDT-encoded value header: kind (1B) + padding (3B) + meta (20B) = 24 bytes.
23pub const CRDT_HEADER_SIZE: usize = 24;
24
25/// Type of CRDT entry.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27#[repr(u8)]
28pub enum EntryKind {
29    /// Key-value write (user value follows the header).
30    Put = 0,
31    /// Logical delete (tombstone). No user value follows.
32    Tombstone = 1,
33}
34
35impl EntryKind {
36    pub fn from_u8(v: u8) -> Option<Self> {
37        match v {
38            0 => Some(Self::Put),
39            1 => Some(Self::Tombstone),
40            _ => None,
41        }
42    }
43}
44
45impl CrdtMeta {
46    #[inline]
47    pub fn new(timestamp: HlcTimestamp, node_id: NodeId) -> Self {
48        Self { timestamp, node_id }
49    }
50
51    pub fn to_bytes(&self) -> [u8; CRDT_META_SIZE] {
52        let mut buf = [0u8; CRDT_META_SIZE];
53        let ts_bytes = self.timestamp.to_bytes();
54        let nid_bytes = self.node_id.to_bytes();
55        buf[0..12].copy_from_slice(&ts_bytes);
56        buf[12..20].copy_from_slice(&nid_bytes);
57        buf
58    }
59
60    pub fn from_bytes(b: &[u8; CRDT_META_SIZE]) -> Self {
61        let ts = HlcTimestamp::from_bytes(b[0..12].try_into().unwrap());
62        let nid = NodeId::from_bytes(b[12..20].try_into().unwrap());
63        Self {
64            timestamp: ts,
65            node_id: nid,
66        }
67    }
68
69    /// Higher timestamp wins, NodeId tiebreaker.
70    #[inline]
71    pub fn lww_cmp(&self, other: &Self) -> std::cmp::Ordering {
72        self.timestamp
73            .cmp(&other.timestamp)
74            .then(self.node_id.cmp(&other.node_id))
75    }
76
77    #[inline]
78    pub fn wins_over(&self, other: &Self) -> bool {
79        self.lww_cmp(other) == std::cmp::Ordering::Greater
80    }
81}
82
83impl std::fmt::Debug for CrdtMeta {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "CrdtMeta({:?}, {:?})", self.timestamp, self.node_id)
86    }
87}
88
89/// Encode a user value with CRDT header.
90///
91/// Format: `[entry_kind: u8][_pad: 3B][HLC: 12B][NodeId: 8B][user_value...]`
92///
93/// Total header: 24 bytes. User value follows immediately after.
94pub fn encode_lww_value(meta: &CrdtMeta, kind: EntryKind, user_value: &[u8]) -> Vec<u8> {
95    let user_len = if kind == EntryKind::Tombstone {
96        0
97    } else {
98        user_value.len()
99    };
100    let mut buf = Vec::with_capacity(CRDT_HEADER_SIZE + user_len);
101    buf.push(kind as u8);
102    buf.extend_from_slice(&[0u8; 3]); // padding
103    buf.extend_from_slice(&meta.to_bytes());
104    if kind == EntryKind::Put {
105        buf.extend_from_slice(user_value);
106    }
107    buf
108}
109
110#[derive(Debug)]
111pub struct DecodedValue<'a> {
112    pub meta: CrdtMeta,
113    pub kind: EntryKind,
114    pub user_value: &'a [u8],
115}
116
117pub fn decode_lww_value(data: &[u8]) -> Result<DecodedValue<'_>, DecodeError> {
118    if data.len() < CRDT_HEADER_SIZE {
119        return Err(DecodeError::TooShort {
120            expected: CRDT_HEADER_SIZE,
121            actual: data.len(),
122        });
123    }
124
125    let kind = EntryKind::from_u8(data[0]).ok_or(DecodeError::InvalidEntryKind(data[0]))?;
126    // bytes 1..4 are padding (ignored on read)
127    let meta_bytes: &[u8; CRDT_META_SIZE] = data[4..24].try_into().unwrap();
128    let meta = CrdtMeta::from_bytes(meta_bytes);
129
130    let user_value = if kind == EntryKind::Tombstone {
131        &data[CRDT_HEADER_SIZE..CRDT_HEADER_SIZE] // empty slice
132    } else {
133        &data[CRDT_HEADER_SIZE..]
134    };
135
136    Ok(DecodedValue {
137        meta,
138        kind,
139        user_value,
140    })
141}
142
143#[derive(Debug, thiserror::Error)]
144pub enum DecodeError {
145    #[error("CRDT value too short: expected at least {expected} bytes, got {actual}")]
146    TooShort { expected: usize, actual: usize },
147
148    #[error("invalid CRDT entry kind: {0}")]
149    InvalidEntryKind(u8),
150}
151
152/// Merge two CRDT entries for the same key.
153///
154/// Returns which side wins using LWW resolution:
155/// higher timestamp wins, NodeId tiebreaker.
156/// The entry kind (Put vs Tombstone) does NOT affect the merge -
157/// a tombstone with a higher timestamp defeats a put with a lower one.
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MergeResult {
160    /// Keep the local entry.
161    Local,
162    /// Take the remote entry.
163    Remote,
164    /// Both entries are identical.
165    Equal,
166}
167
168pub fn lww_merge(local: &CrdtMeta, remote: &CrdtMeta) -> MergeResult {
169    match local.lww_cmp(remote) {
170        std::cmp::Ordering::Greater => MergeResult::Local,
171        std::cmp::Ordering::Less => MergeResult::Remote,
172        std::cmp::Ordering::Equal => MergeResult::Equal,
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::hlc::HlcTimestamp;
180    use crate::node_id::NodeId;
181
182    const SECOND: i64 = 1_000_000_000;
183
184    fn meta(wall_ns: i64, logical: i32, node: u64) -> CrdtMeta {
185        CrdtMeta::new(HlcTimestamp::new(wall_ns, logical), NodeId::from_u64(node))
186    }
187
188    #[test]
189    fn meta_new_and_accessors() {
190        let ts = HlcTimestamp::new(1000 * SECOND, 5);
191        let nid = NodeId::from_u64(42);
192        let m = CrdtMeta::new(ts, nid);
193        assert_eq!(m.timestamp, ts);
194        assert_eq!(m.node_id, nid);
195    }
196
197    #[test]
198    fn meta_bytes_roundtrip() {
199        let m = meta(1000 * SECOND, 42, 0xDEADBEEF);
200        let bytes = m.to_bytes();
201        assert_eq!(bytes.len(), CRDT_META_SIZE);
202        let m2 = CrdtMeta::from_bytes(&bytes);
203        assert_eq!(m, m2);
204    }
205
206    #[test]
207    fn meta_bytes_roundtrip_zero() {
208        let m = meta(0, 0, 0);
209        let bytes = m.to_bytes();
210        let m2 = CrdtMeta::from_bytes(&bytes);
211        assert_eq!(m, m2);
212    }
213
214    #[test]
215    fn meta_bytes_roundtrip_max() {
216        let m = meta(i64::MAX, i32::MAX, u64::MAX);
217        let bytes = m.to_bytes();
218        let m2 = CrdtMeta::from_bytes(&bytes);
219        assert_eq!(m, m2);
220    }
221
222    #[test]
223    fn meta_debug_format() {
224        let m = meta(1_000_000_000, 5, 255);
225        let s = format!("{m:?}");
226        assert!(s.contains("CrdtMeta"));
227        assert!(s.contains("HLC"));
228        assert!(s.contains("NodeId"));
229    }
230
231    #[test]
232    fn lww_higher_timestamp_wins() {
233        let a = meta(1000 * SECOND, 0, 1);
234        let b = meta(1001 * SECOND, 0, 1);
235        assert!(b.wins_over(&a));
236        assert!(!a.wins_over(&b));
237    }
238
239    #[test]
240    fn lww_higher_logical_wins() {
241        let a = meta(1000 * SECOND, 5, 1);
242        let b = meta(1000 * SECOND, 6, 1);
243        assert!(b.wins_over(&a));
244        assert!(!a.wins_over(&b));
245    }
246
247    #[test]
248    fn lww_node_id_tiebreaker() {
249        let a = meta(1000 * SECOND, 5, 100);
250        let b = meta(1000 * SECOND, 5, 200);
251        assert!(b.wins_over(&a));
252        assert!(!a.wins_over(&b));
253    }
254
255    #[test]
256    fn lww_equal_entries() {
257        let a = meta(1000 * SECOND, 5, 100);
258        let b = meta(1000 * SECOND, 5, 100);
259        assert!(!a.wins_over(&b));
260        assert!(!b.wins_over(&a));
261        assert_eq!(a.lww_cmp(&b), std::cmp::Ordering::Equal);
262    }
263
264    #[test]
265    fn lww_timestamp_dominates_node_id() {
266        // Even with lower node_id, higher timestamp wins
267        let a = meta(1001 * SECOND, 0, 1);
268        let b = meta(1000 * SECOND, 0, u64::MAX);
269        assert!(a.wins_over(&b));
270    }
271
272    #[test]
273    fn merge_local_wins() {
274        let local = meta(1001 * SECOND, 0, 1);
275        let remote = meta(1000 * SECOND, 0, 1);
276        assert_eq!(lww_merge(&local, &remote), MergeResult::Local);
277    }
278
279    #[test]
280    fn merge_remote_wins() {
281        let local = meta(1000 * SECOND, 0, 1);
282        let remote = meta(1001 * SECOND, 0, 1);
283        assert_eq!(lww_merge(&local, &remote), MergeResult::Remote);
284    }
285
286    #[test]
287    fn merge_equal() {
288        let local = meta(1000 * SECOND, 5, 100);
289        let remote = meta(1000 * SECOND, 5, 100);
290        assert_eq!(lww_merge(&local, &remote), MergeResult::Equal);
291    }
292
293    #[test]
294    fn merge_commutativity() {
295        let entries = [
296            meta(1000 * SECOND, 0, 1),
297            meta(1000 * SECOND, 0, 2),
298            meta(1001 * SECOND, 0, 1),
299            meta(1000 * SECOND, 1, 1),
300        ];
301
302        for a in &entries {
303            for b in &entries {
304                let ab = lww_merge(a, b);
305                let ba = lww_merge(b, a);
306                // Commutativity: merge(a,b) mirror equals merge(b,a)
307                match (ab, ba) {
308                    (MergeResult::Local, MergeResult::Remote) => {}
309                    (MergeResult::Remote, MergeResult::Local) => {}
310                    (MergeResult::Equal, MergeResult::Equal) => {}
311                    _ => panic!("commutativity violated for {a:?} vs {b:?}: {ab:?} vs {ba:?}"),
312                }
313            }
314        }
315    }
316
317    #[test]
318    fn merge_associativity() {
319        // For three entries, the winner should be the same regardless of merge order.
320        let a = meta(1000 * SECOND, 0, 1);
321        let b = meta(1001 * SECOND, 5, 2);
322        let c = meta(1001 * SECOND, 5, 3);
323
324        // Winner is c (same timestamp as b, higher node_id)
325        // merge(merge(a, b), c) should pick the same winner as merge(a, merge(b, c))
326
327        fn winner(local: &CrdtMeta, remote: &CrdtMeta) -> CrdtMeta {
328            match lww_merge(local, remote) {
329                MergeResult::Local | MergeResult::Equal => *local,
330                MergeResult::Remote => *remote,
331            }
332        }
333
334        let ab = winner(&a, &b);
335        let ab_c = winner(&ab, &c);
336
337        let bc = winner(&b, &c);
338        let a_bc = winner(&a, &bc);
339
340        assert_eq!(ab_c, a_bc, "associativity violated");
341    }
342
343    #[test]
344    fn merge_idempotency() {
345        let a = meta(1000 * SECOND, 5, 42);
346        assert_eq!(lww_merge(&a, &a), MergeResult::Equal);
347    }
348
349    #[test]
350    fn entry_kind_roundtrip() {
351        assert_eq!(EntryKind::from_u8(0), Some(EntryKind::Put));
352        assert_eq!(EntryKind::from_u8(1), Some(EntryKind::Tombstone));
353        assert_eq!(EntryKind::from_u8(2), None);
354        assert_eq!(EntryKind::from_u8(255), None);
355    }
356
357    #[test]
358    fn encode_decode_put_roundtrip() {
359        let m = meta(1000 * SECOND, 5, 42);
360        let user_val = b"hello world";
361        let encoded = encode_lww_value(&m, EntryKind::Put, user_val);
362
363        assert_eq!(encoded.len(), CRDT_HEADER_SIZE + user_val.len());
364
365        let decoded = decode_lww_value(&encoded).unwrap();
366        assert_eq!(decoded.meta, m);
367        assert_eq!(decoded.kind, EntryKind::Put);
368        assert_eq!(decoded.user_value, user_val);
369    }
370
371    #[test]
372    fn encode_decode_tombstone_roundtrip() {
373        let m = meta(1000 * SECOND, 5, 42);
374        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"");
375
376        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
377
378        let decoded = decode_lww_value(&encoded).unwrap();
379        assert_eq!(decoded.meta, m);
380        assert_eq!(decoded.kind, EntryKind::Tombstone);
381        assert_eq!(decoded.user_value.len(), 0);
382    }
383
384    #[test]
385    fn encode_tombstone_ignores_user_value() {
386        let m = meta(1000 * SECOND, 5, 42);
387        // Even if user_value is non-empty, tombstone encoding ignores it
388        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"should be ignored");
389        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
390    }
391
392    #[test]
393    fn encode_decode_empty_value() {
394        let m = meta(1000 * SECOND, 0, 1);
395        let encoded = encode_lww_value(&m, EntryKind::Put, b"");
396
397        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
398
399        let decoded = decode_lww_value(&encoded).unwrap();
400        assert_eq!(decoded.kind, EntryKind::Put);
401        assert_eq!(decoded.user_value.len(), 0);
402    }
403
404    #[test]
405    fn encode_decode_large_value() {
406        let m = meta(1000 * SECOND, 0, 1);
407        let user_val = vec![0xAB; 4096];
408        let encoded = encode_lww_value(&m, EntryKind::Put, &user_val);
409
410        assert_eq!(encoded.len(), CRDT_HEADER_SIZE + 4096);
411
412        let decoded = decode_lww_value(&encoded).unwrap();
413        assert_eq!(decoded.user_value, &user_val[..]);
414    }
415
416    #[test]
417    fn decode_too_short() {
418        let err = decode_lww_value(&[0u8; 10]).unwrap_err();
419        assert!(matches!(err, DecodeError::TooShort { .. }));
420    }
421
422    #[test]
423    fn decode_invalid_entry_kind() {
424        let mut data = [0u8; CRDT_HEADER_SIZE];
425        data[0] = 255; // invalid
426        let err = decode_lww_value(&data).unwrap_err();
427        assert!(matches!(err, DecodeError::InvalidEntryKind(255)));
428    }
429
430    #[test]
431    fn header_size_constant() {
432        assert_eq!(CRDT_HEADER_SIZE, 24);
433        assert_eq!(CRDT_META_SIZE, 20);
434        // 1 (kind) + 3 (pad) + 12 (HLC) + 8 (NodeId) = 24
435        assert_eq!(1 + 3 + 12 + 8, CRDT_HEADER_SIZE);
436    }
437
438    #[test]
439    fn merge_encoded_values() {
440        let local_meta = meta(1000 * SECOND, 0, 1);
441        let remote_meta = meta(1001 * SECOND, 0, 2);
442
443        let local_encoded = encode_lww_value(&local_meta, EntryKind::Put, b"local");
444        let remote_encoded = encode_lww_value(&remote_meta, EntryKind::Put, b"remote");
445
446        let local_decoded = decode_lww_value(&local_encoded).unwrap();
447        let remote_decoded = decode_lww_value(&remote_encoded).unwrap();
448
449        let result = lww_merge(&local_decoded.meta, &remote_decoded.meta);
450        assert_eq!(result, MergeResult::Remote);
451    }
452
453    #[test]
454    fn tombstone_wins_over_put_with_lower_timestamp() {
455        let put_meta = meta(1000 * SECOND, 0, 1);
456        let del_meta = meta(1001 * SECOND, 0, 1);
457
458        let put_encoded = encode_lww_value(&put_meta, EntryKind::Put, b"value");
459        let del_encoded = encode_lww_value(&del_meta, EntryKind::Tombstone, b"");
460
461        let put_decoded = decode_lww_value(&put_encoded).unwrap();
462        let del_decoded = decode_lww_value(&del_encoded).unwrap();
463
464        // Tombstone has higher timestamp - it wins
465        let result = lww_merge(&put_decoded.meta, &del_decoded.meta);
466        assert_eq!(result, MergeResult::Remote);
467        assert_eq!(del_decoded.kind, EntryKind::Tombstone);
468    }
469
470    #[test]
471    fn put_wins_over_tombstone_with_lower_timestamp() {
472        let del_meta = meta(1000 * SECOND, 0, 1);
473        let put_meta = meta(1001 * SECOND, 0, 1);
474
475        let del_encoded = encode_lww_value(&del_meta, EntryKind::Tombstone, b"");
476        let put_encoded = encode_lww_value(&put_meta, EntryKind::Put, b"value");
477
478        let del_decoded = decode_lww_value(&del_encoded).unwrap();
479        let put_decoded = decode_lww_value(&put_encoded).unwrap();
480
481        // Put has higher timestamp - it wins over the tombstone
482        let result = lww_merge(&del_decoded.meta, &put_decoded.meta);
483        assert_eq!(result, MergeResult::Remote);
484        assert_eq!(put_decoded.kind, EntryKind::Put);
485    }
486
487    #[test]
488    fn encoded_format_put() {
489        let m = CrdtMeta::new(
490            HlcTimestamp::new(0x0102_0304_0506_0708, 0x090A0B0C),
491            NodeId::from_u64(0x1112_1314_1516_1718),
492        );
493        let encoded = encode_lww_value(&m, EntryKind::Put, b"\xAA\xBB");
494
495        // kind=0, pad=[0,0,0], HLC=8B+4B, NodeId=8B, value=2B
496        assert_eq!(encoded[0], 0x00); // Put
497        assert_eq!(&encoded[1..4], &[0, 0, 0]); // padding
498                                                // HLC wall_time big-endian
499        assert_eq!(
500            &encoded[4..12],
501            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
502        );
503        // HLC logical big-endian
504        assert_eq!(&encoded[12..16], &[0x09, 0x0A, 0x0B, 0x0C]);
505        // NodeId big-endian
506        assert_eq!(
507            &encoded[16..24],
508            &[0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18]
509        );
510        // User value
511        assert_eq!(&encoded[24..26], &[0xAA, 0xBB]);
512    }
513
514    #[test]
515    fn encoded_format_tombstone() {
516        let m = meta(1000 * SECOND, 0, 1);
517        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"");
518        assert_eq!(encoded[0], 0x01); // Tombstone
519        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
520    }
521
522    #[test]
523    fn merge_many_entries_finds_latest() {
524        let entries: Vec<CrdtMeta> = (0..100)
525            .map(|i| meta(1000 * SECOND + i as i64, 0, i as u64))
526            .collect();
527
528        let mut winner = entries[0];
529        for e in &entries[1..] {
530            if lww_merge(&winner, e) == MergeResult::Remote {
531                winner = *e;
532            }
533        }
534
535        // Last entry should win (highest timestamp and node_id)
536        assert_eq!(winner.timestamp.wall_time(), 1000 * SECOND + 99);
537        assert_eq!(winner.node_id.as_u64(), 99);
538    }
539
540    #[test]
541    fn merge_reverse_order_same_result() {
542        let entries: Vec<CrdtMeta> = (0..100)
543            .map(|i| meta(1000 * SECOND + i as i64, 0, i as u64))
544            .collect();
545
546        let mut fwd_winner = entries[0];
547        for e in &entries[1..] {
548            if lww_merge(&fwd_winner, e) == MergeResult::Remote {
549                fwd_winner = *e;
550            }
551        }
552
553        let mut rev_winner = entries[99];
554        for e in entries[..99].iter().rev() {
555            if lww_merge(&rev_winner, e) == MergeResult::Remote {
556                rev_winner = *e;
557            }
558        }
559
560        assert_eq!(fwd_winner, rev_winner);
561    }
562
563    #[test]
564    fn merge_shuffled_order_same_result() {
565        use std::collections::BTreeSet;
566
567        let entries: Vec<CrdtMeta> = (0..50)
568            .map(|i| meta(1000 * SECOND + (i * 7 % 50) as i64, 0, i as u64))
569            .collect();
570
571        let expected = entries.iter().max_by(|a, b| a.lww_cmp(b)).unwrap();
572
573        let mut winner = entries[0];
574        for e in &entries[1..] {
575            if lww_merge(&winner, e) == MergeResult::Remote {
576                winner = *e;
577            }
578        }
579
580        assert_eq!(winner, *expected);
581
582        // Merge in BTreeSet-sorted order (different from insertion order)
583        let sorted: BTreeSet<u64> = entries
584            .iter()
585            .map(|e| e.timestamp.wall_time() as u64)
586            .collect();
587        assert!(sorted.len() <= entries.len()); // some might collide, that's fine
588    }
589}