Skip to main content

citadel_sync/
protocol.rs

1use citadel_core::types::PageId;
2use citadel_core::MERKLE_HASH_SIZE;
3
4use crate::apply::ApplyResult;
5use crate::diff::{DiffEntry, MerkleHash, PageDigest};
6use crate::node_id::NodeId;
7
8/// Message type tags for wire format.
9const MSG_HELLO: u8 = 0;
10const MSG_HELLO_ACK: u8 = 1;
11const MSG_DIGEST_REQUEST: u8 = 2;
12const MSG_DIGEST_RESPONSE: u8 = 3;
13const MSG_ENTRIES_REQUEST: u8 = 4;
14const MSG_ENTRIES_RESPONSE: u8 = 5;
15const MSG_PATCH_DATA: u8 = 6;
16const MSG_PATCH_ACK: u8 = 7;
17const MSG_DONE: u8 = 8;
18const MSG_ERROR: u8 = 9;
19const MSG_PULL_REQUEST: u8 = 10;
20const MSG_PULL_RESPONSE: u8 = 11;
21const MSG_TABLE_LIST_REQUEST: u8 = 12;
22const MSG_TABLE_LIST_RESPONSE: u8 = 13;
23const MSG_TABLE_SYNC_BEGIN: u8 = 14;
24const MSG_TABLE_SYNC_END: u8 = 15;
25
26/// Metadata about a named table for multi-table sync negotiation.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TableInfo {
29    pub name: Vec<u8>,
30    pub root_page: PageId,
31    pub root_hash: MerkleHash,
32}
33
34/// Sync protocol messages exchanged between initiator and responder.
35#[derive(Debug, Clone)]
36pub enum SyncMessage {
37    /// Initiator greeting with identity and tree root state.
38    Hello {
39        node_id: NodeId,
40        root_page: PageId,
41        root_hash: MerkleHash,
42    },
43    /// Responder acknowledgment with its own tree root state.
44    HelloAck {
45        node_id: NodeId,
46        root_page: PageId,
47        root_hash: MerkleHash,
48        in_sync: bool,
49    },
50    /// Request page digests from the remote tree.
51    DigestRequest { page_ids: Vec<PageId> },
52    /// Response with page digests.
53    DigestResponse { digests: Vec<PageDigest> },
54    /// Request leaf entries from remote pages.
55    EntriesRequest { page_ids: Vec<PageId> },
56    /// Response with leaf entries.
57    EntriesResponse { entries: Vec<DiffEntry> },
58    /// Serialized SyncPatch data.
59    PatchData { data: Vec<u8> },
60    /// Acknowledgment after applying a patch.
61    PatchAck { result: ApplyResult },
62    /// Session complete.
63    Done,
64    /// Error during sync.
65    Error { message: String },
66    /// Request updated root info for pull phase after push.
67    PullRequest,
68    /// Response with updated root info for pull phase.
69    PullResponse {
70        root_page: PageId,
71        root_hash: MerkleHash,
72    },
73    /// Request list of named tables from the remote peer.
74    TableListRequest,
75    /// Response with the list of named tables.
76    TableListResponse { tables: Vec<TableInfo> },
77    /// Begin syncing a specific named table.
78    TableSyncBegin {
79        table_name: Vec<u8>,
80        root_page: PageId,
81        root_hash: MerkleHash,
82    },
83    /// End syncing a specific named table.
84    TableSyncEnd { table_name: Vec<u8> },
85}
86
87/// Errors from sync message serialization/deserialization.
88#[derive(Debug, thiserror::Error)]
89pub enum ProtocolError {
90    #[error("{context}: expected at least {expected} bytes, got {actual}")]
91    Truncated {
92        context: String,
93        expected: usize,
94        actual: usize,
95    },
96
97    #[error("unknown message type: {0}")]
98    UnknownMessageType(u8),
99}
100
101impl SyncMessage {
102    /// Serialize to wire format: `[msg_type: u8][payload_len: u32 LE][payload]`.
103    pub fn serialize(&self) -> Vec<u8> {
104        let (msg_type, payload) = match self {
105            SyncMessage::Hello {
106                node_id,
107                root_page,
108                root_hash,
109            } => {
110                let mut p = Vec::with_capacity(40);
111                p.extend_from_slice(&node_id.to_bytes());
112                p.extend_from_slice(&root_page.0.to_le_bytes());
113                p.extend_from_slice(root_hash);
114                (MSG_HELLO, p)
115            }
116            SyncMessage::HelloAck {
117                node_id,
118                root_page,
119                root_hash,
120                in_sync,
121            } => {
122                let mut p = Vec::with_capacity(41);
123                p.extend_from_slice(&node_id.to_bytes());
124                p.extend_from_slice(&root_page.0.to_le_bytes());
125                p.extend_from_slice(root_hash);
126                p.push(if *in_sync { 1 } else { 0 });
127                (MSG_HELLO_ACK, p)
128            }
129            SyncMessage::DigestRequest { page_ids } => {
130                let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
131                p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
132                for pid in page_ids {
133                    p.extend_from_slice(&pid.0.to_le_bytes());
134                }
135                (MSG_DIGEST_REQUEST, p)
136            }
137            SyncMessage::DigestResponse { digests } => {
138                let mut p = Vec::new();
139                p.extend_from_slice(&(digests.len() as u32).to_le_bytes());
140                for d in digests {
141                    serialize_page_digest(&mut p, d);
142                }
143                (MSG_DIGEST_RESPONSE, p)
144            }
145            SyncMessage::EntriesRequest { page_ids } => {
146                let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
147                p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
148                for pid in page_ids {
149                    p.extend_from_slice(&pid.0.to_le_bytes());
150                }
151                (MSG_ENTRIES_REQUEST, p)
152            }
153            SyncMessage::EntriesResponse { entries } => {
154                let mut p = Vec::new();
155                p.extend_from_slice(&(entries.len() as u32).to_le_bytes());
156                for e in entries {
157                    serialize_diff_entry(&mut p, e);
158                }
159                (MSG_ENTRIES_RESPONSE, p)
160            }
161            SyncMessage::PatchData { data } => (MSG_PATCH_DATA, data.clone()),
162            SyncMessage::PatchAck { result } => {
163                let mut p = Vec::with_capacity(24);
164                p.extend_from_slice(&result.entries_applied.to_le_bytes());
165                p.extend_from_slice(&result.entries_skipped.to_le_bytes());
166                p.extend_from_slice(&result.entries_equal.to_le_bytes());
167                (MSG_PATCH_ACK, p)
168            }
169            SyncMessage::Done => (MSG_DONE, Vec::new()),
170            SyncMessage::Error { message } => {
171                let bytes = message.as_bytes();
172                let mut p = Vec::with_capacity(4 + bytes.len());
173                p.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
174                p.extend_from_slice(bytes);
175                (MSG_ERROR, p)
176            }
177            SyncMessage::PullRequest => (MSG_PULL_REQUEST, Vec::new()),
178            SyncMessage::PullResponse {
179                root_page,
180                root_hash,
181            } => {
182                let mut p = Vec::with_capacity(32);
183                p.extend_from_slice(&root_page.0.to_le_bytes());
184                p.extend_from_slice(root_hash);
185                (MSG_PULL_RESPONSE, p)
186            }
187            SyncMessage::TableListRequest => (MSG_TABLE_LIST_REQUEST, Vec::new()),
188            SyncMessage::TableListResponse { tables } => {
189                let mut p = Vec::new();
190                p.extend_from_slice(&(tables.len() as u32).to_le_bytes());
191                for t in tables {
192                    p.extend_from_slice(&(t.name.len() as u16).to_le_bytes());
193                    p.extend_from_slice(&t.name);
194                    p.extend_from_slice(&t.root_page.0.to_le_bytes());
195                    p.extend_from_slice(&t.root_hash);
196                }
197                (MSG_TABLE_LIST_RESPONSE, p)
198            }
199            SyncMessage::TableSyncBegin {
200                table_name,
201                root_page,
202                root_hash,
203            } => {
204                let mut p = Vec::with_capacity(2 + table_name.len() + 4 + MERKLE_HASH_SIZE);
205                p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
206                p.extend_from_slice(table_name);
207                p.extend_from_slice(&root_page.0.to_le_bytes());
208                p.extend_from_slice(root_hash);
209                (MSG_TABLE_SYNC_BEGIN, p)
210            }
211            SyncMessage::TableSyncEnd { table_name } => {
212                let mut p = Vec::with_capacity(2 + table_name.len());
213                p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
214                p.extend_from_slice(table_name);
215                (MSG_TABLE_SYNC_END, p)
216            }
217        };
218
219        let mut buf = Vec::with_capacity(5 + payload.len());
220        buf.push(msg_type);
221        buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
222        buf.extend_from_slice(&payload);
223        buf
224    }
225
226    /// Deserialize from wire format.
227    pub fn deserialize(data: &[u8]) -> Result<Self, ProtocolError> {
228        if data.len() < 5 {
229            return Err(ProtocolError::Truncated {
230                context: "message header".to_string(),
231                expected: 5,
232                actual: data.len(),
233            });
234        }
235
236        let msg_type = data[0];
237        let payload_len = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
238
239        if data.len() < 5 + payload_len {
240            return Err(ProtocolError::Truncated {
241                context: "message payload".to_string(),
242                expected: 5 + payload_len,
243                actual: data.len(),
244            });
245        }
246
247        let payload = &data[5..5 + payload_len];
248
249        match msg_type {
250            MSG_HELLO => {
251                ensure_len(payload, 40, "Hello")?;
252                let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
253                let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
254                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
255                root_hash.copy_from_slice(&payload[12..40]);
256                Ok(SyncMessage::Hello {
257                    node_id,
258                    root_page,
259                    root_hash,
260                })
261            }
262            MSG_HELLO_ACK => {
263                ensure_len(payload, 41, "HelloAck")?;
264                let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
265                let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
266                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
267                root_hash.copy_from_slice(&payload[12..40]);
268                let in_sync = payload[40] != 0;
269                Ok(SyncMessage::HelloAck {
270                    node_id,
271                    root_page,
272                    root_hash,
273                    in_sync,
274                })
275            }
276            MSG_DIGEST_REQUEST => {
277                ensure_len(payload, 4, "DigestRequest")?;
278                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
279                ensure_len(payload, 4 + count * 4, "DigestRequest")?;
280                let page_ids = (0..count)
281                    .map(|i| {
282                        let off = 4 + i * 4;
283                        PageId(u32::from_le_bytes(
284                            payload[off..off + 4].try_into().unwrap(),
285                        ))
286                    })
287                    .collect();
288                Ok(SyncMessage::DigestRequest { page_ids })
289            }
290            MSG_DIGEST_RESPONSE => {
291                ensure_len(payload, 4, "DigestResponse")?;
292                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
293                let mut pos = 4;
294                let mut digests = Vec::with_capacity(count);
295                for _ in 0..count {
296                    let (digest, consumed) = deserialize_page_digest(payload, pos)?;
297                    digests.push(digest);
298                    pos += consumed;
299                }
300                Ok(SyncMessage::DigestResponse { digests })
301            }
302            MSG_ENTRIES_REQUEST => {
303                ensure_len(payload, 4, "EntriesRequest")?;
304                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
305                ensure_len(payload, 4 + count * 4, "EntriesRequest")?;
306                let page_ids = (0..count)
307                    .map(|i| {
308                        let off = 4 + i * 4;
309                        PageId(u32::from_le_bytes(
310                            payload[off..off + 4].try_into().unwrap(),
311                        ))
312                    })
313                    .collect();
314                Ok(SyncMessage::EntriesRequest { page_ids })
315            }
316            MSG_ENTRIES_RESPONSE => {
317                ensure_len(payload, 4, "EntriesResponse")?;
318                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
319                let mut pos = 4;
320                let mut entries = Vec::with_capacity(count);
321                for _ in 0..count {
322                    let (entry, consumed) = deserialize_diff_entry(payload, pos)?;
323                    entries.push(entry);
324                    pos += consumed;
325                }
326                Ok(SyncMessage::EntriesResponse { entries })
327            }
328            MSG_PATCH_DATA => Ok(SyncMessage::PatchData {
329                data: payload.to_vec(),
330            }),
331            MSG_PATCH_ACK => {
332                ensure_len(payload, 24, "PatchAck")?;
333                let entries_applied = u64::from_le_bytes(payload[0..8].try_into().unwrap());
334                let entries_skipped = u64::from_le_bytes(payload[8..16].try_into().unwrap());
335                let entries_equal = u64::from_le_bytes(payload[16..24].try_into().unwrap());
336                Ok(SyncMessage::PatchAck {
337                    result: ApplyResult {
338                        entries_applied,
339                        entries_skipped,
340                        entries_equal,
341                    },
342                })
343            }
344            MSG_DONE => Ok(SyncMessage::Done),
345            MSG_ERROR => {
346                ensure_len(payload, 4, "Error")?;
347                let msg_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
348                ensure_len(payload, 4 + msg_len, "Error")?;
349                let message = String::from_utf8_lossy(&payload[4..4 + msg_len]).into_owned();
350                Ok(SyncMessage::Error { message })
351            }
352            MSG_PULL_REQUEST => Ok(SyncMessage::PullRequest),
353            MSG_PULL_RESPONSE => {
354                ensure_len(payload, 32, "PullResponse")?;
355                let root_page = PageId(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
356                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
357                root_hash.copy_from_slice(&payload[4..32]);
358                Ok(SyncMessage::PullResponse {
359                    root_page,
360                    root_hash,
361                })
362            }
363            MSG_TABLE_LIST_REQUEST => Ok(SyncMessage::TableListRequest),
364            MSG_TABLE_LIST_RESPONSE => {
365                ensure_len(payload, 4, "TableListResponse")?;
366                let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
367                let mut pos = 4;
368                let mut tables = Vec::with_capacity(count);
369                for _ in 0..count {
370                    ensure_len(payload, pos + 2, "TableInfo name_len")?;
371                    let name_len =
372                        u16::from_le_bytes(payload[pos..pos + 2].try_into().unwrap()) as usize;
373                    pos += 2;
374                    ensure_len(payload, pos + name_len + 4 + MERKLE_HASH_SIZE, "TableInfo")?;
375                    let name = payload[pos..pos + name_len].to_vec();
376                    pos += name_len;
377                    let root_page = PageId(u32::from_le_bytes(
378                        payload[pos..pos + 4].try_into().unwrap(),
379                    ));
380                    pos += 4;
381                    let mut root_hash = [0u8; MERKLE_HASH_SIZE];
382                    root_hash.copy_from_slice(&payload[pos..pos + MERKLE_HASH_SIZE]);
383                    pos += MERKLE_HASH_SIZE;
384                    tables.push(TableInfo {
385                        name,
386                        root_page,
387                        root_hash,
388                    });
389                }
390                Ok(SyncMessage::TableListResponse { tables })
391            }
392            MSG_TABLE_SYNC_BEGIN => {
393                ensure_len(payload, 2, "TableSyncBegin")?;
394                let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
395                ensure_len(
396                    payload,
397                    2 + name_len + 4 + MERKLE_HASH_SIZE,
398                    "TableSyncBegin",
399                )?;
400                let table_name = payload[2..2 + name_len].to_vec();
401                let off = 2 + name_len;
402                let root_page = PageId(u32::from_le_bytes(
403                    payload[off..off + 4].try_into().unwrap(),
404                ));
405                let mut root_hash = [0u8; MERKLE_HASH_SIZE];
406                root_hash.copy_from_slice(&payload[off + 4..off + 4 + MERKLE_HASH_SIZE]);
407                Ok(SyncMessage::TableSyncBegin {
408                    table_name,
409                    root_page,
410                    root_hash,
411                })
412            }
413            MSG_TABLE_SYNC_END => {
414                ensure_len(payload, 2, "TableSyncEnd")?;
415                let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
416                ensure_len(payload, 2 + name_len, "TableSyncEnd")?;
417                let table_name = payload[2..2 + name_len].to_vec();
418                Ok(SyncMessage::TableSyncEnd { table_name })
419            }
420            _ => Err(ProtocolError::UnknownMessageType(msg_type)),
421        }
422    }
423}
424
425fn ensure_len(data: &[u8], needed: usize, ctx: &str) -> Result<(), ProtocolError> {
426    if data.len() < needed {
427        Err(ProtocolError::Truncated {
428            context: ctx.to_string(),
429            expected: needed,
430            actual: data.len(),
431        })
432    } else {
433        Ok(())
434    }
435}
436
437fn serialize_page_digest(buf: &mut Vec<u8>, d: &PageDigest) {
438    buf.extend_from_slice(&d.page_id.0.to_le_bytes());
439    buf.extend_from_slice(&(d.page_type as u16).to_le_bytes());
440    buf.extend_from_slice(&d.merkle_hash);
441    buf.extend_from_slice(&(d.children.len() as u32).to_le_bytes());
442    for c in &d.children {
443        buf.extend_from_slice(&c.0.to_le_bytes());
444    }
445}
446
447fn deserialize_page_digest(
448    data: &[u8],
449    offset: usize,
450) -> Result<(PageDigest, usize), ProtocolError> {
451    // page_id(4) + page_type(2) + merkle_hash(28) + child_count(4) = 38
452    let min = 38;
453    if data.len() < offset + min {
454        return Err(ProtocolError::Truncated {
455            context: "PageDigest header".to_string(),
456            expected: offset + min,
457            actual: data.len(),
458        });
459    }
460
461    let page_id = PageId(u32::from_le_bytes(
462        data[offset..offset + 4].try_into().unwrap(),
463    ));
464    let page_type_raw = u16::from_le_bytes(data[offset + 4..offset + 6].try_into().unwrap());
465    let page_type = citadel_core::types::PageType::from_u16(page_type_raw)
466        .unwrap_or(citadel_core::types::PageType::Leaf);
467    let mut merkle_hash = [0u8; MERKLE_HASH_SIZE];
468    merkle_hash.copy_from_slice(&data[offset + 6..offset + 34]);
469    let child_count =
470        u32::from_le_bytes(data[offset + 34..offset + 38].try_into().unwrap()) as usize;
471
472    if data.len() < offset + min + child_count * 4 {
473        return Err(ProtocolError::Truncated {
474            context: "PageDigest children".to_string(),
475            expected: offset + min + child_count * 4,
476            actual: data.len(),
477        });
478    }
479
480    let children = (0..child_count)
481        .map(|i| {
482            let off = offset + 38 + i * 4;
483            PageId(u32::from_le_bytes(data[off..off + 4].try_into().unwrap()))
484        })
485        .collect();
486
487    Ok((
488        PageDigest {
489            page_id,
490            page_type,
491            merkle_hash,
492            children,
493        },
494        min + child_count * 4,
495    ))
496}
497
498fn serialize_diff_entry(buf: &mut Vec<u8>, e: &DiffEntry) {
499    buf.extend_from_slice(&(e.key.len() as u16).to_le_bytes());
500    buf.extend_from_slice(&(e.value.len() as u32).to_le_bytes());
501    buf.push(e.val_type);
502    buf.extend_from_slice(&e.key);
503    buf.extend_from_slice(&e.value);
504}
505
506fn deserialize_diff_entry(data: &[u8], offset: usize) -> Result<(DiffEntry, usize), ProtocolError> {
507    // key_len(2) + val_len(4) + val_type(1) = 7
508    let header = 7;
509    if data.len() < offset + header {
510        return Err(ProtocolError::Truncated {
511            context: "DiffEntry header".to_string(),
512            expected: offset + header,
513            actual: data.len(),
514        });
515    }
516
517    let key_len = u16::from_le_bytes(data[offset..offset + 2].try_into().unwrap()) as usize;
518    let val_len = u32::from_le_bytes(data[offset + 2..offset + 6].try_into().unwrap()) as usize;
519    let val_type = data[offset + 6];
520
521    let total = header + key_len + val_len;
522    if data.len() < offset + total {
523        return Err(ProtocolError::Truncated {
524            context: "DiffEntry data".to_string(),
525            expected: offset + total,
526            actual: data.len(),
527        });
528    }
529
530    let key = data[offset + 7..offset + 7 + key_len].to_vec();
531    let value = data[offset + 7 + key_len..offset + 7 + key_len + val_len].to_vec();
532
533    Ok((
534        DiffEntry {
535            key,
536            value,
537            val_type,
538        },
539        total,
540    ))
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use citadel_core::types::PageType;
547
548    fn sample_hash() -> MerkleHash {
549        let mut h = [0u8; MERKLE_HASH_SIZE];
550        for (i, byte) in h.iter_mut().enumerate() {
551            *byte = i as u8;
552        }
553        h
554    }
555
556    #[test]
557    fn hello_roundtrip() {
558        let msg = SyncMessage::Hello {
559            node_id: NodeId::from_u64(42),
560            root_page: PageId(7),
561            root_hash: sample_hash(),
562        };
563        let data = msg.serialize();
564        let decoded = SyncMessage::deserialize(&data).unwrap();
565        match decoded {
566            SyncMessage::Hello {
567                node_id,
568                root_page,
569                root_hash,
570            } => {
571                assert_eq!(node_id, NodeId::from_u64(42));
572                assert_eq!(root_page, PageId(7));
573                assert_eq!(root_hash, sample_hash());
574            }
575            _ => panic!("wrong variant"),
576        }
577    }
578
579    #[test]
580    fn hello_ack_roundtrip() {
581        let msg = SyncMessage::HelloAck {
582            node_id: NodeId::from_u64(99),
583            root_page: PageId(3),
584            root_hash: sample_hash(),
585            in_sync: true,
586        };
587        let data = msg.serialize();
588        let decoded = SyncMessage::deserialize(&data).unwrap();
589        match decoded {
590            SyncMessage::HelloAck {
591                node_id,
592                root_page,
593                root_hash,
594                in_sync,
595            } => {
596                assert_eq!(node_id, NodeId::from_u64(99));
597                assert_eq!(root_page, PageId(3));
598                assert_eq!(root_hash, sample_hash());
599                assert!(in_sync);
600            }
601            _ => panic!("wrong variant"),
602        }
603    }
604
605    #[test]
606    fn digest_request_roundtrip() {
607        let msg = SyncMessage::DigestRequest {
608            page_ids: vec![PageId(1), PageId(5), PageId(100)],
609        };
610        let data = msg.serialize();
611        let decoded = SyncMessage::deserialize(&data).unwrap();
612        match decoded {
613            SyncMessage::DigestRequest { page_ids } => {
614                assert_eq!(page_ids, vec![PageId(1), PageId(5), PageId(100)]);
615            }
616            _ => panic!("wrong variant"),
617        }
618    }
619
620    #[test]
621    fn digest_response_roundtrip() {
622        let msg = SyncMessage::DigestResponse {
623            digests: vec![
624                PageDigest {
625                    page_id: PageId(1),
626                    page_type: PageType::Leaf,
627                    merkle_hash: sample_hash(),
628                    children: vec![],
629                },
630                PageDigest {
631                    page_id: PageId(2),
632                    page_type: PageType::Branch,
633                    merkle_hash: [0xAA; MERKLE_HASH_SIZE],
634                    children: vec![PageId(3), PageId(4)],
635                },
636            ],
637        };
638        let data = msg.serialize();
639        let decoded = SyncMessage::deserialize(&data).unwrap();
640        match decoded {
641            SyncMessage::DigestResponse { digests } => {
642                assert_eq!(digests.len(), 2);
643                assert_eq!(digests[0].page_id, PageId(1));
644                assert!(digests[0].children.is_empty());
645                assert_eq!(digests[1].children, vec![PageId(3), PageId(4)]);
646            }
647            _ => panic!("wrong variant"),
648        }
649    }
650
651    #[test]
652    fn entries_request_roundtrip() {
653        let msg = SyncMessage::EntriesRequest {
654            page_ids: vec![PageId(10)],
655        };
656        let data = msg.serialize();
657        let decoded = SyncMessage::deserialize(&data).unwrap();
658        match decoded {
659            SyncMessage::EntriesRequest { page_ids } => {
660                assert_eq!(page_ids, vec![PageId(10)]);
661            }
662            _ => panic!("wrong variant"),
663        }
664    }
665
666    #[test]
667    fn entries_response_roundtrip() {
668        let msg = SyncMessage::EntriesResponse {
669            entries: vec![
670                DiffEntry {
671                    key: b"k1".to_vec(),
672                    value: b"v1".to_vec(),
673                    val_type: 0,
674                },
675                DiffEntry {
676                    key: b"k2".to_vec(),
677                    value: b"v2".to_vec(),
678                    val_type: 1,
679                },
680            ],
681        };
682        let data = msg.serialize();
683        let decoded = SyncMessage::deserialize(&data).unwrap();
684        match decoded {
685            SyncMessage::EntriesResponse { entries } => {
686                assert_eq!(entries.len(), 2);
687                assert_eq!(entries[0].key, b"k1");
688                assert_eq!(entries[1].val_type, 1);
689            }
690            _ => panic!("wrong variant"),
691        }
692    }
693
694    #[test]
695    fn patch_data_roundtrip() {
696        let msg = SyncMessage::PatchData {
697            data: vec![1, 2, 3, 4, 5],
698        };
699        let data = msg.serialize();
700        let decoded = SyncMessage::deserialize(&data).unwrap();
701        match decoded {
702            SyncMessage::PatchData { data: d } => {
703                assert_eq!(d, vec![1, 2, 3, 4, 5]);
704            }
705            _ => panic!("wrong variant"),
706        }
707    }
708
709    #[test]
710    fn patch_ack_roundtrip() {
711        let msg = SyncMessage::PatchAck {
712            result: ApplyResult {
713                entries_applied: 10,
714                entries_skipped: 3,
715                entries_equal: 2,
716            },
717        };
718        let data = msg.serialize();
719        let decoded = SyncMessage::deserialize(&data).unwrap();
720        match decoded {
721            SyncMessage::PatchAck { result } => {
722                assert_eq!(result.entries_applied, 10);
723                assert_eq!(result.entries_skipped, 3);
724                assert_eq!(result.entries_equal, 2);
725            }
726            _ => panic!("wrong variant"),
727        }
728    }
729
730    #[test]
731    fn done_roundtrip() {
732        let data = SyncMessage::Done.serialize();
733        let decoded = SyncMessage::deserialize(&data).unwrap();
734        assert!(matches!(decoded, SyncMessage::Done));
735    }
736
737    #[test]
738    fn error_roundtrip() {
739        let msg = SyncMessage::Error {
740            message: "something broke".into(),
741        };
742        let data = msg.serialize();
743        let decoded = SyncMessage::deserialize(&data).unwrap();
744        match decoded {
745            SyncMessage::Error { message } => {
746                assert_eq!(message, "something broke");
747            }
748            _ => panic!("wrong variant"),
749        }
750    }
751
752    #[test]
753    fn pull_request_roundtrip() {
754        let data = SyncMessage::PullRequest.serialize();
755        let decoded = SyncMessage::deserialize(&data).unwrap();
756        assert!(matches!(decoded, SyncMessage::PullRequest));
757    }
758
759    #[test]
760    fn pull_response_roundtrip() {
761        let msg = SyncMessage::PullResponse {
762            root_page: PageId(15),
763            root_hash: sample_hash(),
764        };
765        let data = msg.serialize();
766        let decoded = SyncMessage::deserialize(&data).unwrap();
767        match decoded {
768            SyncMessage::PullResponse {
769                root_page,
770                root_hash,
771            } => {
772                assert_eq!(root_page, PageId(15));
773                assert_eq!(root_hash, sample_hash());
774            }
775            _ => panic!("wrong variant"),
776        }
777    }
778
779    #[test]
780    fn truncated_data() {
781        let err = SyncMessage::deserialize(&[0, 1]).unwrap_err();
782        assert!(matches!(err, ProtocolError::Truncated { .. }));
783    }
784
785    #[test]
786    fn unknown_message_type() {
787        let data = [255, 0, 0, 0, 0];
788        let err = SyncMessage::deserialize(&data).unwrap_err();
789        assert!(matches!(err, ProtocolError::UnknownMessageType(255)));
790    }
791
792    #[test]
793    fn empty_digest_request() {
794        let msg = SyncMessage::DigestRequest { page_ids: vec![] };
795        let data = msg.serialize();
796        let decoded = SyncMessage::deserialize(&data).unwrap();
797        match decoded {
798            SyncMessage::DigestRequest { page_ids } => assert!(page_ids.is_empty()),
799            _ => panic!("wrong variant"),
800        }
801    }
802
803    #[test]
804    fn table_list_request_roundtrip() {
805        let data = SyncMessage::TableListRequest.serialize();
806        let decoded = SyncMessage::deserialize(&data).unwrap();
807        assert!(matches!(decoded, SyncMessage::TableListRequest));
808    }
809
810    #[test]
811    fn table_list_response_roundtrip() {
812        let msg = SyncMessage::TableListResponse {
813            tables: vec![
814                TableInfo {
815                    name: b"users".to_vec(),
816                    root_page: PageId(10),
817                    root_hash: sample_hash(),
818                },
819                TableInfo {
820                    name: b"orders".to_vec(),
821                    root_page: PageId(20),
822                    root_hash: [0xBB; MERKLE_HASH_SIZE],
823                },
824            ],
825        };
826        let data = msg.serialize();
827        let decoded = SyncMessage::deserialize(&data).unwrap();
828        match decoded {
829            SyncMessage::TableListResponse { tables } => {
830                assert_eq!(tables.len(), 2);
831                assert_eq!(tables[0].name, b"users");
832                assert_eq!(tables[0].root_page, PageId(10));
833                assert_eq!(tables[0].root_hash, sample_hash());
834                assert_eq!(tables[1].name, b"orders");
835                assert_eq!(tables[1].root_page, PageId(20));
836            }
837            _ => panic!("wrong variant"),
838        }
839    }
840
841    #[test]
842    fn table_list_response_empty() {
843        let msg = SyncMessage::TableListResponse { tables: vec![] };
844        let data = msg.serialize();
845        let decoded = SyncMessage::deserialize(&data).unwrap();
846        match decoded {
847            SyncMessage::TableListResponse { tables } => assert!(tables.is_empty()),
848            _ => panic!("wrong variant"),
849        }
850    }
851
852    #[test]
853    fn table_sync_begin_roundtrip() {
854        let msg = SyncMessage::TableSyncBegin {
855            table_name: b"products".to_vec(),
856            root_page: PageId(77),
857            root_hash: sample_hash(),
858        };
859        let data = msg.serialize();
860        let decoded = SyncMessage::deserialize(&data).unwrap();
861        match decoded {
862            SyncMessage::TableSyncBegin {
863                table_name,
864                root_page,
865                root_hash,
866            } => {
867                assert_eq!(table_name, b"products");
868                assert_eq!(root_page, PageId(77));
869                assert_eq!(root_hash, sample_hash());
870            }
871            _ => panic!("wrong variant"),
872        }
873    }
874
875    #[test]
876    fn table_sync_end_roundtrip() {
877        let msg = SyncMessage::TableSyncEnd {
878            table_name: b"products".to_vec(),
879        };
880        let data = msg.serialize();
881        let decoded = SyncMessage::deserialize(&data).unwrap();
882        match decoded {
883            SyncMessage::TableSyncEnd { table_name } => {
884                assert_eq!(table_name, b"products");
885            }
886            _ => panic!("wrong variant"),
887        }
888    }
889
890    #[test]
891    fn empty_entries_response() {
892        let msg = SyncMessage::EntriesResponse { entries: vec![] };
893        let data = msg.serialize();
894        let decoded = SyncMessage::deserialize(&data).unwrap();
895        match decoded {
896            SyncMessage::EntriesResponse { entries } => assert!(entries.is_empty()),
897            _ => panic!("wrong variant"),
898        }
899    }
900}