Skip to main content

citadel_sync/
crdt.rs

1use crate::hlc::HlcTimestamp;
2use crate::node_id::NodeId;
3
4/// Per-entry CRDT metadata for LWW (Last-Writer-Wins) conflict resolution.
5///
6/// 20 bytes on wire: entry_kind (1B) + padding (3B) + HLC timestamp (12B) + NodeId (8B).
7///
8/// Conflict resolution: higher timestamp wins, NodeId tiebreaker.
9/// This forms a join-semilattice with a total order, guaranteeing:
10/// - Commutativity: merge(a, b) == merge(b, a)
11/// - Associativity: merge(merge(a, b), c) == merge(a, merge(b, c))
12/// - Idempotency: merge(a, a) == a
13#[derive(Clone, Copy, PartialEq, Eq, Hash)]
14pub struct CrdtMeta {
15    pub timestamp: HlcTimestamp,
16    pub node_id: NodeId,
17}
18
19/// Wire size of CrdtMeta: HLC (12B) + NodeId (8B) = 20 bytes.
20pub const CRDT_META_SIZE: usize = 20;
21
22/// Wire size of a full CRDT-encoded value header: kind (1B) + padding (3B) + meta (20B) = 24 bytes.
23pub const CRDT_HEADER_SIZE: usize = 24;
24
25/// Type of CRDT entry.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27#[repr(u8)]
28pub enum EntryKind {
29    /// Key-value write (user value follows the header).
30    Put = 0,
31    /// Logical delete (tombstone). No user value follows.
32    Tombstone = 1,
33}
34
35impl EntryKind {
36    pub fn from_u8(v: u8) -> Option<Self> {
37        match v {
38            0 => Some(Self::Put),
39            1 => Some(Self::Tombstone),
40            _ => None,
41        }
42    }
43}
44
45impl CrdtMeta {
46    #[inline]
47    pub fn new(timestamp: HlcTimestamp, node_id: NodeId) -> Self {
48        Self { timestamp, node_id }
49    }
50
51    pub fn to_bytes(&self) -> [u8; CRDT_META_SIZE] {
52        let mut buf = [0u8; CRDT_META_SIZE];
53        let ts_bytes = self.timestamp.to_bytes();
54        let nid_bytes = self.node_id.to_bytes();
55        buf[0..12].copy_from_slice(&ts_bytes);
56        buf[12..20].copy_from_slice(&nid_bytes);
57        buf
58    }
59
60    pub fn from_bytes(b: &[u8; CRDT_META_SIZE]) -> Self {
61        let ts = HlcTimestamp::from_bytes(b[0..12].try_into().unwrap());
62        let nid = NodeId::from_bytes(b[12..20].try_into().unwrap());
63        Self {
64            timestamp: ts,
65            node_id: nid,
66        }
67    }
68
69    /// Higher timestamp wins, NodeId tiebreaker.
70    #[inline]
71    pub fn lww_cmp(&self, other: &Self) -> std::cmp::Ordering {
72        self.timestamp
73            .cmp(&other.timestamp)
74            .then(self.node_id.cmp(&other.node_id))
75    }
76
77    #[inline]
78    pub fn wins_over(&self, other: &Self) -> bool {
79        self.lww_cmp(other) == std::cmp::Ordering::Greater
80    }
81}
82
83impl std::fmt::Debug for CrdtMeta {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "CrdtMeta({:?}, {:?})", self.timestamp, self.node_id)
86    }
87}
88
89/// Encode a user value with CRDT header.
90///
91/// Format: `[entry_kind: u8][_pad: 3B][HLC: 12B][NodeId: 8B][user_value...]`
92///
93/// Total header: 24 bytes. User value follows immediately after.
94pub fn encode_lww_value(meta: &CrdtMeta, kind: EntryKind, user_value: &[u8]) -> Vec<u8> {
95    let user_len = if kind == EntryKind::Tombstone {
96        0
97    } else {
98        user_value.len()
99    };
100    let mut buf = Vec::with_capacity(CRDT_HEADER_SIZE + user_len);
101    buf.push(kind as u8);
102    buf.extend_from_slice(&[0u8; 3]); // padding
103    buf.extend_from_slice(&meta.to_bytes());
104    if kind == EntryKind::Put {
105        buf.extend_from_slice(user_value);
106    }
107    buf
108}
109
110#[derive(Debug)]
111pub struct DecodedValue<'a> {
112    pub meta: CrdtMeta,
113    pub kind: EntryKind,
114    pub user_value: &'a [u8],
115}
116
117pub fn decode_lww_value(data: &[u8]) -> Result<DecodedValue<'_>, DecodeError> {
118    if data.len() < CRDT_HEADER_SIZE {
119        return Err(DecodeError::TooShort {
120            expected: CRDT_HEADER_SIZE,
121            actual: data.len(),
122        });
123    }
124
125    let kind = EntryKind::from_u8(data[0]).ok_or(DecodeError::InvalidEntryKind(data[0]))?;
126    // bytes 1..4 are padding (ignored on read)
127    let meta_bytes: &[u8; CRDT_META_SIZE] = data[4..24].try_into().unwrap();
128    let meta = CrdtMeta::from_bytes(meta_bytes);
129
130    let user_value = if kind == EntryKind::Tombstone {
131        &data[CRDT_HEADER_SIZE..CRDT_HEADER_SIZE] // empty slice
132    } else {
133        &data[CRDT_HEADER_SIZE..]
134    };
135
136    Ok(DecodedValue {
137        meta,
138        kind,
139        user_value,
140    })
141}
142
143#[derive(Debug, thiserror::Error)]
144pub enum DecodeError {
145    #[error("CRDT value too short: expected at least {expected} bytes, got {actual}")]
146    TooShort { expected: usize, actual: usize },
147
148    #[error("invalid CRDT entry kind: {0}")]
149    InvalidEntryKind(u8),
150}
151
152/// Merge two CRDT entries for the same key.
153///
154/// Returns which side wins using LWW resolution:
155/// higher timestamp wins, NodeId tiebreaker.
156/// The entry kind (Put vs Tombstone) does NOT affect the merge -
157/// a tombstone with a higher timestamp defeats a put with a lower one.
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum MergeResult {
160    /// Keep the local entry.
161    Local,
162    /// Take the remote entry.
163    Remote,
164    /// Both entries are identical.
165    Equal,
166}
167
168pub fn lww_merge(local: &CrdtMeta, remote: &CrdtMeta) -> MergeResult {
169    match local.lww_cmp(remote) {
170        std::cmp::Ordering::Greater => MergeResult::Local,
171        std::cmp::Ordering::Less => MergeResult::Remote,
172        std::cmp::Ordering::Equal => MergeResult::Equal,
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::hlc::HlcTimestamp;
180    use crate::node_id::NodeId;
181
182    const SECOND: i64 = 1_000_000_000;
183
184    fn meta(wall_ns: i64, logical: i32, node: u64) -> CrdtMeta {
185        CrdtMeta::new(HlcTimestamp::new(wall_ns, logical), NodeId::from_u64(node))
186    }
187
188    // ── CrdtMeta basics ──────────────────────────────────────────────
189
190    #[test]
191    fn meta_new_and_accessors() {
192        let ts = HlcTimestamp::new(1000 * SECOND, 5);
193        let nid = NodeId::from_u64(42);
194        let m = CrdtMeta::new(ts, nid);
195        assert_eq!(m.timestamp, ts);
196        assert_eq!(m.node_id, nid);
197    }
198
199    #[test]
200    fn meta_bytes_roundtrip() {
201        let m = meta(1000 * SECOND, 42, 0xDEADBEEF);
202        let bytes = m.to_bytes();
203        assert_eq!(bytes.len(), CRDT_META_SIZE);
204        let m2 = CrdtMeta::from_bytes(&bytes);
205        assert_eq!(m, m2);
206    }
207
208    #[test]
209    fn meta_bytes_roundtrip_zero() {
210        let m = meta(0, 0, 0);
211        let bytes = m.to_bytes();
212        let m2 = CrdtMeta::from_bytes(&bytes);
213        assert_eq!(m, m2);
214    }
215
216    #[test]
217    fn meta_bytes_roundtrip_max() {
218        let m = meta(i64::MAX, i32::MAX, u64::MAX);
219        let bytes = m.to_bytes();
220        let m2 = CrdtMeta::from_bytes(&bytes);
221        assert_eq!(m, m2);
222    }
223
224    #[test]
225    fn meta_debug_format() {
226        let m = meta(1_000_000_000, 5, 255);
227        let s = format!("{m:?}");
228        assert!(s.contains("CrdtMeta"));
229        assert!(s.contains("HLC"));
230        assert!(s.contains("NodeId"));
231    }
232
233    // ── LWW comparison ───────────────────────────────────────────────
234
235    #[test]
236    fn lww_higher_timestamp_wins() {
237        let a = meta(1000 * SECOND, 0, 1);
238        let b = meta(1001 * SECOND, 0, 1);
239        assert!(b.wins_over(&a));
240        assert!(!a.wins_over(&b));
241    }
242
243    #[test]
244    fn lww_higher_logical_wins() {
245        let a = meta(1000 * SECOND, 5, 1);
246        let b = meta(1000 * SECOND, 6, 1);
247        assert!(b.wins_over(&a));
248        assert!(!a.wins_over(&b));
249    }
250
251    #[test]
252    fn lww_node_id_tiebreaker() {
253        let a = meta(1000 * SECOND, 5, 100);
254        let b = meta(1000 * SECOND, 5, 200);
255        assert!(b.wins_over(&a));
256        assert!(!a.wins_over(&b));
257    }
258
259    #[test]
260    fn lww_equal_entries() {
261        let a = meta(1000 * SECOND, 5, 100);
262        let b = meta(1000 * SECOND, 5, 100);
263        assert!(!a.wins_over(&b));
264        assert!(!b.wins_over(&a));
265        assert_eq!(a.lww_cmp(&b), std::cmp::Ordering::Equal);
266    }
267
268    #[test]
269    fn lww_timestamp_dominates_node_id() {
270        // Even with lower node_id, higher timestamp wins
271        let a = meta(1001 * SECOND, 0, 1);
272        let b = meta(1000 * SECOND, 0, u64::MAX);
273        assert!(a.wins_over(&b));
274    }
275
276    // ── LWW merge function ───────────────────────────────────────────
277
278    #[test]
279    fn merge_local_wins() {
280        let local = meta(1001 * SECOND, 0, 1);
281        let remote = meta(1000 * SECOND, 0, 1);
282        assert_eq!(lww_merge(&local, &remote), MergeResult::Local);
283    }
284
285    #[test]
286    fn merge_remote_wins() {
287        let local = meta(1000 * SECOND, 0, 1);
288        let remote = meta(1001 * SECOND, 0, 1);
289        assert_eq!(lww_merge(&local, &remote), MergeResult::Remote);
290    }
291
292    #[test]
293    fn merge_equal() {
294        let local = meta(1000 * SECOND, 5, 100);
295        let remote = meta(1000 * SECOND, 5, 100);
296        assert_eq!(lww_merge(&local, &remote), MergeResult::Equal);
297    }
298
299    // ── CRDT properties ──────────────────────────────────────────────
300
301    #[test]
302    fn merge_commutativity() {
303        let entries = [
304            meta(1000 * SECOND, 0, 1),
305            meta(1000 * SECOND, 0, 2),
306            meta(1001 * SECOND, 0, 1),
307            meta(1000 * SECOND, 1, 1),
308        ];
309
310        for a in &entries {
311            for b in &entries {
312                let ab = lww_merge(a, b);
313                let ba = lww_merge(b, a);
314                // Commutativity: merge(a,b) mirror equals merge(b,a)
315                match (ab, ba) {
316                    (MergeResult::Local, MergeResult::Remote) => {}
317                    (MergeResult::Remote, MergeResult::Local) => {}
318                    (MergeResult::Equal, MergeResult::Equal) => {}
319                    _ => panic!("commutativity violated for {a:?} vs {b:?}: {ab:?} vs {ba:?}"),
320                }
321            }
322        }
323    }
324
325    #[test]
326    fn merge_associativity() {
327        // For three entries, the winner should be the same regardless of merge order.
328        let a = meta(1000 * SECOND, 0, 1);
329        let b = meta(1001 * SECOND, 5, 2);
330        let c = meta(1001 * SECOND, 5, 3);
331
332        // Winner is c (same timestamp as b, higher node_id)
333        // merge(merge(a, b), c) should pick the same winner as merge(a, merge(b, c))
334
335        fn winner(local: &CrdtMeta, remote: &CrdtMeta) -> CrdtMeta {
336            match lww_merge(local, remote) {
337                MergeResult::Local | MergeResult::Equal => *local,
338                MergeResult::Remote => *remote,
339            }
340        }
341
342        let ab = winner(&a, &b);
343        let ab_c = winner(&ab, &c);
344
345        let bc = winner(&b, &c);
346        let a_bc = winner(&a, &bc);
347
348        assert_eq!(ab_c, a_bc, "associativity violated");
349    }
350
351    #[test]
352    fn merge_idempotency() {
353        let a = meta(1000 * SECOND, 5, 42);
354        assert_eq!(lww_merge(&a, &a), MergeResult::Equal);
355    }
356
357    // ── EntryKind ────────────────────────────────────────────────────
358
359    #[test]
360    fn entry_kind_roundtrip() {
361        assert_eq!(EntryKind::from_u8(0), Some(EntryKind::Put));
362        assert_eq!(EntryKind::from_u8(1), Some(EntryKind::Tombstone));
363        assert_eq!(EntryKind::from_u8(2), None);
364        assert_eq!(EntryKind::from_u8(255), None);
365    }
366
367    // ── Value encoding ───────────────────────────────────────────────
368
369    #[test]
370    fn encode_decode_put_roundtrip() {
371        let m = meta(1000 * SECOND, 5, 42);
372        let user_val = b"hello world";
373        let encoded = encode_lww_value(&m, EntryKind::Put, user_val);
374
375        assert_eq!(encoded.len(), CRDT_HEADER_SIZE + user_val.len());
376
377        let decoded = decode_lww_value(&encoded).unwrap();
378        assert_eq!(decoded.meta, m);
379        assert_eq!(decoded.kind, EntryKind::Put);
380        assert_eq!(decoded.user_value, user_val);
381    }
382
383    #[test]
384    fn encode_decode_tombstone_roundtrip() {
385        let m = meta(1000 * SECOND, 5, 42);
386        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"");
387
388        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
389
390        let decoded = decode_lww_value(&encoded).unwrap();
391        assert_eq!(decoded.meta, m);
392        assert_eq!(decoded.kind, EntryKind::Tombstone);
393        assert_eq!(decoded.user_value.len(), 0);
394    }
395
396    #[test]
397    fn encode_tombstone_ignores_user_value() {
398        let m = meta(1000 * SECOND, 5, 42);
399        // Even if user_value is non-empty, tombstone encoding ignores it
400        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"should be ignored");
401        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
402    }
403
404    #[test]
405    fn encode_decode_empty_value() {
406        let m = meta(1000 * SECOND, 0, 1);
407        let encoded = encode_lww_value(&m, EntryKind::Put, b"");
408
409        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
410
411        let decoded = decode_lww_value(&encoded).unwrap();
412        assert_eq!(decoded.kind, EntryKind::Put);
413        assert_eq!(decoded.user_value.len(), 0);
414    }
415
416    #[test]
417    fn encode_decode_large_value() {
418        let m = meta(1000 * SECOND, 0, 1);
419        let user_val = vec![0xAB; 4096];
420        let encoded = encode_lww_value(&m, EntryKind::Put, &user_val);
421
422        assert_eq!(encoded.len(), CRDT_HEADER_SIZE + 4096);
423
424        let decoded = decode_lww_value(&encoded).unwrap();
425        assert_eq!(decoded.user_value, &user_val[..]);
426    }
427
428    #[test]
429    fn decode_too_short() {
430        let err = decode_lww_value(&[0u8; 10]).unwrap_err();
431        assert!(matches!(err, DecodeError::TooShort { .. }));
432    }
433
434    #[test]
435    fn decode_invalid_entry_kind() {
436        let mut data = [0u8; CRDT_HEADER_SIZE];
437        data[0] = 255; // invalid
438        let err = decode_lww_value(&data).unwrap_err();
439        assert!(matches!(err, DecodeError::InvalidEntryKind(255)));
440    }
441
442    #[test]
443    fn header_size_constant() {
444        assert_eq!(CRDT_HEADER_SIZE, 24);
445        assert_eq!(CRDT_META_SIZE, 20);
446        // 1 (kind) + 3 (pad) + 12 (HLC) + 8 (NodeId) = 24
447        assert_eq!(1 + 3 + 12 + 8, CRDT_HEADER_SIZE);
448    }
449
450    // ── Encoding preserves metadata across merge ─────────────────────
451
452    #[test]
453    fn merge_encoded_values() {
454        let local_meta = meta(1000 * SECOND, 0, 1);
455        let remote_meta = meta(1001 * SECOND, 0, 2);
456
457        let local_encoded = encode_lww_value(&local_meta, EntryKind::Put, b"local");
458        let remote_encoded = encode_lww_value(&remote_meta, EntryKind::Put, b"remote");
459
460        let local_decoded = decode_lww_value(&local_encoded).unwrap();
461        let remote_decoded = decode_lww_value(&remote_encoded).unwrap();
462
463        let result = lww_merge(&local_decoded.meta, &remote_decoded.meta);
464        assert_eq!(result, MergeResult::Remote);
465    }
466
467    #[test]
468    fn tombstone_wins_over_put_with_lower_timestamp() {
469        let put_meta = meta(1000 * SECOND, 0, 1);
470        let del_meta = meta(1001 * SECOND, 0, 1);
471
472        let put_encoded = encode_lww_value(&put_meta, EntryKind::Put, b"value");
473        let del_encoded = encode_lww_value(&del_meta, EntryKind::Tombstone, b"");
474
475        let put_decoded = decode_lww_value(&put_encoded).unwrap();
476        let del_decoded = decode_lww_value(&del_encoded).unwrap();
477
478        // Tombstone has higher timestamp - it wins
479        let result = lww_merge(&put_decoded.meta, &del_decoded.meta);
480        assert_eq!(result, MergeResult::Remote);
481        assert_eq!(del_decoded.kind, EntryKind::Tombstone);
482    }
483
484    #[test]
485    fn put_wins_over_tombstone_with_lower_timestamp() {
486        let del_meta = meta(1000 * SECOND, 0, 1);
487        let put_meta = meta(1001 * SECOND, 0, 1);
488
489        let del_encoded = encode_lww_value(&del_meta, EntryKind::Tombstone, b"");
490        let put_encoded = encode_lww_value(&put_meta, EntryKind::Put, b"value");
491
492        let del_decoded = decode_lww_value(&del_encoded).unwrap();
493        let put_decoded = decode_lww_value(&put_encoded).unwrap();
494
495        // Put has higher timestamp - it wins over the tombstone
496        let result = lww_merge(&del_decoded.meta, &put_decoded.meta);
497        assert_eq!(result, MergeResult::Remote);
498        assert_eq!(put_decoded.kind, EntryKind::Put);
499    }
500
501    // ── Binary format verification ───────────────────────────────────
502
503    #[test]
504    fn encoded_format_put() {
505        let m = CrdtMeta::new(
506            HlcTimestamp::new(0x0102_0304_0506_0708, 0x090A0B0C),
507            NodeId::from_u64(0x1112_1314_1516_1718),
508        );
509        let encoded = encode_lww_value(&m, EntryKind::Put, b"\xAA\xBB");
510
511        // kind=0, pad=[0,0,0], HLC=8B+4B, NodeId=8B, value=2B
512        assert_eq!(encoded[0], 0x00); // Put
513        assert_eq!(&encoded[1..4], &[0, 0, 0]); // padding
514                                                // HLC wall_time big-endian
515        assert_eq!(
516            &encoded[4..12],
517            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
518        );
519        // HLC logical big-endian
520        assert_eq!(&encoded[12..16], &[0x09, 0x0A, 0x0B, 0x0C]);
521        // NodeId big-endian
522        assert_eq!(
523            &encoded[16..24],
524            &[0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18]
525        );
526        // User value
527        assert_eq!(&encoded[24..26], &[0xAA, 0xBB]);
528    }
529
530    #[test]
531    fn encoded_format_tombstone() {
532        let m = meta(1000 * SECOND, 0, 1);
533        let encoded = encode_lww_value(&m, EntryKind::Tombstone, b"");
534        assert_eq!(encoded[0], 0x01); // Tombstone
535        assert_eq!(encoded.len(), CRDT_HEADER_SIZE);
536    }
537
538    // ── Stress: many merges ──────────────────────────────────────────
539
540    #[test]
541    fn merge_many_entries_finds_latest() {
542        let entries: Vec<CrdtMeta> = (0..100)
543            .map(|i| meta(1000 * SECOND + i as i64, 0, i as u64))
544            .collect();
545
546        let mut winner = entries[0];
547        for e in &entries[1..] {
548            if lww_merge(&winner, e) == MergeResult::Remote {
549                winner = *e;
550            }
551        }
552
553        // Last entry should win (highest timestamp and node_id)
554        assert_eq!(winner.timestamp.wall_time(), 1000 * SECOND + 99);
555        assert_eq!(winner.node_id.as_u64(), 99);
556    }
557
558    #[test]
559    fn merge_reverse_order_same_result() {
560        let entries: Vec<CrdtMeta> = (0..100)
561            .map(|i| meta(1000 * SECOND + i as i64, 0, i as u64))
562            .collect();
563
564        // Forward merge
565        let mut fwd_winner = entries[0];
566        for e in &entries[1..] {
567            if lww_merge(&fwd_winner, e) == MergeResult::Remote {
568                fwd_winner = *e;
569            }
570        }
571
572        // Reverse merge
573        let mut rev_winner = entries[99];
574        for e in entries[..99].iter().rev() {
575            if lww_merge(&rev_winner, e) == MergeResult::Remote {
576                rev_winner = *e;
577            }
578        }
579
580        assert_eq!(fwd_winner, rev_winner);
581    }
582
583    #[test]
584    fn merge_shuffled_order_same_result() {
585        use std::collections::BTreeSet;
586
587        // Create entries with different timestamps
588        let entries: Vec<CrdtMeta> = (0..50)
589            .map(|i| meta(1000 * SECOND + (i * 7 % 50) as i64, 0, i as u64))
590            .collect();
591
592        // Find absolute winner (max by lww_cmp)
593        let expected = entries.iter().max_by(|a, b| a.lww_cmp(b)).unwrap();
594
595        // Merge in original order
596        let mut winner = entries[0];
597        for e in &entries[1..] {
598            if lww_merge(&winner, e) == MergeResult::Remote {
599                winner = *e;
600            }
601        }
602
603        assert_eq!(winner, *expected);
604
605        // Merge in BTreeSet-sorted order (different from insertion order)
606        let sorted: BTreeSet<u64> = entries
607            .iter()
608            .map(|e| e.timestamp.wall_time() as u64)
609            .collect();
610        assert!(sorted.len() <= entries.len()); // some might collide, that's fine
611    }
612}