Skip to main content

citadeldb_sync/
protocol.rs

1use citadel_core::types::PageId;
2use citadel_core::MERKLE_HASH_SIZE;
3
4use crate::apply::ApplyResult;
5use crate::diff::{DiffEntry, MerkleHash, PageDigest};
6use crate::node_id::NodeId;
7
8/// Message type tags for wire format.
9const MSG_HELLO: u8 = 0;
10const MSG_HELLO_ACK: u8 = 1;
11const MSG_DIGEST_REQUEST: u8 = 2;
12const MSG_DIGEST_RESPONSE: u8 = 3;
13const MSG_ENTRIES_REQUEST: u8 = 4;
14const MSG_ENTRIES_RESPONSE: u8 = 5;
15const MSG_PATCH_DATA: u8 = 6;
16const MSG_PATCH_ACK: u8 = 7;
17const MSG_DONE: u8 = 8;
18const MSG_ERROR: u8 = 9;
19const MSG_PULL_REQUEST: u8 = 10;
20const MSG_PULL_RESPONSE: u8 = 11;
21const MSG_TABLE_LIST_REQUEST: u8 = 12;
22const MSG_TABLE_LIST_RESPONSE: u8 = 13;
23const MSG_TABLE_SYNC_BEGIN: u8 = 14;
24const MSG_TABLE_SYNC_END: u8 = 15;
25
26/// Metadata about a named table for multi-table sync negotiation.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TableInfo {
29    pub name: Vec<u8>,
30    pub root_page: PageId,
31    pub root_hash: MerkleHash,
32}
33
34/// Sync protocol messages exchanged between initiator and responder.
35#[derive(Debug, Clone)]
36pub enum SyncMessage {
37    /// Initiator greeting with identity and tree root state.
38    Hello {
39        node_id: NodeId,
40        root_page: PageId,
41        root_hash: MerkleHash,
42    },
43    /// Responder acknowledgment with its own tree root state.
44    HelloAck {
45        node_id: NodeId,
46        root_page: PageId,
47        root_hash: MerkleHash,
48        in_sync: bool,
49    },
50    /// Request page digests from the remote tree.
51    DigestRequest {
52        page_ids: Vec<PageId>,
53    },
54    /// Response with page digests.
55    DigestResponse {
56        digests: Vec<PageDigest>,
57    },
58    /// Request leaf entries from remote pages.
59    EntriesRequest {
60        page_ids: Vec<PageId>,
61    },
62    /// Response with leaf entries.
63    EntriesResponse {
64        entries: Vec<DiffEntry>,
65    },
66    /// Serialized SyncPatch data.
67    PatchData {
68        data: Vec<u8>,
69    },
70    /// Acknowledgment after applying a patch.
71    PatchAck {
72        result: ApplyResult,
73    },
74    /// Session complete.
75    Done,
76    /// Error during sync.
77    Error {
78        message: String,
79    },
80    /// Request updated root info for pull phase after push.
81    PullRequest,
82    /// Response with updated root info for pull phase.
83    PullResponse {
84        root_page: PageId,
85        root_hash: MerkleHash,
86    },
87    /// Request list of named tables from the remote peer.
88    TableListRequest,
89    /// Response with the list of named tables.
90    TableListResponse {
91        tables: Vec<TableInfo>,
92    },
93    /// Begin syncing a specific named table.
94    TableSyncBegin {
95        table_name: Vec<u8>,
96        root_page: PageId,
97        root_hash: MerkleHash,
98    },
99    /// End syncing a specific named table.
100    TableSyncEnd {
101        table_name: Vec<u8>,
102    },
103}
104
105/// Errors from sync message serialization/deserialization.
106#[derive(Debug, thiserror::Error)]
107pub enum ProtocolError {
108    #[error("{context}: expected at least {expected} bytes, got {actual}")]
109    Truncated { context: String, expected: usize, actual: usize },
110
111    #[error("unknown message type: {0}")]
112    UnknownMessageType(u8),
113}
114
115impl SyncMessage {
116    /// Serialize to wire format: `[msg_type: u8][payload_len: u32 LE][payload]`.
117    pub fn serialize(&self) -> Vec<u8> {
118        let (msg_type, payload) = match self {
119            SyncMessage::Hello { node_id, root_page, root_hash } => {
120                let mut p = Vec::with_capacity(40);
121                p.extend_from_slice(&node_id.to_bytes());
122                p.extend_from_slice(&root_page.0.to_le_bytes());
123                p.extend_from_slice(root_hash);
124                (MSG_HELLO, p)
125            }
126            SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync } => {
127                let mut p = Vec::with_capacity(41);
128                p.extend_from_slice(&node_id.to_bytes());
129                p.extend_from_slice(&root_page.0.to_le_bytes());
130                p.extend_from_slice(root_hash);
131                p.push(if *in_sync { 1 } else { 0 });
132                (MSG_HELLO_ACK, p)
133            }
134            SyncMessage::DigestRequest { page_ids } => {
135                let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
136                p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
137                for pid in page_ids {
138                    p.extend_from_slice(&pid.0.to_le_bytes());
139                }
140                (MSG_DIGEST_REQUEST, p)
141            }
142            SyncMessage::DigestResponse { digests } => {
143                let mut p = Vec::new();
144                p.extend_from_slice(&(digests.len() as u32).to_le_bytes());
145                for d in digests {
146                    serialize_page_digest(&mut p, d);
147                }
148                (MSG_DIGEST_RESPONSE, p)
149            }
150            SyncMessage::EntriesRequest { page_ids } => {
151                let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
152                p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
153                for pid in page_ids {
154                    p.extend_from_slice(&pid.0.to_le_bytes());
155                }
156                (MSG_ENTRIES_REQUEST, p)
157            }
158            SyncMessage::EntriesResponse { entries } => {
159                let mut p = Vec::new();
160                p.extend_from_slice(&(entries.len() as u32).to_le_bytes());
161                for e in entries {
162                    serialize_diff_entry(&mut p, e);
163                }
164                (MSG_ENTRIES_RESPONSE, p)
165            }
166            SyncMessage::PatchData { data } => {
167                (MSG_PATCH_DATA, data.clone())
168            }
169            SyncMessage::PatchAck { result } => {
170                let mut p = Vec::with_capacity(24);
171                p.extend_from_slice(&result.entries_applied.to_le_bytes());
172                p.extend_from_slice(&result.entries_skipped.to_le_bytes());
173                p.extend_from_slice(&result.entries_equal.to_le_bytes());
174                (MSG_PATCH_ACK, p)
175            }
176            SyncMessage::Done => {
177                (MSG_DONE, Vec::new())
178            }
179            SyncMessage::Error { message } => {
180                let bytes = message.as_bytes();
181                let mut p = Vec::with_capacity(4 + bytes.len());
182                p.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
183                p.extend_from_slice(bytes);
184                (MSG_ERROR, p)
185            }
186            SyncMessage::PullRequest => {
187                (MSG_PULL_REQUEST, Vec::new())
188            }
189            SyncMessage::PullResponse { root_page, root_hash } => {
190                let mut p = Vec::with_capacity(32);
191                p.extend_from_slice(&root_page.0.to_le_bytes());
192                p.extend_from_slice(root_hash);
193                (MSG_PULL_RESPONSE, p)
194            }
195            SyncMessage::TableListRequest => {
196                (MSG_TABLE_LIST_REQUEST, Vec::new())
197            }
198            SyncMessage::TableListResponse { tables } => {
199                let mut p = Vec::new();
200                p.extend_from_slice(&(tables.len() as u32).to_le_bytes());
201                for t in tables {
202                    p.extend_from_slice(&(t.name.len() as u16).to_le_bytes());
203                    p.extend_from_slice(&t.name);
204                    p.extend_from_slice(&t.root_page.0.to_le_bytes());
205                    p.extend_from_slice(&t.root_hash);
206                }
207                (MSG_TABLE_LIST_RESPONSE, p)
208            }
209            SyncMessage::TableSyncBegin { table_name, root_page, root_hash } => {
210                let mut p = Vec::with_capacity(2 + table_name.len() + 4 + MERKLE_HASH_SIZE);
211                p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
212                p.extend_from_slice(table_name);
213                p.extend_from_slice(&root_page.0.to_le_bytes());
214                p.extend_from_slice(root_hash);
215                (MSG_TABLE_SYNC_BEGIN, p)
216            }
217            SyncMessage::TableSyncEnd { table_name } => {
218                let mut p = Vec::with_capacity(2 + table_name.len());
219                p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
220                p.extend_from_slice(table_name);
221                (MSG_TABLE_SYNC_END, p)
222            }
223        };
224
225        let mut buf = Vec::with_capacity(5 + payload.len());
226        buf.push(msg_type);
227        buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
228        buf.extend_from_slice(&payload);
229        buf
230    }
231
232    /// Deserialize from wire format.
233    pub fn deserialize(data: &[u8]) -> Result<Self, ProtocolError> {
234        if data.len() < 5 {
235            return Err(ProtocolError::Truncated {
236                context: "message header".to_string(),
237                expected: 5,
238                actual: data.len(),
239            });
240        }
241
242        let msg_type = data[0];
243        let payload_len = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
244
245        if data.len() < 5 + payload_len {
246            return Err(ProtocolError::Truncated {
247                context: "message payload".to_string(),
248                expected: 5 + payload_len,
249                actual: data.len(),
250            });
251        }
252
253        let payload = &data[5..5 + payload_len];
254
255        match msg_type {
256            MSG_HELLO => {
257                ensure_len(payload, 40, "Hello")?;
258                let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
259                let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
260                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
261                root_hash.copy_from_slice(&payload[12..40]);
262                Ok(SyncMessage::Hello { node_id, root_page, root_hash })
263            }
264            MSG_HELLO_ACK => {
265                ensure_len(payload, 41, "HelloAck")?;
266                let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
267                let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
268                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
269                root_hash.copy_from_slice(&payload[12..40]);
270                let in_sync = payload[40] != 0;
271                Ok(SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync })
272            }
273            MSG_DIGEST_REQUEST => {
274                ensure_len(payload, 4, "DigestRequest")?;
275                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
276                ensure_len(payload, 4 + count * 4, "DigestRequest")?;
277                let page_ids = (0..count)
278                    .map(|i| {
279                        let off = 4 + i * 4;
280                        PageId(u32::from_le_bytes(payload[off..off + 4].try_into().unwrap()))
281                    })
282                    .collect();
283                Ok(SyncMessage::DigestRequest { page_ids })
284            }
285            MSG_DIGEST_RESPONSE => {
286                ensure_len(payload, 4, "DigestResponse")?;
287                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
288                let mut pos = 4;
289                let mut digests = Vec::with_capacity(count);
290                for _ in 0..count {
291                    let (digest, consumed) = deserialize_page_digest(payload, pos)?;
292                    digests.push(digest);
293                    pos += consumed;
294                }
295                Ok(SyncMessage::DigestResponse { digests })
296            }
297            MSG_ENTRIES_REQUEST => {
298                ensure_len(payload, 4, "EntriesRequest")?;
299                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
300                ensure_len(payload, 4 + count * 4, "EntriesRequest")?;
301                let page_ids = (0..count)
302                    .map(|i| {
303                        let off = 4 + i * 4;
304                        PageId(u32::from_le_bytes(payload[off..off + 4].try_into().unwrap()))
305                    })
306                    .collect();
307                Ok(SyncMessage::EntriesRequest { page_ids })
308            }
309            MSG_ENTRIES_RESPONSE => {
310                ensure_len(payload, 4, "EntriesResponse")?;
311                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
312                let mut pos = 4;
313                let mut entries = Vec::with_capacity(count);
314                for _ in 0..count {
315                    let (entry, consumed) = deserialize_diff_entry(payload, pos)?;
316                    entries.push(entry);
317                    pos += consumed;
318                }
319                Ok(SyncMessage::EntriesResponse { entries })
320            }
321            MSG_PATCH_DATA => {
322                Ok(SyncMessage::PatchData { data: payload.to_vec() })
323            }
324            MSG_PATCH_ACK => {
325                ensure_len(payload, 24, "PatchAck")?;
326                let entries_applied = u64::from_le_bytes(payload[0..8].try_into().unwrap());
327                let entries_skipped = u64::from_le_bytes(payload[8..16].try_into().unwrap());
328                let entries_equal = u64::from_le_bytes(payload[16..24].try_into().unwrap());
329                Ok(SyncMessage::PatchAck {
330                    result: ApplyResult { entries_applied, entries_skipped, entries_equal },
331                })
332            }
333            MSG_DONE => {
334                Ok(SyncMessage::Done)
335            }
336            MSG_ERROR => {
337                ensure_len(payload, 4, "Error")?;
338                let msg_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
339                ensure_len(payload, 4 + msg_len, "Error")?;
340                let message = String::from_utf8_lossy(&payload[4..4 + msg_len]).into_owned();
341                Ok(SyncMessage::Error { message })
342            }
343            MSG_PULL_REQUEST => {
344                Ok(SyncMessage::PullRequest)
345            }
346            MSG_PULL_RESPONSE => {
347                ensure_len(payload, 32, "PullResponse")?;
348                let root_page = PageId(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
349                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
350                root_hash.copy_from_slice(&payload[4..32]);
351                Ok(SyncMessage::PullResponse { root_page, root_hash })
352            }
353            MSG_TABLE_LIST_REQUEST => {
354                Ok(SyncMessage::TableListRequest)
355            }
356            MSG_TABLE_LIST_RESPONSE => {
357                ensure_len(payload, 4, "TableListResponse")?;
358                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
359                let mut pos = 4;
360                let mut tables = Vec::with_capacity(count);
361                for _ in 0..count {
362                    ensure_len(payload, pos + 2, "TableInfo name_len")?;
363                    let name_len = u16::from_le_bytes(
364                        payload[pos..pos + 2].try_into().unwrap(),
365                    ) as usize;
366                    pos += 2;
367                    ensure_len(payload, pos + name_len + 4 + MERKLE_HASH_SIZE, "TableInfo")?;
368                    let name = payload[pos..pos + name_len].to_vec();
369                    pos += name_len;
370                    let root_page = PageId(u32::from_le_bytes(
371                        payload[pos..pos + 4].try_into().unwrap(),
372                    ));
373                    pos += 4;
374                    let mut root_hash = [0u8; MERKLE_HASH_SIZE];
375                    root_hash.copy_from_slice(&payload[pos..pos + MERKLE_HASH_SIZE]);
376                    pos += MERKLE_HASH_SIZE;
377                    tables.push(TableInfo { name, root_page, root_hash });
378                }
379                Ok(SyncMessage::TableListResponse { tables })
380            }
381            MSG_TABLE_SYNC_BEGIN => {
382                ensure_len(payload, 2, "TableSyncBegin")?;
383                let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
384                ensure_len(payload, 2 + name_len + 4 + MERKLE_HASH_SIZE, "TableSyncBegin")?;
385                let table_name = payload[2..2 + name_len].to_vec();
386                let off = 2 + name_len;
387                let root_page = PageId(u32::from_le_bytes(
388                    payload[off..off + 4].try_into().unwrap(),
389                ));
390                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
391                root_hash.copy_from_slice(&payload[off + 4..off + 4 + MERKLE_HASH_SIZE]);
392                Ok(SyncMessage::TableSyncBegin { table_name, root_page, root_hash })
393            }
394            MSG_TABLE_SYNC_END => {
395                ensure_len(payload, 2, "TableSyncEnd")?;
396                let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
397                ensure_len(payload, 2 + name_len, "TableSyncEnd")?;
398                let table_name = payload[2..2 + name_len].to_vec();
399                Ok(SyncMessage::TableSyncEnd { table_name })
400            }
401            _ => Err(ProtocolError::UnknownMessageType(msg_type)),
402        }
403    }
404}
405
406fn ensure_len(data: &[u8], needed: usize, ctx: &str) -> Result<(), ProtocolError> {
407    if data.len() < needed {
408        Err(ProtocolError::Truncated {
409            context: ctx.to_string(),
410            expected: needed,
411            actual: data.len(),
412        })
413    } else {
414        Ok(())
415    }
416}
417
418fn serialize_page_digest(buf: &mut Vec<u8>, d: &PageDigest) {
419    buf.extend_from_slice(&d.page_id.0.to_le_bytes());
420    buf.extend_from_slice(&(d.page_type as u16).to_le_bytes());
421    buf.extend_from_slice(&d.merkle_hash);
422    buf.extend_from_slice(&(d.children.len() as u32).to_le_bytes());
423    for c in &d.children {
424        buf.extend_from_slice(&c.0.to_le_bytes());
425    }
426}
427
428fn deserialize_page_digest(data: &[u8], offset: usize) -> Result<(PageDigest, usize), ProtocolError> {
429    // page_id(4) + page_type(2) + merkle_hash(28) + child_count(4) = 38
430    let min = 38;
431    if data.len() < offset + min {
432        return Err(ProtocolError::Truncated {
433            context: "PageDigest header".to_string(),
434            expected: offset + min,
435            actual: data.len(),
436        });
437    }
438
439    let page_id = PageId(u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()));
440    let page_type_raw = u16::from_le_bytes(data[offset + 4..offset + 6].try_into().unwrap());
441    let page_type = citadel_core::types::PageType::from_u16(page_type_raw)
442        .unwrap_or(citadel_core::types::PageType::Leaf);
443    let mut merkle_hash = [0u8; MERKLE_HASH_SIZE];
444    merkle_hash.copy_from_slice(&data[offset + 6..offset + 34]);
445    let child_count = u32::from_le_bytes(data[offset + 34..offset + 38].try_into().unwrap()) as usize;
446
447    if data.len() < offset + min + child_count * 4 {
448        return Err(ProtocolError::Truncated {
449            context: "PageDigest children".to_string(),
450            expected: offset + min + child_count * 4,
451            actual: data.len(),
452        });
453    }
454
455    let children = (0..child_count)
456        .map(|i| {
457            let off = offset + 38 + i * 4;
458            PageId(u32::from_le_bytes(data[off..off + 4].try_into().unwrap()))
459        })
460        .collect();
461
462    Ok((
463        PageDigest { page_id, page_type, merkle_hash, children },
464        min + child_count * 4,
465    ))
466}
467
468fn serialize_diff_entry(buf: &mut Vec<u8>, e: &DiffEntry) {
469    buf.extend_from_slice(&(e.key.len() as u16).to_le_bytes());
470    buf.extend_from_slice(&(e.value.len() as u32).to_le_bytes());
471    buf.push(e.val_type);
472    buf.extend_from_slice(&e.key);
473    buf.extend_from_slice(&e.value);
474}
475
476fn deserialize_diff_entry(data: &[u8], offset: usize) -> Result<(DiffEntry, usize), ProtocolError> {
477    // key_len(2) + val_len(4) + val_type(1) = 7
478    let header = 7;
479    if data.len() < offset + header {
480        return Err(ProtocolError::Truncated {
481            context: "DiffEntry header".to_string(),
482            expected: offset + header,
483            actual: data.len(),
484        });
485    }
486
487    let key_len = u16::from_le_bytes(data[offset..offset + 2].try_into().unwrap()) as usize;
488    let val_len = u32::from_le_bytes(data[offset + 2..offset + 6].try_into().unwrap()) as usize;
489    let val_type = data[offset + 6];
490
491    let total = header + key_len + val_len;
492    if data.len() < offset + total {
493        return Err(ProtocolError::Truncated {
494            context: "DiffEntry data".to_string(),
495            expected: offset + total,
496            actual: data.len(),
497        });
498    }
499
500    let key = data[offset + 7..offset + 7 + key_len].to_vec();
501    let value = data[offset + 7 + key_len..offset + 7 + key_len + val_len].to_vec();
502
503    Ok((DiffEntry { key, value, val_type }, total))
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use citadel_core::types::PageType;
510
511    fn sample_hash() -> MerkleHash {
512        let mut h = [0u8; MERKLE_HASH_SIZE];
513        for i in 0..MERKLE_HASH_SIZE {
514            h[i] = i as u8;
515        }
516        h
517    }
518
519    #[test]
520    fn hello_roundtrip() {
521        let msg = SyncMessage::Hello {
522            node_id: NodeId::from_u64(42),
523            root_page: PageId(7),
524            root_hash: sample_hash(),
525        };
526        let data = msg.serialize();
527        let decoded = SyncMessage::deserialize(&data).unwrap();
528        match decoded {
529            SyncMessage::Hello { node_id, root_page, root_hash } => {
530                assert_eq!(node_id, NodeId::from_u64(42));
531                assert_eq!(root_page, PageId(7));
532                assert_eq!(root_hash, sample_hash());
533            }
534            _ => panic!("wrong variant"),
535        }
536    }
537
538    #[test]
539    fn hello_ack_roundtrip() {
540        let msg = SyncMessage::HelloAck {
541            node_id: NodeId::from_u64(99),
542            root_page: PageId(3),
543            root_hash: sample_hash(),
544            in_sync: true,
545        };
546        let data = msg.serialize();
547        let decoded = SyncMessage::deserialize(&data).unwrap();
548        match decoded {
549            SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync } => {
550                assert_eq!(node_id, NodeId::from_u64(99));
551                assert_eq!(root_page, PageId(3));
552                assert_eq!(root_hash, sample_hash());
553                assert!(in_sync);
554            }
555            _ => panic!("wrong variant"),
556        }
557    }
558
559    #[test]
560    fn digest_request_roundtrip() {
561        let msg = SyncMessage::DigestRequest {
562            page_ids: vec![PageId(1), PageId(5), PageId(100)],
563        };
564        let data = msg.serialize();
565        let decoded = SyncMessage::deserialize(&data).unwrap();
566        match decoded {
567            SyncMessage::DigestRequest { page_ids } => {
568                assert_eq!(page_ids, vec![PageId(1), PageId(5), PageId(100)]);
569            }
570            _ => panic!("wrong variant"),
571        }
572    }
573
574    #[test]
575    fn digest_response_roundtrip() {
576        let msg = SyncMessage::DigestResponse {
577            digests: vec![
578                PageDigest {
579                    page_id: PageId(1),
580                    page_type: PageType::Leaf,
581                    merkle_hash: sample_hash(),
582                    children: vec![],
583                },
584                PageDigest {
585                    page_id: PageId(2),
586                    page_type: PageType::Branch,
587                    merkle_hash: [0xAA; MERKLE_HASH_SIZE],
588                    children: vec![PageId(3), PageId(4)],
589                },
590            ],
591        };
592        let data = msg.serialize();
593        let decoded = SyncMessage::deserialize(&data).unwrap();
594        match decoded {
595            SyncMessage::DigestResponse { digests } => {
596                assert_eq!(digests.len(), 2);
597                assert_eq!(digests[0].page_id, PageId(1));
598                assert!(digests[0].children.is_empty());
599                assert_eq!(digests[1].children, vec![PageId(3), PageId(4)]);
600            }
601            _ => panic!("wrong variant"),
602        }
603    }
604
605    #[test]
606    fn entries_request_roundtrip() {
607        let msg = SyncMessage::EntriesRequest {
608            page_ids: vec![PageId(10)],
609        };
610        let data = msg.serialize();
611        let decoded = SyncMessage::deserialize(&data).unwrap();
612        match decoded {
613            SyncMessage::EntriesRequest { page_ids } => {
614                assert_eq!(page_ids, vec![PageId(10)]);
615            }
616            _ => panic!("wrong variant"),
617        }
618    }
619
620    #[test]
621    fn entries_response_roundtrip() {
622        let msg = SyncMessage::EntriesResponse {
623            entries: vec![
624                DiffEntry { key: b"k1".to_vec(), value: b"v1".to_vec(), val_type: 0 },
625                DiffEntry { key: b"k2".to_vec(), value: b"v2".to_vec(), val_type: 1 },
626            ],
627        };
628        let data = msg.serialize();
629        let decoded = SyncMessage::deserialize(&data).unwrap();
630        match decoded {
631            SyncMessage::EntriesResponse { entries } => {
632                assert_eq!(entries.len(), 2);
633                assert_eq!(entries[0].key, b"k1");
634                assert_eq!(entries[1].val_type, 1);
635            }
636            _ => panic!("wrong variant"),
637        }
638    }
639
640    #[test]
641    fn patch_data_roundtrip() {
642        let msg = SyncMessage::PatchData { data: vec![1, 2, 3, 4, 5] };
643        let data = msg.serialize();
644        let decoded = SyncMessage::deserialize(&data).unwrap();
645        match decoded {
646            SyncMessage::PatchData { data: d } => {
647                assert_eq!(d, vec![1, 2, 3, 4, 5]);
648            }
649            _ => panic!("wrong variant"),
650        }
651    }
652
653    #[test]
654    fn patch_ack_roundtrip() {
655        let msg = SyncMessage::PatchAck {
656            result: ApplyResult {
657                entries_applied: 10,
658                entries_skipped: 3,
659                entries_equal: 2,
660            },
661        };
662        let data = msg.serialize();
663        let decoded = SyncMessage::deserialize(&data).unwrap();
664        match decoded {
665            SyncMessage::PatchAck { result } => {
666                assert_eq!(result.entries_applied, 10);
667                assert_eq!(result.entries_skipped, 3);
668                assert_eq!(result.entries_equal, 2);
669            }
670            _ => panic!("wrong variant"),
671        }
672    }
673
674    #[test]
675    fn done_roundtrip() {
676        let data = SyncMessage::Done.serialize();
677        let decoded = SyncMessage::deserialize(&data).unwrap();
678        assert!(matches!(decoded, SyncMessage::Done));
679    }
680
681    #[test]
682    fn error_roundtrip() {
683        let msg = SyncMessage::Error { message: "something broke".into() };
684        let data = msg.serialize();
685        let decoded = SyncMessage::deserialize(&data).unwrap();
686        match decoded {
687            SyncMessage::Error { message } => {
688                assert_eq!(message, "something broke");
689            }
690            _ => panic!("wrong variant"),
691        }
692    }
693
694    #[test]
695    fn pull_request_roundtrip() {
696        let data = SyncMessage::PullRequest.serialize();
697        let decoded = SyncMessage::deserialize(&data).unwrap();
698        assert!(matches!(decoded, SyncMessage::PullRequest));
699    }
700
701    #[test]
702    fn pull_response_roundtrip() {
703        let msg = SyncMessage::PullResponse {
704            root_page: PageId(15),
705            root_hash: sample_hash(),
706        };
707        let data = msg.serialize();
708        let decoded = SyncMessage::deserialize(&data).unwrap();
709        match decoded {
710            SyncMessage::PullResponse { root_page, root_hash } => {
711                assert_eq!(root_page, PageId(15));
712                assert_eq!(root_hash, sample_hash());
713            }
714            _ => panic!("wrong variant"),
715        }
716    }
717
718    #[test]
719    fn truncated_data() {
720        let err = SyncMessage::deserialize(&[0, 1]).unwrap_err();
721        assert!(matches!(err, ProtocolError::Truncated { .. }));
722    }
723
724    #[test]
725    fn unknown_message_type() {
726        let data = [255, 0, 0, 0, 0];
727        let err = SyncMessage::deserialize(&data).unwrap_err();
728        assert!(matches!(err, ProtocolError::UnknownMessageType(255)));
729    }
730
731    #[test]
732    fn empty_digest_request() {
733        let msg = SyncMessage::DigestRequest { page_ids: vec![] };
734        let data = msg.serialize();
735        let decoded = SyncMessage::deserialize(&data).unwrap();
736        match decoded {
737            SyncMessage::DigestRequest { page_ids } => assert!(page_ids.is_empty()),
738            _ => panic!("wrong variant"),
739        }
740    }
741
742    #[test]
743    fn table_list_request_roundtrip() {
744        let data = SyncMessage::TableListRequest.serialize();
745        let decoded = SyncMessage::deserialize(&data).unwrap();
746        assert!(matches!(decoded, SyncMessage::TableListRequest));
747    }
748
749    #[test]
750    fn table_list_response_roundtrip() {
751        let msg = SyncMessage::TableListResponse {
752            tables: vec![
753                TableInfo {
754                    name: b"users".to_vec(),
755                    root_page: PageId(10),
756                    root_hash: sample_hash(),
757                },
758                TableInfo {
759                    name: b"orders".to_vec(),
760                    root_page: PageId(20),
761                    root_hash: [0xBB; MERKLE_HASH_SIZE],
762                },
763            ],
764        };
765        let data = msg.serialize();
766        let decoded = SyncMessage::deserialize(&data).unwrap();
767        match decoded {
768            SyncMessage::TableListResponse { tables } => {
769                assert_eq!(tables.len(), 2);
770                assert_eq!(tables[0].name, b"users");
771                assert_eq!(tables[0].root_page, PageId(10));
772                assert_eq!(tables[0].root_hash, sample_hash());
773                assert_eq!(tables[1].name, b"orders");
774                assert_eq!(tables[1].root_page, PageId(20));
775            }
776            _ => panic!("wrong variant"),
777        }
778    }
779
780    #[test]
781    fn table_list_response_empty() {
782        let msg = SyncMessage::TableListResponse { tables: vec![] };
783        let data = msg.serialize();
784        let decoded = SyncMessage::deserialize(&data).unwrap();
785        match decoded {
786            SyncMessage::TableListResponse { tables } => assert!(tables.is_empty()),
787            _ => panic!("wrong variant"),
788        }
789    }
790
791    #[test]
792    fn table_sync_begin_roundtrip() {
793        let msg = SyncMessage::TableSyncBegin {
794            table_name: b"products".to_vec(),
795            root_page: PageId(77),
796            root_hash: sample_hash(),
797        };
798        let data = msg.serialize();
799        let decoded = SyncMessage::deserialize(&data).unwrap();
800        match decoded {
801            SyncMessage::TableSyncBegin { table_name, root_page, root_hash } => {
802                assert_eq!(table_name, b"products");
803                assert_eq!(root_page, PageId(77));
804                assert_eq!(root_hash, sample_hash());
805            }
806            _ => panic!("wrong variant"),
807        }
808    }
809
810    #[test]
811    fn table_sync_end_roundtrip() {
812        let msg = SyncMessage::TableSyncEnd {
813            table_name: b"products".to_vec(),
814        };
815        let data = msg.serialize();
816        let decoded = SyncMessage::deserialize(&data).unwrap();
817        match decoded {
818            SyncMessage::TableSyncEnd { table_name } => {
819                assert_eq!(table_name, b"products");
820            }
821            _ => panic!("wrong variant"),
822        }
823    }
824
825    #[test]
826    fn empty_entries_response() {
827        let msg = SyncMessage::EntriesResponse { entries: vec![] };
828        let data = msg.serialize();
829        let decoded = SyncMessage::deserialize(&data).unwrap();
830        match decoded {
831            SyncMessage::EntriesResponse { entries } => assert!(entries.is_empty()),
832            _ => panic!("wrong variant"),
833        }
834    }
835}