Skip to main content

citadel_sync/
patch.rs

1use crate::crdt::{CrdtMeta, EntryKind, CRDT_HEADER_SIZE, CRDT_META_SIZE};
2use crate::diff::DiffResult;
3use crate::node_id::NodeId;
4
5const PATCH_MAGIC: u32 = 0x53594E43; // "SYNC"
6const PATCH_VERSION: u8 = 1;
7
8const FLAG_HAS_CRDT: u8 = 0x01;
9
10/// A single entry in a sync patch.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct PatchEntry {
13    pub key: Vec<u8>,
14    pub value: Vec<u8>,
15    pub kind: EntryKind,
16    pub crdt_meta: Option<CrdtMeta>,
17}
18
19/// A serializable sync patch containing entries to apply to a target database.
20#[derive(Debug, Clone)]
21pub struct SyncPatch {
22    pub source_node: NodeId,
23    pub entries: Vec<PatchEntry>,
24    pub crdt_aware: bool,
25}
26
27/// Errors from patch serialization/deserialization.
28#[derive(Debug, thiserror::Error)]
29pub enum PatchError {
30    #[error("invalid patch magic: expected {expected:#010x}, got {actual:#010x}")]
31    InvalidMagic { expected: u32, actual: u32 },
32
33    #[error("unsupported patch version: {0}")]
34    UnsupportedVersion(u8),
35
36    #[error("patch data truncated: expected at least {expected} bytes, got {actual}")]
37    Truncated { expected: usize, actual: usize },
38
39    #[error("invalid entry kind: {0}")]
40    InvalidEntryKind(u8),
41}
42
43impl SyncPatch {
44    /// Build a SyncPatch from a DiffResult.
45    ///
46    /// If `crdt_aware` is true, values are expected to contain CRDT headers
47    /// and entries will carry CrdtMeta extracted from the value prefix.
48    pub fn from_diff(source_node: NodeId, diff: &DiffResult, crdt_aware: bool) -> Self {
49        let entries = diff
50            .entries
51            .iter()
52            .map(|e| {
53                if crdt_aware && e.value.len() >= CRDT_HEADER_SIZE {
54                    // Try to decode CRDT header from the value
55                    if let Ok(decoded) = crate::crdt::decode_lww_value(&e.value) {
56                        return PatchEntry {
57                            key: e.key.clone(),
58                            value: e.value.clone(),
59                            kind: decoded.kind,
60                            crdt_meta: Some(decoded.meta),
61                        };
62                    }
63                }
64                PatchEntry {
65                    key: e.key.clone(),
66                    value: e.value.clone(),
67                    kind: EntryKind::Put,
68                    crdt_meta: None,
69                }
70            })
71            .collect();
72
73        SyncPatch {
74            source_node,
75            entries,
76            crdt_aware,
77        }
78    }
79
80    /// Create an empty patch.
81    pub fn empty(source_node: NodeId) -> Self {
82        SyncPatch {
83            source_node,
84            entries: Vec::new(),
85            crdt_aware: false,
86        }
87    }
88
89    pub fn len(&self) -> usize {
90        self.entries.len()
91    }
92
93    pub fn is_empty(&self) -> bool {
94        self.entries.is_empty()
95    }
96
97    /// Serialize to binary wire format.
98    ///
99    /// Format:
100    /// ```text
101    /// [magic: u32 LE][version: u8][flags: u8][source_node: 8B][entry_count: u32 LE]
102    /// Per entry:
103    ///   [key_len: u16 LE][value_len: u32 LE][kind: u8]
104    ///   [crdt_meta: 20B]  (if FLAG_HAS_CRDT)
105    ///   [key: key_len bytes][value: value_len bytes]
106    /// ```
107    pub fn serialize(&self) -> Vec<u8> {
108        let flags = if self.crdt_aware { FLAG_HAS_CRDT } else { 0 };
109
110        // Estimate capacity
111        let header_size = 4 + 1 + 1 + 8 + 4; // 18
112        let per_entry_overhead = 2 + 4 + 1 + if self.crdt_aware { CRDT_META_SIZE } else { 0 };
113        let data_size: usize = self
114            .entries
115            .iter()
116            .map(|e| per_entry_overhead + e.key.len() + e.value.len())
117            .sum();
118
119        let mut buf = Vec::with_capacity(header_size + data_size);
120
121        // Header
122        buf.extend_from_slice(&PATCH_MAGIC.to_le_bytes());
123        buf.push(PATCH_VERSION);
124        buf.push(flags);
125        buf.extend_from_slice(&self.source_node.to_bytes());
126        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
127
128        // Entries
129        for entry in &self.entries {
130            buf.extend_from_slice(&(entry.key.len() as u16).to_le_bytes());
131            buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
132            buf.push(entry.kind as u8);
133            if self.crdt_aware {
134                if let Some(ref meta) = entry.crdt_meta {
135                    buf.extend_from_slice(&meta.to_bytes());
136                } else {
137                    buf.extend_from_slice(&[0u8; CRDT_META_SIZE]);
138                }
139            }
140            buf.extend_from_slice(&entry.key);
141            buf.extend_from_slice(&entry.value);
142        }
143
144        buf
145    }
146
147    /// Deserialize from binary wire format.
148    pub fn deserialize(data: &[u8]) -> Result<Self, PatchError> {
149        let header_size = 4 + 1 + 1 + 8 + 4; // 18 bytes
150        if data.len() < header_size {
151            return Err(PatchError::Truncated {
152                expected: header_size,
153                actual: data.len(),
154            });
155        }
156
157        let magic = u32::from_le_bytes(data[0..4].try_into().unwrap());
158        if magic != PATCH_MAGIC {
159            return Err(PatchError::InvalidMagic {
160                expected: PATCH_MAGIC,
161                actual: magic,
162            });
163        }
164
165        let version = data[4];
166        if version != PATCH_VERSION {
167            return Err(PatchError::UnsupportedVersion(version));
168        }
169
170        let flags = data[5];
171        let crdt_aware = (flags & FLAG_HAS_CRDT) != 0;
172        let source_node = NodeId::from_bytes(data[6..14].try_into().unwrap());
173        let entry_count = u32::from_le_bytes(data[14..18].try_into().unwrap()) as usize;
174
175        let mut entries = Vec::with_capacity(entry_count);
176        let mut pos = header_size;
177
178        for _ in 0..entry_count {
179            // key_len (2) + value_len (4) + kind (1) = 7
180            let entry_header = 7 + if crdt_aware { CRDT_META_SIZE } else { 0 };
181            if pos + entry_header > data.len() {
182                return Err(PatchError::Truncated {
183                    expected: pos + entry_header,
184                    actual: data.len(),
185                });
186            }
187
188            let key_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
189            let value_len = u32::from_le_bytes(data[pos + 2..pos + 6].try_into().unwrap()) as usize;
190            let kind_byte = data[pos + 6];
191            let kind =
192                EntryKind::from_u8(kind_byte).ok_or(PatchError::InvalidEntryKind(kind_byte))?;
193            pos += 7;
194
195            let crdt_meta = if crdt_aware {
196                let meta_bytes: &[u8; CRDT_META_SIZE] =
197                    data[pos..pos + CRDT_META_SIZE].try_into().unwrap();
198                pos += CRDT_META_SIZE;
199                Some(CrdtMeta::from_bytes(meta_bytes))
200            } else {
201                None
202            };
203
204            if pos + key_len + value_len > data.len() {
205                return Err(PatchError::Truncated {
206                    expected: pos + key_len + value_len,
207                    actual: data.len(),
208                });
209            }
210
211            let key = data[pos..pos + key_len].to_vec();
212            pos += key_len;
213            let value = data[pos..pos + value_len].to_vec();
214            pos += value_len;
215
216            entries.push(PatchEntry {
217                key,
218                value,
219                kind,
220                crdt_meta,
221            });
222        }
223
224        Ok(SyncPatch {
225            source_node,
226            entries,
227            crdt_aware,
228        })
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::hlc::HlcTimestamp;
236
237    fn meta(wall_ns: i64, logical: i32, node: u64) -> CrdtMeta {
238        CrdtMeta::new(HlcTimestamp::new(wall_ns, logical), NodeId::from_u64(node))
239    }
240
241    #[test]
242    fn empty_patch_roundtrip() {
243        let patch = SyncPatch::empty(NodeId::from_u64(42));
244        let data = patch.serialize();
245        let decoded = SyncPatch::deserialize(&data).unwrap();
246        assert!(decoded.is_empty());
247        assert_eq!(decoded.source_node, NodeId::from_u64(42));
248        assert!(!decoded.crdt_aware);
249    }
250
251    #[test]
252    fn patch_with_entries_roundtrip() {
253        let patch = SyncPatch {
254            source_node: NodeId::from_u64(1),
255            entries: vec![
256                PatchEntry {
257                    key: b"key1".to_vec(),
258                    value: b"value1".to_vec(),
259                    kind: EntryKind::Put,
260                    crdt_meta: None,
261                },
262                PatchEntry {
263                    key: b"key2".to_vec(),
264                    value: b"value2".to_vec(),
265                    kind: EntryKind::Put,
266                    crdt_meta: None,
267                },
268            ],
269            crdt_aware: false,
270        };
271
272        let data = patch.serialize();
273        let decoded = SyncPatch::deserialize(&data).unwrap();
274        assert_eq!(decoded.len(), 2);
275        assert_eq!(decoded.entries[0].key, b"key1");
276        assert_eq!(decoded.entries[0].value, b"value1");
277        assert_eq!(decoded.entries[1].key, b"key2");
278        assert_eq!(decoded.entries[1].value, b"value2");
279    }
280
281    #[test]
282    fn crdt_patch_roundtrip() {
283        let m = meta(1_000_000_000, 5, 42);
284        let patch = SyncPatch {
285            source_node: NodeId::from_u64(1),
286            entries: vec![
287                PatchEntry {
288                    key: b"key1".to_vec(),
289                    value: b"value1".to_vec(),
290                    kind: EntryKind::Put,
291                    crdt_meta: Some(m),
292                },
293                PatchEntry {
294                    key: b"key2".to_vec(),
295                    value: Vec::new(),
296                    kind: EntryKind::Tombstone,
297                    crdt_meta: Some(m),
298                },
299            ],
300            crdt_aware: true,
301        };
302
303        let data = patch.serialize();
304        let decoded = SyncPatch::deserialize(&data).unwrap();
305        assert_eq!(decoded.len(), 2);
306        assert!(decoded.crdt_aware);
307        assert_eq!(decoded.entries[0].crdt_meta, Some(m));
308        assert_eq!(decoded.entries[0].kind, EntryKind::Put);
309        assert_eq!(decoded.entries[1].crdt_meta, Some(m));
310        assert_eq!(decoded.entries[1].kind, EntryKind::Tombstone);
311    }
312
313    #[test]
314    fn large_values_roundtrip() {
315        let big_key = vec![0xAA; 2048];
316        let big_val = vec![0xBB; 8192];
317        let patch = SyncPatch {
318            source_node: NodeId::from_u64(99),
319            entries: vec![PatchEntry {
320                key: big_key.clone(),
321                value: big_val.clone(),
322                kind: EntryKind::Put,
323                crdt_meta: None,
324            }],
325            crdt_aware: false,
326        };
327
328        let data = patch.serialize();
329        let decoded = SyncPatch::deserialize(&data).unwrap();
330        assert_eq!(decoded.entries[0].key, big_key);
331        assert_eq!(decoded.entries[0].value, big_val);
332    }
333
334    #[test]
335    fn invalid_magic_error() {
336        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
337        data[0] = 0xFF; // corrupt magic
338        let err = SyncPatch::deserialize(&data).unwrap_err();
339        assert!(matches!(err, PatchError::InvalidMagic { .. }));
340    }
341
342    #[test]
343    fn unsupported_version_error() {
344        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
345        data[4] = 99; // bad version
346        let err = SyncPatch::deserialize(&data).unwrap_err();
347        assert!(matches!(err, PatchError::UnsupportedVersion(99)));
348    }
349
350    #[test]
351    fn truncated_header_error() {
352        let err = SyncPatch::deserialize(&[0u8; 5]).unwrap_err();
353        assert!(matches!(err, PatchError::Truncated { .. }));
354    }
355
356    #[test]
357    fn truncated_entry_error() {
358        let patch = SyncPatch {
359            source_node: NodeId::from_u64(1),
360            entries: vec![PatchEntry {
361                key: b"key".to_vec(),
362                value: b"value".to_vec(),
363                kind: EntryKind::Put,
364                crdt_meta: None,
365            }],
366            crdt_aware: false,
367        };
368        let mut data = patch.serialize();
369        data.truncate(data.len() - 3); // cut off end of value
370        let err = SyncPatch::deserialize(&data).unwrap_err();
371        assert!(matches!(err, PatchError::Truncated { .. }));
372    }
373
374    #[test]
375    fn invalid_entry_kind_error() {
376        let patch = SyncPatch {
377            source_node: NodeId::from_u64(1),
378            entries: vec![PatchEntry {
379                key: b"k".to_vec(),
380                value: b"v".to_vec(),
381                kind: EntryKind::Put,
382                crdt_meta: None,
383            }],
384            crdt_aware: false,
385        };
386        let mut data = patch.serialize();
387        // Entry kind byte is at offset 18 (header) + 6 (after key_len + value_len)
388        data[18 + 6] = 255; // invalid kind
389        let err = SyncPatch::deserialize(&data).unwrap_err();
390        assert!(matches!(err, PatchError::InvalidEntryKind(255)));
391    }
392
393    #[test]
394    fn many_entries_roundtrip() {
395        let entries: Vec<PatchEntry> = (0..1000u32)
396            .map(|i| PatchEntry {
397                key: i.to_be_bytes().to_vec(),
398                value: format!("val-{i}").into_bytes(),
399                kind: EntryKind::Put,
400                crdt_meta: None,
401            })
402            .collect();
403
404        let patch = SyncPatch {
405            source_node: NodeId::from_u64(7),
406            entries,
407            crdt_aware: false,
408        };
409
410        let data = patch.serialize();
411        let decoded = SyncPatch::deserialize(&data).unwrap();
412        assert_eq!(decoded.len(), 1000);
413        for (i, entry) in decoded.entries.iter().enumerate() {
414            assert_eq!(entry.key, (i as u32).to_be_bytes());
415        }
416    }
417
418    #[test]
419    fn from_diff_non_crdt() {
420        let diff = DiffResult {
421            entries: vec![
422                crate::diff::DiffEntry {
423                    key: b"k1".to_vec(),
424                    value: b"v1".to_vec(),
425                    val_type: 0,
426                },
427                crate::diff::DiffEntry {
428                    key: b"k2".to_vec(),
429                    value: b"v2".to_vec(),
430                    val_type: 0,
431                },
432            ],
433            pages_compared: 5,
434            subtrees_skipped: 2,
435        };
436
437        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, false);
438        assert_eq!(patch.len(), 2);
439        assert!(!patch.crdt_aware);
440        assert_eq!(patch.entries[0].key, b"k1");
441        assert!(patch.entries[0].crdt_meta.is_none());
442    }
443
444    #[test]
445    fn from_diff_crdt_extracts_meta() {
446        let m = meta(1_000_000_000, 5, 42);
447        let crdt_value = crate::crdt::encode_lww_value(&m, EntryKind::Put, b"user-data");
448
449        let diff = DiffResult {
450            entries: vec![crate::diff::DiffEntry {
451                key: b"k1".to_vec(),
452                value: crdt_value,
453                val_type: 0,
454            }],
455            pages_compared: 1,
456            subtrees_skipped: 0,
457        };
458
459        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, true);
460        assert_eq!(patch.len(), 1);
461        assert!(patch.crdt_aware);
462        assert_eq!(patch.entries[0].crdt_meta, Some(m));
463        assert_eq!(patch.entries[0].kind, EntryKind::Put);
464    }
465}