Skip to main content

citadeldb_sync/
patch.rs

1use crate::crdt::{CrdtMeta, EntryKind, CRDT_HEADER_SIZE, CRDT_META_SIZE};
2use crate::diff::DiffResult;
3use crate::node_id::NodeId;
4
5const PATCH_MAGIC: u32 = 0x53594E43; // "SYNC"
6const PATCH_VERSION: u8 = 1;
7
8const FLAG_HAS_CRDT: u8 = 0x01;
9
10/// A single entry in a sync patch.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct PatchEntry {
13    pub key: Vec<u8>,
14    pub value: Vec<u8>,
15    pub kind: EntryKind,
16    pub crdt_meta: Option<CrdtMeta>,
17}
18
19/// A serializable sync patch containing entries to apply to a target database.
20#[derive(Debug, Clone)]
21pub struct SyncPatch {
22    pub source_node: NodeId,
23    pub entries: Vec<PatchEntry>,
24    pub crdt_aware: bool,
25}
26
27/// Errors from patch serialization/deserialization.
28#[derive(Debug, thiserror::Error)]
29pub enum PatchError {
30    #[error("invalid patch magic: expected {expected:#010x}, got {actual:#010x}")]
31    InvalidMagic { expected: u32, actual: u32 },
32
33    #[error("unsupported patch version: {0}")]
34    UnsupportedVersion(u8),
35
36    #[error("patch data truncated: expected at least {expected} bytes, got {actual}")]
37    Truncated { expected: usize, actual: usize },
38
39    #[error("invalid entry kind: {0}")]
40    InvalidEntryKind(u8),
41}
42
43impl SyncPatch {
44    /// Build a SyncPatch from a DiffResult.
45    ///
46    /// If `crdt_aware` is true, values are expected to contain CRDT headers
47    /// and entries will carry CrdtMeta extracted from the value prefix.
48    pub fn from_diff(source_node: NodeId, diff: &DiffResult, crdt_aware: bool) -> Self {
49        let entries = diff
50            .entries
51            .iter()
52            .map(|e| {
53                if crdt_aware && e.value.len() >= CRDT_HEADER_SIZE {
54                    // Try to decode CRDT header from the value
55                    if let Ok(decoded) = crate::crdt::decode_lww_value(&e.value) {
56                        return PatchEntry {
57                            key: e.key.clone(),
58                            value: e.value.clone(),
59                            kind: decoded.kind,
60                            crdt_meta: Some(decoded.meta),
61                        };
62                    }
63                }
64                PatchEntry {
65                    key: e.key.clone(),
66                    value: e.value.clone(),
67                    kind: EntryKind::Put,
68                    crdt_meta: None,
69                }
70            })
71            .collect();
72
73        SyncPatch {
74            source_node,
75            entries,
76            crdt_aware,
77        }
78    }
79
80    /// Create an empty patch.
81    pub fn empty(source_node: NodeId) -> Self {
82        SyncPatch {
83            source_node,
84            entries: Vec::new(),
85            crdt_aware: false,
86        }
87    }
88
89    pub fn len(&self) -> usize {
90        self.entries.len()
91    }
92
93    pub fn is_empty(&self) -> bool {
94        self.entries.is_empty()
95    }
96
97    /// Serialize to binary wire format.
98    ///
99    /// Format:
100    /// ```text
101    /// [magic: u32 LE][version: u8][flags: u8][source_node: 8B][entry_count: u32 LE]
102    /// Per entry:
103    ///   [key_len: u16 LE][value_len: u32 LE][kind: u8]
104    ///   [crdt_meta: 20B]  (if FLAG_HAS_CRDT)
105    ///   [key: key_len bytes][value: value_len bytes]
106    /// ```
107    pub fn serialize(&self) -> Vec<u8> {
108        let flags = if self.crdt_aware { FLAG_HAS_CRDT } else { 0 };
109
110        // Estimate capacity
111        let header_size = 4 + 1 + 1 + 8 + 4; // 18
112        let per_entry_overhead = 2 + 4 + 1 + if self.crdt_aware { CRDT_META_SIZE } else { 0 };
113        let data_size: usize = self
114            .entries
115            .iter()
116            .map(|e| per_entry_overhead + e.key.len() + e.value.len())
117            .sum();
118
119        let mut buf = Vec::with_capacity(header_size + data_size);
120
121        // Header
122        buf.extend_from_slice(&PATCH_MAGIC.to_le_bytes());
123        buf.push(PATCH_VERSION);
124        buf.push(flags);
125        buf.extend_from_slice(&self.source_node.to_bytes());
126        buf.extend_from_slice(&(self.entries.len() as u32).to_le_bytes());
127
128        // Entries
129        for entry in &self.entries {
130            buf.extend_from_slice(&(entry.key.len() as u16).to_le_bytes());
131            buf.extend_from_slice(&(entry.value.len() as u32).to_le_bytes());
132            buf.push(entry.kind as u8);
133            if self.crdt_aware {
134                if let Some(ref meta) = entry.crdt_meta {
135                    buf.extend_from_slice(&meta.to_bytes());
136                } else {
137                    buf.extend_from_slice(&[0u8; CRDT_META_SIZE]);
138                }
139            }
140            buf.extend_from_slice(&entry.key);
141            buf.extend_from_slice(&entry.value);
142        }
143
144        buf
145    }
146
147    /// Deserialize from binary wire format.
148    pub fn deserialize(data: &[u8]) -> Result<Self, PatchError> {
149        let header_size = 4 + 1 + 1 + 8 + 4; // 18 bytes
150        if data.len() < header_size {
151            return Err(PatchError::Truncated {
152                expected: header_size,
153                actual: data.len(),
154            });
155        }
156
157        let magic = u32::from_le_bytes(data[0..4].try_into().unwrap());
158        if magic != PATCH_MAGIC {
159            return Err(PatchError::InvalidMagic {
160                expected: PATCH_MAGIC,
161                actual: magic,
162            });
163        }
164
165        let version = data[4];
166        if version != PATCH_VERSION {
167            return Err(PatchError::UnsupportedVersion(version));
168        }
169
170        let flags = data[5];
171        let crdt_aware = (flags & FLAG_HAS_CRDT) != 0;
172        let source_node = NodeId::from_bytes(data[6..14].try_into().unwrap());
173        let entry_count = u32::from_le_bytes(data[14..18].try_into().unwrap()) as usize;
174
175        let mut entries = Vec::with_capacity(entry_count);
176        let mut pos = header_size;
177
178        for _ in 0..entry_count {
179            // key_len (2) + value_len (4) + kind (1) = 7
180            let entry_header = 7 + if crdt_aware { CRDT_META_SIZE } else { 0 };
181            if pos + entry_header > data.len() {
182                return Err(PatchError::Truncated {
183                    expected: pos + entry_header,
184                    actual: data.len(),
185                });
186            }
187
188            let key_len = u16::from_le_bytes(data[pos..pos + 2].try_into().unwrap()) as usize;
189            let value_len =
190                u32::from_le_bytes(data[pos + 2..pos + 6].try_into().unwrap()) as usize;
191            let kind_byte = data[pos + 6];
192            let kind = EntryKind::from_u8(kind_byte)
193                .ok_or(PatchError::InvalidEntryKind(kind_byte))?;
194            pos += 7;
195
196            let crdt_meta = if crdt_aware {
197                let meta_bytes: &[u8; CRDT_META_SIZE] =
198                    data[pos..pos + CRDT_META_SIZE].try_into().unwrap();
199                pos += CRDT_META_SIZE;
200                Some(CrdtMeta::from_bytes(meta_bytes))
201            } else {
202                None
203            };
204
205            if pos + key_len + value_len > data.len() {
206                return Err(PatchError::Truncated {
207                    expected: pos + key_len + value_len,
208                    actual: data.len(),
209                });
210            }
211
212            let key = data[pos..pos + key_len].to_vec();
213            pos += key_len;
214            let value = data[pos..pos + value_len].to_vec();
215            pos += value_len;
216
217            entries.push(PatchEntry {
218                key,
219                value,
220                kind,
221                crdt_meta,
222            });
223        }
224
225        Ok(SyncPatch {
226            source_node,
227            entries,
228            crdt_aware,
229        })
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::hlc::HlcTimestamp;
237
238    fn meta(wall_ns: i64, logical: i32, node: u64) -> CrdtMeta {
239        CrdtMeta::new(HlcTimestamp::new(wall_ns, logical), NodeId::from_u64(node))
240    }
241
242    #[test]
243    fn empty_patch_roundtrip() {
244        let patch = SyncPatch::empty(NodeId::from_u64(42));
245        let data = patch.serialize();
246        let decoded = SyncPatch::deserialize(&data).unwrap();
247        assert!(decoded.is_empty());
248        assert_eq!(decoded.source_node, NodeId::from_u64(42));
249        assert!(!decoded.crdt_aware);
250    }
251
252    #[test]
253    fn patch_with_entries_roundtrip() {
254        let patch = SyncPatch {
255            source_node: NodeId::from_u64(1),
256            entries: vec![
257                PatchEntry {
258                    key: b"key1".to_vec(),
259                    value: b"value1".to_vec(),
260                    kind: EntryKind::Put,
261                    crdt_meta: None,
262                },
263                PatchEntry {
264                    key: b"key2".to_vec(),
265                    value: b"value2".to_vec(),
266                    kind: EntryKind::Put,
267                    crdt_meta: None,
268                },
269            ],
270            crdt_aware: false,
271        };
272
273        let data = patch.serialize();
274        let decoded = SyncPatch::deserialize(&data).unwrap();
275        assert_eq!(decoded.len(), 2);
276        assert_eq!(decoded.entries[0].key, b"key1");
277        assert_eq!(decoded.entries[0].value, b"value1");
278        assert_eq!(decoded.entries[1].key, b"key2");
279        assert_eq!(decoded.entries[1].value, b"value2");
280    }
281
282    #[test]
283    fn crdt_patch_roundtrip() {
284        let m = meta(1_000_000_000, 5, 42);
285        let patch = SyncPatch {
286            source_node: NodeId::from_u64(1),
287            entries: vec![
288                PatchEntry {
289                    key: b"key1".to_vec(),
290                    value: b"value1".to_vec(),
291                    kind: EntryKind::Put,
292                    crdt_meta: Some(m),
293                },
294                PatchEntry {
295                    key: b"key2".to_vec(),
296                    value: Vec::new(),
297                    kind: EntryKind::Tombstone,
298                    crdt_meta: Some(m),
299                },
300            ],
301            crdt_aware: true,
302        };
303
304        let data = patch.serialize();
305        let decoded = SyncPatch::deserialize(&data).unwrap();
306        assert_eq!(decoded.len(), 2);
307        assert!(decoded.crdt_aware);
308        assert_eq!(decoded.entries[0].crdt_meta, Some(m));
309        assert_eq!(decoded.entries[0].kind, EntryKind::Put);
310        assert_eq!(decoded.entries[1].crdt_meta, Some(m));
311        assert_eq!(decoded.entries[1].kind, EntryKind::Tombstone);
312    }
313
314    #[test]
315    fn large_values_roundtrip() {
316        let big_key = vec![0xAA; 2048];
317        let big_val = vec![0xBB; 8192];
318        let patch = SyncPatch {
319            source_node: NodeId::from_u64(99),
320            entries: vec![PatchEntry {
321                key: big_key.clone(),
322                value: big_val.clone(),
323                kind: EntryKind::Put,
324                crdt_meta: None,
325            }],
326            crdt_aware: false,
327        };
328
329        let data = patch.serialize();
330        let decoded = SyncPatch::deserialize(&data).unwrap();
331        assert_eq!(decoded.entries[0].key, big_key);
332        assert_eq!(decoded.entries[0].value, big_val);
333    }
334
335    #[test]
336    fn invalid_magic_error() {
337        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
338        data[0] = 0xFF; // corrupt magic
339        let err = SyncPatch::deserialize(&data).unwrap_err();
340        assert!(matches!(err, PatchError::InvalidMagic { .. }));
341    }
342
343    #[test]
344    fn unsupported_version_error() {
345        let mut data = SyncPatch::empty(NodeId::from_u64(1)).serialize();
346        data[4] = 99; // bad version
347        let err = SyncPatch::deserialize(&data).unwrap_err();
348        assert!(matches!(err, PatchError::UnsupportedVersion(99)));
349    }
350
351    #[test]
352    fn truncated_header_error() {
353        let err = SyncPatch::deserialize(&[0u8; 5]).unwrap_err();
354        assert!(matches!(err, PatchError::Truncated { .. }));
355    }
356
357    #[test]
358    fn truncated_entry_error() {
359        let patch = SyncPatch {
360            source_node: NodeId::from_u64(1),
361            entries: vec![PatchEntry {
362                key: b"key".to_vec(),
363                value: b"value".to_vec(),
364                kind: EntryKind::Put,
365                crdt_meta: None,
366            }],
367            crdt_aware: false,
368        };
369        let mut data = patch.serialize();
370        data.truncate(data.len() - 3); // cut off end of value
371        let err = SyncPatch::deserialize(&data).unwrap_err();
372        assert!(matches!(err, PatchError::Truncated { .. }));
373    }
374
375    #[test]
376    fn invalid_entry_kind_error() {
377        let patch = SyncPatch {
378            source_node: NodeId::from_u64(1),
379            entries: vec![PatchEntry {
380                key: b"k".to_vec(),
381                value: b"v".to_vec(),
382                kind: EntryKind::Put,
383                crdt_meta: None,
384            }],
385            crdt_aware: false,
386        };
387        let mut data = patch.serialize();
388        // Entry kind byte is at offset 18 (header) + 6 (after key_len + value_len)
389        data[18 + 6] = 255; // invalid kind
390        let err = SyncPatch::deserialize(&data).unwrap_err();
391        assert!(matches!(err, PatchError::InvalidEntryKind(255)));
392    }
393
394    #[test]
395    fn many_entries_roundtrip() {
396        let entries: Vec<PatchEntry> = (0..1000u32)
397            .map(|i| PatchEntry {
398                key: i.to_be_bytes().to_vec(),
399                value: format!("val-{i}").into_bytes(),
400                kind: EntryKind::Put,
401                crdt_meta: None,
402            })
403            .collect();
404
405        let patch = SyncPatch {
406            source_node: NodeId::from_u64(7),
407            entries,
408            crdt_aware: false,
409        };
410
411        let data = patch.serialize();
412        let decoded = SyncPatch::deserialize(&data).unwrap();
413        assert_eq!(decoded.len(), 1000);
414        for (i, entry) in decoded.entries.iter().enumerate() {
415            assert_eq!(entry.key, (i as u32).to_be_bytes());
416        }
417    }
418
419    #[test]
420    fn from_diff_non_crdt() {
421        let diff = DiffResult {
422            entries: vec![
423                crate::diff::DiffEntry {
424                    key: b"k1".to_vec(),
425                    value: b"v1".to_vec(),
426                    val_type: 0,
427                },
428                crate::diff::DiffEntry {
429                    key: b"k2".to_vec(),
430                    value: b"v2".to_vec(),
431                    val_type: 0,
432                },
433            ],
434            pages_compared: 5,
435            subtrees_skipped: 2,
436        };
437
438        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, false);
439        assert_eq!(patch.len(), 2);
440        assert!(!patch.crdt_aware);
441        assert_eq!(patch.entries[0].key, b"k1");
442        assert!(patch.entries[0].crdt_meta.is_none());
443    }
444
445    #[test]
446    fn from_diff_crdt_extracts_meta() {
447        let m = meta(1_000_000_000, 5, 42);
448        let crdt_value = crate::crdt::encode_lww_value(&m, EntryKind::Put, b"user-data");
449
450        let diff = DiffResult {
451            entries: vec![crate::diff::DiffEntry {
452                key: b"k1".to_vec(),
453                value: crdt_value,
454                val_type: 0,
455            }],
456            pages_compared: 1,
457            subtrees_skipped: 0,
458        };
459
460        let patch = SyncPatch::from_diff(NodeId::from_u64(1), &diff, true);
461        assert_eq!(patch.len(), 1);
462        assert!(patch.crdt_aware);
463        assert_eq!(patch.entries[0].crdt_meta, Some(m));
464        assert_eq!(patch.entries[0].kind, EntryKind::Put);
465    }
466}