Skip to main content

citadel_sync/
patch.rs

1use crate::crdt::{CrdtMeta, EntryKind, CRDT_HEADER_SIZE, CRDT_META_SIZE};
2use crate::diff::DiffResult;
3use crate::node_id::NodeId;
4
5const PATCH_MAGIC: u32 = 0x53594E43; // "SYNC"
6const PATCH_VERSION: u8 = 1;
7
8const FLAG_HAS_CRDT: u8 = 0x01;
9
10/// A single entry in a sync patch.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct PatchEntry {
13    pub key: Vec<u8>,
14    pub value: Vec<u8>,
15    pub kind: EntryKind,
16    pub crdt_meta: Option<CrdtMeta>,
17}
18
19/// A serializable sync patch containing entries to apply to a target database.
20#[derive(Debug, Clone)]
21pub struct SyncPatch {
22    pub source_node: NodeId,
23    pub entries: Vec<PatchEntry>,
24    pub crdt_aware: bool,
25}
26
27/// Errors from patch serialization/deserialization.
28#[derive(Debug, thiserror::Error)]
29pub enum PatchError {
30    #[error("invalid patch magic: expected {expected:#010x}, got {actual:#010x}")]
31    InvalidMagic { expected: u32, actual: u32 },
32
33    #[error("unsupported patch version: {0}")]
34    UnsupportedVersion(u8),
35
36    #[error("patch data truncated: expected at least {expected} bytes, got {actual}")]
37    Truncated { expected: usize, actual: usize },
38
39    #[error("invalid entry kind: {0}")]
40    InvalidEntryKind(u8),
41}
42
43impl SyncPatch {
44    /// Build a SyncPatch from a DiffResult.
45    ///
46    /// If `crdt_aware` is true, values are expected to contain CRDT headers
47    /// and entries will carry CrdtMeta extracted from the value prefix.
48    pub fn from_diff(source_node: NodeId, diff: &DiffResult, crdt_aware: bool) -> Self {
49        let entries = diff
50            .entries
51            .iter()
52            .map(|e| {
53                if crdt_aware && e.value.len() >= CRDT_HEADER_SIZE {
54                    if let Ok(decoded) = crate::crdt::decode_lww_value(&e.value) {
55                        return PatchEntry {
56                            key: e.key.clone(),
57                            value: e.value.clone(),
58                            kind: decoded.kind,
59                            crdt_meta: Some(decoded.meta),
60                        };
61                    }
62                }
63                PatchEntry {
64                    key: e.key.clone(),
65                    value: e.value.clone(),
66                    kind: EntryKind::Put,
67                    crdt_meta: None,
68                }
69            })
70            .collect();
71
72        SyncPatch {
73            source_node,
74            entries,
75            crdt_aware,
76        }
77    }
78
79    /// Create an empty patch.
80    pub fn empty(source_node: NodeId) -> Self {
81        SyncPatch {
82            source_node,
83            entries: Vec::new(),
84            crdt_aware: false,
85        }
86    }
87
88    pub fn len(&self) -> usize {
89        self.entries.len()
90    }
91
92    pub fn is_empty(&self) -> bool {
93        self.entries.is_empty()
94    }
95
96    /// Serialize to binary wire format.
97    ///
98    /// Format:
99    /// ```text
100    /// [magic: u32 LE][version: u8][flags: u8][source_node: 8B][entry_count: u32 LE]
101    /// Per entry:
102    ///   [key_len: u16 LE][value_len: u32 LE][kind: u8]
103    ///   [crdt_meta: 20B]  (if FLAG_HAS_CRDT)
104    ///   [key: key_len bytes][value: value_len bytes]
105    /// ```
106    pub fn serialize(&self) -> Vec<u8> {
107        let flags = if self.crdt_aware { FLAG_HAS_CRDT } else { 0 };
108
109        let header_size = 4 + 1 + 1 + 8 + 4; // 18
110        let per_entry_overhead = 2 + 4 + 1 + if self.crdt_aware { CRDT_META_SIZE } else { 0 };
111        let data_size: usize = self
112            .entries
113            .iter()
114            .map(|e| per_entry_overhead + e.key.len() + e.value.len())
115            .sum();
116
117        let mut buf = Vec::with_capacity(header_size + data_size);
118
119        buf.extend_from_slice(&PATCH_MAGIC.to_le_bytes());
120        buf.push(PATCH_VERSION);
121        buf.push(flags);
122        buf.extend_from_slice(&self.source_node.to_bytes());
123        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
124
125        for entry in &self.entries {
126            buf.extend_from_slice(&(entry.key.len() as u16).to_le_bytes());
127            buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
128            buf.push(entry.kind as u8);
129            if self.crdt_aware {
130                if let Some(ref meta) = entry.crdt_meta {
131                    buf.extend_from_slice(&meta.to_bytes());
132                } else {
133                    buf.extend_from_slice(&[0u8; CRDT_META_SIZE]);
134                }
135            }
136            buf.extend_from_slice(&entry.key);
137            buf.extend_from_slice(&entry.value);
138        }
139
140        buf
141    }
142
143    /// Deserialize from binary wire format.
144    pub fn deserialize(data: &[u8]) -> Result<Self, PatchError> {
145        let header_size = 4 + 1 + 1 + 8 + 4; // 18 bytes
146        if data.len() < header_size {
147            return Err(PatchError::Truncated {
148                expected: header_size,
149                actual: data.len(),
150            });
151        }
152
153        let magic = u32::from_le_bytes(data[0..4].try_into().unwrap());
154        if magic != PATCH_MAGIC {
155            return Err(PatchError::InvalidMagic {
156                expected: PATCH_MAGIC,
157                actual: magic,
158            });
159        }
160
161        let version = data[4];
162        if version != PATCH_VERSION {
163            return Err(PatchError::UnsupportedVersion(version));
164        }
165
166        let flags = data[5];
167        let crdt_aware = (flags & FLAG_HAS_CRDT) != 0;
168        let source_node = NodeId::from_bytes(data[6..14].try_into().unwrap());
169        let entry_count = u32::from_le_bytes(data[14..18].try_into().unwrap()) as usize;
170
171        let mut entries = Vec::with_capacity(entry_count);
172        let mut pos = header_size;
173
174        for _ in 0..entry_count {
175            // key_len (2) + value_len (4) + kind (1) = 7
176            let entry_header = 7 + if crdt_aware { CRDT_META_SIZE } else { 0 };
177            if pos + entry_header > data.len() {
178                return Err(PatchError::Truncated {
179                    expected: pos + entry_header,
180                    actual: data.len(),
181                });
182            }
183
184            let key_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
185            let value_len = u32::from_le_bytes(data[pos + 2..pos + 6].try_into().unwrap()) as usize;
186            let kind_byte = data[pos + 6];
187            let kind =
188                EntryKind::from_u8(kind_byte).ok_or(PatchError::InvalidEntryKind(kind_byte))?;
189            pos += 7;
190
191            let crdt_meta = if crdt_aware {
192                let meta_bytes: &[u8; CRDT_META_SIZE] =
193                    data[pos..pos + CRDT_META_SIZE].try_into().unwrap();
194                pos += CRDT_META_SIZE;
195                Some(CrdtMeta::from_bytes(meta_bytes))
196            } else {
197                None
198            };
199
200            if pos + key_len + value_len > data.len() {
201                return Err(PatchError::Truncated {
202                    expected: pos + key_len + value_len,
203                    actual: data.len(),
204                });
205            }
206
207            let key = data[pos..pos + key_len].to_vec();
208            pos += key_len;
209            let value = data[pos..pos + value_len].to_vec();
210            pos += value_len;
211
212            entries.push(PatchEntry {
213                key,
214                value,
215                kind,
216                crdt_meta,
217            });
218        }
219
220        Ok(SyncPatch {
221            source_node,
222            entries,
223            crdt_aware,
224        })
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::hlc::HlcTimestamp;
232
233    fn meta(wall_ns: i64, logical: i32, node: u64) -> CrdtMeta {
234        CrdtMeta::new(HlcTimestamp::new(wall_ns, logical), NodeId::from_u64(node))
235    }
236
237    #[test]
238    fn empty_patch_roundtrip() {
239        let patch = SyncPatch::empty(NodeId::from_u64(42));
240        let data = patch.serialize();
241        let decoded = SyncPatch::deserialize(&data).unwrap();
242        assert!(decoded.is_empty());
243        assert_eq!(decoded.source_node, NodeId::from_u64(42));
244        assert!(!decoded.crdt_aware);
245    }
246
247    #[test]
248    fn patch_with_entries_roundtrip() {
249        let patch = SyncPatch {
250            source_node: NodeId::from_u64(1),
251            entries: vec![
252                PatchEntry {
253                    key: b"key1".to_vec(),
254                    value: b"value1".to_vec(),
255                    kind: EntryKind::Put,
256                    crdt_meta: None,
257                },
258                PatchEntry {
259                    key: b"key2".to_vec(),
260                    value: b"value2".to_vec(),
261                    kind: EntryKind::Put,
262                    crdt_meta: None,
263                },
264            ],
265            crdt_aware: false,
266        };
267
268        let data = patch.serialize();
269        let decoded = SyncPatch::deserialize(&data).unwrap();
270        assert_eq!(decoded.len(), 2);
271        assert_eq!(decoded.entries[0].key, b"key1");
272        assert_eq!(decoded.entries[0].value, b"value1");
273        assert_eq!(decoded.entries[1].key, b"key2");
274        assert_eq!(decoded.entries[1].value, b"value2");
275    }
276
277    #[test]
278    fn crdt_patch_roundtrip() {
279        let m = meta(1_000_000_000, 5, 42);
280        let patch = SyncPatch {
281            source_node: NodeId::from_u64(1),
282            entries: vec![
283                PatchEntry {
284                    key: b"key1".to_vec(),
285                    value: b"value1".to_vec(),
286                    kind: EntryKind::Put,
287                    crdt_meta: Some(m),
288                },
289                PatchEntry {
290                    key: b"key2".to_vec(),
291                    value: Vec::new(),
292                    kind: EntryKind::Tombstone,
293                    crdt_meta: Some(m),
294                },
295            ],
296            crdt_aware: true,
297        };
298
299        let data = patch.serialize();
300        let decoded = SyncPatch::deserialize(&data).unwrap();
301        assert_eq!(decoded.len(), 2);
302        assert!(decoded.crdt_aware);
303        assert_eq!(decoded.entries[0].crdt_meta, Some(m));
304        assert_eq!(decoded.entries[0].kind, EntryKind::Put);
305        assert_eq!(decoded.entries[1].crdt_meta, Some(m));
306        assert_eq!(decoded.entries[1].kind, EntryKind::Tombstone);
307    }
308
309    #[test]
310    fn large_values_roundtrip() {
311        let big_key = vec![0xAA; 2048];
312        let big_val = vec![0xBB; 8192];
313        let patch = SyncPatch {
314            source_node: NodeId::from_u64(99),
315            entries: vec![PatchEntry {
316                key: big_key.clone(),
317                value: big_val.clone(),
318                kind: EntryKind::Put,
319                crdt_meta: None,
320            }],
321            crdt_aware: false,
322        };
323
324        let data = patch.serialize();
325        let decoded = SyncPatch::deserialize(&data).unwrap();
326        assert_eq!(decoded.entries[0].key, big_key);
327        assert_eq!(decoded.entries[0].value, big_val);
328    }
329
330    #[test]
331    fn invalid_magic_error() {
332        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
333        data[0] = 0xFF; // corrupt magic
334        let err = SyncPatch::deserialize(&data).unwrap_err();
335        assert!(matches!(err, PatchError::InvalidMagic { .. }));
336    }
337
338    #[test]
339    fn unsupported_version_error() {
340        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
341        data[4] = 99; // bad version
342        let err = SyncPatch::deserialize(&data).unwrap_err();
343        assert!(matches!(err, PatchError::UnsupportedVersion(99)));
344    }
345
346    #[test]
347    fn truncated_header_error() {
348        let err = SyncPatch::deserialize(&[0u8; 5]).unwrap_err();
349        assert!(matches!(err, PatchError::Truncated { .. }));
350    }
351
352    #[test]
353    fn truncated_entry_error() {
354        let patch = SyncPatch {
355            source_node: NodeId::from_u64(1),
356            entries: vec![PatchEntry {
357                key: b"key".to_vec(),
358                value: b"value".to_vec(),
359                kind: EntryKind::Put,
360                crdt_meta: None,
361            }],
362            crdt_aware: false,
363        };
364        let mut data = patch.serialize();
365        data.truncate(data.len() - 3); // cut off end of value
366        let err = SyncPatch::deserialize(&data).unwrap_err();
367        assert!(matches!(err, PatchError::Truncated { .. }));
368    }
369
370    #[test]
371    fn invalid_entry_kind_error() {
372        let patch = SyncPatch {
373            source_node: NodeId::from_u64(1),
374            entries: vec![PatchEntry {
375                key: b"k".to_vec(),
376                value: b"v".to_vec(),
377                kind: EntryKind::Put,
378                crdt_meta: None,
379            }],
380            crdt_aware: false,
381        };
382        let mut data = patch.serialize();
383        // Entry kind byte is at offset 18 (header) + 6 (after key_len + value_len)
384        data[18 + 6] = 255; // invalid kind
385        let err = SyncPatch::deserialize(&data).unwrap_err();
386        assert!(matches!(err, PatchError::InvalidEntryKind(255)));
387    }
388
389    #[test]
390    fn many_entries_roundtrip() {
391        let entries: Vec<PatchEntry> = (0..1000u32)
392            .map(|i| PatchEntry {
393                key: i.to_be_bytes().to_vec(),
394                value: format!("val-{i}").into_bytes(),
395                kind: EntryKind::Put,
396                crdt_meta: None,
397            })
398            .collect();
399
400        let patch = SyncPatch {
401            source_node: NodeId::from_u64(7),
402            entries,
403            crdt_aware: false,
404        };
405
406        let data = patch.serialize();
407        let decoded = SyncPatch::deserialize(&data).unwrap();
408        assert_eq!(decoded.len(), 1000);
409        for (i, entry) in decoded.entries.iter().enumerate() {
410            assert_eq!(entry.key, (i as u32).to_be_bytes());
411        }
412    }
413
414    #[test]
415    fn from_diff_non_crdt() {
416        let diff = DiffResult {
417            entries: vec![
418                crate::diff::DiffEntry {
419                    key: b"k1".to_vec(),
420                    value: b"v1".to_vec(),
421                    val_type: 0,
422                },
423                crate::diff::DiffEntry {
424                    key: b"k2".to_vec(),
425                    value: b"v2".to_vec(),
426                    val_type: 0,
427                },
428            ],
429            pages_compared: 5,
430            subtrees_skipped: 2,
431        };
432
433        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, false);
434        assert_eq!(patch.len(), 2);
435        assert!(!patch.crdt_aware);
436        assert_eq!(patch.entries[0].key, b"k1");
437        assert!(patch.entries[0].crdt_meta.is_none());
438    }
439
440    #[test]
441    fn from_diff_crdt_extracts_meta() {
442        let m = meta(1_000_000_000, 5, 42);
443        let crdt_value = crate::crdt::encode_lww_value(&m, EntryKind::Put, b"user-data");
444
445        let diff = DiffResult {
446            entries: vec![crate::diff::DiffEntry {
447                key: b"k1".to_vec(),
448                value: crdt_value,
449                val_type: 0,
450            }],
451            pages_compared: 1,
452            subtrees_skipped: 0,
453        };
454
455        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, true);
456        assert_eq!(patch.len(), 1);
457        assert!(patch.crdt_aware);
458        assert_eq!(patch.entries[0].crdt_meta, Some(m));
459        assert_eq!(patch.entries[0].kind, EntryKind::Put);
460    }
461}