1use citadel_core::types::PageId;
2use citadel_core::MERKLE_HASH_SIZE;
3
4use crate::apply::ApplyResult;
5use crate::diff::{DiffEntry, MerkleHash, PageDigest};
6use crate::node_id::NodeId;
7
8const MSG_HELLO: u8 = 0;
10const MSG_HELLO_ACK: u8 = 1;
11const MSG_DIGEST_REQUEST: u8 = 2;
12const MSG_DIGEST_RESPONSE: u8 = 3;
13const MSG_ENTRIES_REQUEST: u8 = 4;
14const MSG_ENTRIES_RESPONSE: u8 = 5;
15const MSG_PATCH_DATA: u8 = 6;
16const MSG_PATCH_ACK: u8 = 7;
17const MSG_DONE: u8 = 8;
18const MSG_ERROR: u8 = 9;
19const MSG_PULL_REQUEST: u8 = 10;
20const MSG_PULL_RESPONSE: u8 = 11;
21const MSG_TABLE_LIST_REQUEST: u8 = 12;
22const MSG_TABLE_LIST_RESPONSE: u8 = 13;
23const MSG_TABLE_SYNC_BEGIN: u8 = 14;
24const MSG_TABLE_SYNC_END: u8 = 15;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TableInfo {
29 pub name: Vec<u8>,
30 pub root_page: PageId,
31 pub root_hash: MerkleHash,
32}
33
34#[derive(Debug, Clone)]
36pub enum SyncMessage {
37 Hello {
39 node_id: NodeId,
40 root_page: PageId,
41 root_hash: MerkleHash,
42 },
43 HelloAck {
45 node_id: NodeId,
46 root_page: PageId,
47 root_hash: MerkleHash,
48 in_sync: bool,
49 },
50 DigestRequest { page_ids: Vec<PageId> },
52 DigestResponse { digests: Vec<PageDigest> },
54 EntriesRequest { page_ids: Vec<PageId> },
56 EntriesResponse { entries: Vec<DiffEntry> },
58 PatchData { data: Vec<u8> },
60 PatchAck { result: ApplyResult },
62 Done,
64 Error { message: String },
66 PullRequest,
68 PullResponse {
70 root_page: PageId,
71 root_hash: MerkleHash,
72 },
73 TableListRequest,
75 TableListResponse { tables: Vec<TableInfo> },
77 TableSyncBegin {
79 table_name: Vec<u8>,
80 root_page: PageId,
81 root_hash: MerkleHash,
82 },
83 TableSyncEnd { table_name: Vec<u8> },
85}
86
87#[derive(Debug, thiserror::Error)]
89pub enum ProtocolError {
90 #[error("{context}: expected at least {expected} bytes, got {actual}")]
91 Truncated {
92 context: String,
93 expected: usize,
94 actual: usize,
95 },
96
97 #[error("unknown message type: {0}")]
98 UnknownMessageType(u8),
99}
100
101impl SyncMessage {
102 pub fn serialize(&self) -> Vec<u8> {
104 let (msg_type, payload) = match self {
105 SyncMessage::Hello {
106 node_id,
107 root_page,
108 root_hash,
109 } => {
110 let mut p = Vec::with_capacity(40);
111 p.extend_from_slice(&node_id.to_bytes());
112 p.extend_from_slice(&root_page.0.to_le_bytes());
113 p.extend_from_slice(root_hash);
114 (MSG_HELLO, p)
115 }
116 SyncMessage::HelloAck {
117 node_id,
118 root_page,
119 root_hash,
120 in_sync,
121 } => {
122 let mut p = Vec::with_capacity(41);
123 p.extend_from_slice(&node_id.to_bytes());
124 p.extend_from_slice(&root_page.0.to_le_bytes());
125 p.extend_from_slice(root_hash);
126 p.push(if *in_sync { 1 } else { 0 });
127 (MSG_HELLO_ACK, p)
128 }
129 SyncMessage::DigestRequest { page_ids } => {
130 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
131 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
132 for pid in page_ids {
133 p.extend_from_slice(&pid.0.to_le_bytes());
134 }
135 (MSG_DIGEST_REQUEST, p)
136 }
137 SyncMessage::DigestResponse { digests } => {
138 let mut p = Vec::new();
139 p.extend_from_slice(&(digests.len() as u32).to_le_bytes());
140 for d in digests {
141 serialize_page_digest(&mut p, d);
142 }
143 (MSG_DIGEST_RESPONSE, p)
144 }
145 SyncMessage::EntriesRequest { page_ids } => {
146 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
147 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
148 for pid in page_ids {
149 p.extend_from_slice(&pid.0.to_le_bytes());
150 }
151 (MSG_ENTRIES_REQUEST, p)
152 }
153 SyncMessage::EntriesResponse { entries } => {
154 let mut p = Vec::new();
155 p.extend_from_slice(&(entries.len() as u32).to_le_bytes());
156 for e in entries {
157 serialize_diff_entry(&mut p, e);
158 }
159 (MSG_ENTRIES_RESPONSE, p)
160 }
161 SyncMessage::PatchData { data } => (MSG_PATCH_DATA, data.clone()),
162 SyncMessage::PatchAck { result } => {
163 let mut p = Vec::with_capacity(24);
164 p.extend_from_slice(&result.entries_applied.to_le_bytes());
165 p.extend_from_slice(&result.entries_skipped.to_le_bytes());
166 p.extend_from_slice(&result.entries_equal.to_le_bytes());
167 (MSG_PATCH_ACK, p)
168 }
169 SyncMessage::Done => (MSG_DONE, Vec::new()),
170 SyncMessage::Error { message } => {
171 let bytes = message.as_bytes();
172 let mut p = Vec::with_capacity(4 + bytes.len());
173 p.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
174 p.extend_from_slice(bytes);
175 (MSG_ERROR, p)
176 }
177 SyncMessage::PullRequest => (MSG_PULL_REQUEST, Vec::new()),
178 SyncMessage::PullResponse {
179 root_page,
180 root_hash,
181 } => {
182 let mut p = Vec::with_capacity(32);
183 p.extend_from_slice(&root_page.0.to_le_bytes());
184 p.extend_from_slice(root_hash);
185 (MSG_PULL_RESPONSE, p)
186 }
187 SyncMessage::TableListRequest => (MSG_TABLE_LIST_REQUEST, Vec::new()),
188 SyncMessage::TableListResponse { tables } => {
189 let mut p = Vec::new();
190 p.extend_from_slice(&(tables.len() as u32).to_le_bytes());
191 for t in tables {
192 p.extend_from_slice(&(t.name.len() as u16).to_le_bytes());
193 p.extend_from_slice(&t.name);
194 p.extend_from_slice(&t.root_page.0.to_le_bytes());
195 p.extend_from_slice(&t.root_hash);
196 }
197 (MSG_TABLE_LIST_RESPONSE, p)
198 }
199 SyncMessage::TableSyncBegin {
200 table_name,
201 root_page,
202 root_hash,
203 } => {
204 let mut p = Vec::with_capacity(2 + table_name.len() + 4 + MERKLE_HASH_SIZE);
205 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
206 p.extend_from_slice(table_name);
207 p.extend_from_slice(&root_page.0.to_le_bytes());
208 p.extend_from_slice(root_hash);
209 (MSG_TABLE_SYNC_BEGIN, p)
210 }
211 SyncMessage::TableSyncEnd { table_name } => {
212 let mut p = Vec::with_capacity(2 + table_name.len());
213 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
214 p.extend_from_slice(table_name);
215 (MSG_TABLE_SYNC_END, p)
216 }
217 };
218
219 let mut buf = Vec::with_capacity(5 + payload.len());
220 buf.push(msg_type);
221 buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
222 buf.extend_from_slice(&payload);
223 buf
224 }
225
226 pub fn deserialize(data: &[u8]) -> Result<Self, ProtocolError> {
228 if data.len() < 5 {
229 return Err(ProtocolError::Truncated {
230 context: "message header".to_string(),
231 expected: 5,
232 actual: data.len(),
233 });
234 }
235
236 let msg_type = data[0];
237 let payload_len = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
238
239 if data.len() < 5 + payload_len {
240 return Err(ProtocolError::Truncated {
241 context: "message payload".to_string(),
242 expected: 5 + payload_len,
243 actual: data.len(),
244 });
245 }
246
247 let payload = &data[5..5 + payload_len];
248
249 match msg_type {
250 MSG_HELLO => {
251 ensure_len(payload, 40, "Hello")?;
252 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
253 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
254 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
255 root_hash.copy_from_slice(&payload[12..40]);
256 Ok(SyncMessage::Hello {
257 node_id,
258 root_page,
259 root_hash,
260 })
261 }
262 MSG_HELLO_ACK => {
263 ensure_len(payload, 41, "HelloAck")?;
264 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
265 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
266 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
267 root_hash.copy_from_slice(&payload[12..40]);
268 let in_sync = payload[40] != 0;
269 Ok(SyncMessage::HelloAck {
270 node_id,
271 root_page,
272 root_hash,
273 in_sync,
274 })
275 }
276 MSG_DIGEST_REQUEST => {
277 ensure_len(payload, 4, "DigestRequest")?;
278 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
279 ensure_len(payload, 4 + count * 4, "DigestRequest")?;
280 let page_ids = (0..count)
281 .map(|i| {
282 let off = 4 + i * 4;
283 PageId(u32::from_le_bytes(
284 payload[off..off + 4].try_into().unwrap(),
285 ))
286 })
287 .collect();
288 Ok(SyncMessage::DigestRequest { page_ids })
289 }
290 MSG_DIGEST_RESPONSE => {
291 ensure_len(payload, 4, "DigestResponse")?;
292 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
293 let mut pos = 4;
294 let mut digests = Vec::with_capacity(count);
295 for _ in 0..count {
296 let (digest, consumed) = deserialize_page_digest(payload, pos)?;
297 digests.push(digest);
298 pos += consumed;
299 }
300 Ok(SyncMessage::DigestResponse { digests })
301 }
302 MSG_ENTRIES_REQUEST => {
303 ensure_len(payload, 4, "EntriesRequest")?;
304 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
305 ensure_len(payload, 4 + count * 4, "EntriesRequest")?;
306 let page_ids = (0..count)
307 .map(|i| {
308 let off = 4 + i * 4;
309 PageId(u32::from_le_bytes(
310 payload[off..off + 4].try_into().unwrap(),
311 ))
312 })
313 .collect();
314 Ok(SyncMessage::EntriesRequest { page_ids })
315 }
316 MSG_ENTRIES_RESPONSE => {
317 ensure_len(payload, 4, "EntriesResponse")?;
318 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
319 let mut pos = 4;
320 let mut entries = Vec::with_capacity(count);
321 for _ in 0..count {
322 let (entry, consumed) = deserialize_diff_entry(payload, pos)?;
323 entries.push(entry);
324 pos += consumed;
325 }
326 Ok(SyncMessage::EntriesResponse { entries })
327 }
328 MSG_PATCH_DATA => Ok(SyncMessage::PatchData {
329 data: payload.to_vec(),
330 }),
331 MSG_PATCH_ACK => {
332 ensure_len(payload, 24, "PatchAck")?;
333 let entries_applied = u64::from_le_bytes(payload[0..8].try_into().unwrap());
334 let entries_skipped = u64::from_le_bytes(payload[8..16].try_into().unwrap());
335 let entries_equal = u64::from_le_bytes(payload[16..24].try_into().unwrap());
336 Ok(SyncMessage::PatchAck {
337 result: ApplyResult {
338 entries_applied,
339 entries_skipped,
340 entries_equal,
341 },
342 })
343 }
344 MSG_DONE => Ok(SyncMessage::Done),
345 MSG_ERROR => {
346 ensure_len(payload, 4, "Error")?;
347 let msg_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
348 ensure_len(payload, 4 + msg_len, "Error")?;
349 let message = String::from_utf8_lossy(&payload[4..4 + msg_len]).into_owned();
350 Ok(SyncMessage::Error { message })
351 }
352 MSG_PULL_REQUEST => Ok(SyncMessage::PullRequest),
353 MSG_PULL_RESPONSE => {
354 ensure_len(payload, 32, "PullResponse")?;
355 let root_page = PageId(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
356 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
357 root_hash.copy_from_slice(&payload[4..32]);
358 Ok(SyncMessage::PullResponse {
359 root_page,
360 root_hash,
361 })
362 }
363 MSG_TABLE_LIST_REQUEST => Ok(SyncMessage::TableListRequest),
364 MSG_TABLE_LIST_RESPONSE => {
365 ensure_len(payload, 4, "TableListResponse")?;
366 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
367 let mut pos = 4;
368 let mut tables = Vec::with_capacity(count);
369 for _ in 0..count {
370 ensure_len(payload, pos + 2, "TableInfo name_len")?;
371 let name_len =
372 u16::from_le_bytes(payload[pos..pos + 2].try_into().unwrap()) as usize;
373 pos += 2;
374 ensure_len(payload, pos + name_len + 4 + MERKLE_HASH_SIZE, "TableInfo")?;
375 let name = payload[pos..pos + name_len].to_vec();
376 pos += name_len;
377 let root_page = PageId(u32::from_le_bytes(
378 payload[pos..pos + 4].try_into().unwrap(),
379 ));
380 pos += 4;
381 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
382 root_hash.copy_from_slice(&payload[pos..pos + MERKLE_HASH_SIZE]);
383 pos += MERKLE_HASH_SIZE;
384 tables.push(TableInfo {
385 name,
386 root_page,
387 root_hash,
388 });
389 }
390 Ok(SyncMessage::TableListResponse { tables })
391 }
392 MSG_TABLE_SYNC_BEGIN => {
393 ensure_len(payload, 2, "TableSyncBegin")?;
394 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
395 ensure_len(
396 payload,
397 2 + name_len + 4 + MERKLE_HASH_SIZE,
398 "TableSyncBegin",
399 )?;
400 let table_name = payload[2..2 + name_len].to_vec();
401 let off = 2 + name_len;
402 let root_page = PageId(u32::from_le_bytes(
403 payload[off..off + 4].try_into().unwrap(),
404 ));
405 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
406 root_hash.copy_from_slice(&payload[off + 4..off + 4 + MERKLE_HASH_SIZE]);
407 Ok(SyncMessage::TableSyncBegin {
408 table_name,
409 root_page,
410 root_hash,
411 })
412 }
413 MSG_TABLE_SYNC_END => {
414 ensure_len(payload, 2, "TableSyncEnd")?;
415 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
416 ensure_len(payload, 2 + name_len, "TableSyncEnd")?;
417 let table_name = payload[2..2 + name_len].to_vec();
418 Ok(SyncMessage::TableSyncEnd { table_name })
419 }
420 _ => Err(ProtocolError::UnknownMessageType(msg_type)),
421 }
422 }
423}
424
425fn ensure_len(data: &[u8], needed: usize, ctx: &str) -> Result<(), ProtocolError> {
426 if data.len() < needed {
427 Err(ProtocolError::Truncated {
428 context: ctx.to_string(),
429 expected: needed,
430 actual: data.len(),
431 })
432 } else {
433 Ok(())
434 }
435}
436
437fn serialize_page_digest(buf: &mut Vec<u8>, d: &PageDigest) {
438 buf.extend_from_slice(&d.page_id.0.to_le_bytes());
439 buf.extend_from_slice(&(d.page_type as u16).to_le_bytes());
440 buf.extend_from_slice(&d.merkle_hash);
441 buf.extend_from_slice(&(d.children.len() as u32).to_le_bytes());
442 for c in &d.children {
443 buf.extend_from_slice(&c.0.to_le_bytes());
444 }
445}
446
447fn deserialize_page_digest(
448 data: &[u8],
449 offset: usize,
450) -> Result<(PageDigest, usize), ProtocolError> {
451 let min = 38;
453 if data.len() < offset + min {
454 return Err(ProtocolError::Truncated {
455 context: "PageDigest header".to_string(),
456 expected: offset + min,
457 actual: data.len(),
458 });
459 }
460
461 let page_id = PageId(u32::from_le_bytes(
462 data[offset..offset + 4].try_into().unwrap(),
463 ));
464 let page_type_raw = u16::from_le_bytes(data[offset + 4..offset + 6].try_into().unwrap());
465 let page_type = citadel_core::types::PageType::from_u16(page_type_raw)
466 .unwrap_or(citadel_core::types::PageType::Leaf);
467 let mut merkle_hash = [0u8; MERKLE_HASH_SIZE];
468 merkle_hash.copy_from_slice(&data[offset + 6..offset + 34]);
469 let child_count =
470 u32::from_le_bytes(data[offset + 34..offset + 38].try_into().unwrap()) as usize;
471
472 if data.len() < offset + min + child_count * 4 {
473 return Err(ProtocolError::Truncated {
474 context: "PageDigest children".to_string(),
475 expected: offset + min + child_count * 4,
476 actual: data.len(),
477 });
478 }
479
480 let children = (0..child_count)
481 .map(|i| {
482 let off = offset + 38 + i * 4;
483 PageId(u32::from_le_bytes(data[off..off + 4].try_into().unwrap()))
484 })
485 .collect();
486
487 Ok((
488 PageDigest {
489 page_id,
490 page_type,
491 merkle_hash,
492 children,
493 },
494 min + child_count * 4,
495 ))
496}
497
498fn serialize_diff_entry(buf: &mut Vec<u8>, e: &DiffEntry) {
499 buf.extend_from_slice(&(e.key.len() as u16).to_le_bytes());
500 buf.extend_from_slice(&(e.value.len() as u32).to_le_bytes());
501 buf.push(e.val_type);
502 buf.extend_from_slice(&e.key);
503 buf.extend_from_slice(&e.value);
504}
505
506fn deserialize_diff_entry(data: &[u8], offset: usize) -> Result<(DiffEntry, usize), ProtocolError> {
507 let header = 7;
509 if data.len() < offset + header {
510 return Err(ProtocolError::Truncated {
511 context: "DiffEntry header".to_string(),
512 expected: offset + header,
513 actual: data.len(),
514 });
515 }
516
517 let key_len = u16::from_le_bytes(data[offset..offset + 2].try_into().unwrap()) as usize;
518 let val_len = u32::from_le_bytes(data[offset + 2..offset + 6].try_into().unwrap()) as usize;
519 let val_type = data[offset + 6];
520
521 let total = header + key_len + val_len;
522 if data.len() < offset + total {
523 return Err(ProtocolError::Truncated {
524 context: "DiffEntry data".to_string(),
525 expected: offset + total,
526 actual: data.len(),
527 });
528 }
529
530 let key = data[offset + 7..offset + 7 + key_len].to_vec();
531 let value = data[offset + 7 + key_len..offset + 7 + key_len + val_len].to_vec();
532
533 Ok((
534 DiffEntry {
535 key,
536 value,
537 val_type,
538 },
539 total,
540 ))
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use citadel_core::types::PageType;
547
548 fn sample_hash() -> MerkleHash {
549 let mut h = [0u8; MERKLE_HASH_SIZE];
550 for (i, byte) in h.iter_mut().enumerate() {
551 *byte = i as u8;
552 }
553 h
554 }
555
556 #[test]
557 fn hello_roundtrip() {
558 let msg = SyncMessage::Hello {
559 node_id: NodeId::from_u64(42),
560 root_page: PageId(7),
561 root_hash: sample_hash(),
562 };
563 let data = msg.serialize();
564 let decoded = SyncMessage::deserialize(&data).unwrap();
565 match decoded {
566 SyncMessage::Hello {
567 node_id,
568 root_page,
569 root_hash,
570 } => {
571 assert_eq!(node_id, NodeId::from_u64(42));
572 assert_eq!(root_page, PageId(7));
573 assert_eq!(root_hash, sample_hash());
574 }
575 _ => panic!("wrong variant"),
576 }
577 }
578
579 #[test]
580 fn hello_ack_roundtrip() {
581 let msg = SyncMessage::HelloAck {
582 node_id: NodeId::from_u64(99),
583 root_page: PageId(3),
584 root_hash: sample_hash(),
585 in_sync: true,
586 };
587 let data = msg.serialize();
588 let decoded = SyncMessage::deserialize(&data).unwrap();
589 match decoded {
590 SyncMessage::HelloAck {
591 node_id,
592 root_page,
593 root_hash,
594 in_sync,
595 } => {
596 assert_eq!(node_id, NodeId::from_u64(99));
597 assert_eq!(root_page, PageId(3));
598 assert_eq!(root_hash, sample_hash());
599 assert!(in_sync);
600 }
601 _ => panic!("wrong variant"),
602 }
603 }
604
605 #[test]
606 fn digest_request_roundtrip() {
607 let msg = SyncMessage::DigestRequest {
608 page_ids: vec![PageId(1), PageId(5), PageId(100)],
609 };
610 let data = msg.serialize();
611 let decoded = SyncMessage::deserialize(&data).unwrap();
612 match decoded {
613 SyncMessage::DigestRequest { page_ids } => {
614 assert_eq!(page_ids, vec![PageId(1), PageId(5), PageId(100)]);
615 }
616 _ => panic!("wrong variant"),
617 }
618 }
619
620 #[test]
621 fn digest_response_roundtrip() {
622 let msg = SyncMessage::DigestResponse {
623 digests: vec![
624 PageDigest {
625 page_id: PageId(1),
626 page_type: PageType::Leaf,
627 merkle_hash: sample_hash(),
628 children: vec![],
629 },
630 PageDigest {
631 page_id: PageId(2),
632 page_type: PageType::Branch,
633 merkle_hash: [0xAA; MERKLE_HASH_SIZE],
634 children: vec![PageId(3), PageId(4)],
635 },
636 ],
637 };
638 let data = msg.serialize();
639 let decoded = SyncMessage::deserialize(&data).unwrap();
640 match decoded {
641 SyncMessage::DigestResponse { digests } => {
642 assert_eq!(digests.len(), 2);
643 assert_eq!(digests[0].page_id, PageId(1));
644 assert!(digests[0].children.is_empty());
645 assert_eq!(digests[1].children, vec![PageId(3), PageId(4)]);
646 }
647 _ => panic!("wrong variant"),
648 }
649 }
650
651 #[test]
652 fn entries_request_roundtrip() {
653 let msg = SyncMessage::EntriesRequest {
654 page_ids: vec![PageId(10)],
655 };
656 let data = msg.serialize();
657 let decoded = SyncMessage::deserialize(&data).unwrap();
658 match decoded {
659 SyncMessage::EntriesRequest { page_ids } => {
660 assert_eq!(page_ids, vec![PageId(10)]);
661 }
662 _ => panic!("wrong variant"),
663 }
664 }
665
666 #[test]
667 fn entries_response_roundtrip() {
668 let msg = SyncMessage::EntriesResponse {
669 entries: vec![
670 DiffEntry {
671 key: b"k1".to_vec(),
672 value: b"v1".to_vec(),
673 val_type: 0,
674 },
675 DiffEntry {
676 key: b"k2".to_vec(),
677 value: b"v2".to_vec(),
678 val_type: 1,
679 },
680 ],
681 };
682 let data = msg.serialize();
683 let decoded = SyncMessage::deserialize(&data).unwrap();
684 match decoded {
685 SyncMessage::EntriesResponse { entries } => {
686 assert_eq!(entries.len(), 2);
687 assert_eq!(entries[0].key, b"k1");
688 assert_eq!(entries[1].val_type, 1);
689 }
690 _ => panic!("wrong variant"),
691 }
692 }
693
694 #[test]
695 fn patch_data_roundtrip() {
696 let msg = SyncMessage::PatchData {
697 data: vec![1, 2, 3, 4, 5],
698 };
699 let data = msg.serialize();
700 let decoded = SyncMessage::deserialize(&data).unwrap();
701 match decoded {
702 SyncMessage::PatchData { data: d } => {
703 assert_eq!(d, vec![1, 2, 3, 4, 5]);
704 }
705 _ => panic!("wrong variant"),
706 }
707 }
708
709 #[test]
710 fn patch_ack_roundtrip() {
711 let msg = SyncMessage::PatchAck {
712 result: ApplyResult {
713 entries_applied: 10,
714 entries_skipped: 3,
715 entries_equal: 2,
716 },
717 };
718 let data = msg.serialize();
719 let decoded = SyncMessage::deserialize(&data).unwrap();
720 match decoded {
721 SyncMessage::PatchAck { result } => {
722 assert_eq!(result.entries_applied, 10);
723 assert_eq!(result.entries_skipped, 3);
724 assert_eq!(result.entries_equal, 2);
725 }
726 _ => panic!("wrong variant"),
727 }
728 }
729
730 #[test]
731 fn done_roundtrip() {
732 let data = SyncMessage::Done.serialize();
733 let decoded = SyncMessage::deserialize(&data).unwrap();
734 assert!(matches!(decoded, SyncMessage::Done));
735 }
736
737 #[test]
738 fn error_roundtrip() {
739 let msg = SyncMessage::Error {
740 message: "something broke".into(),
741 };
742 let data = msg.serialize();
743 let decoded = SyncMessage::deserialize(&data).unwrap();
744 match decoded {
745 SyncMessage::Error { message } => {
746 assert_eq!(message, "something broke");
747 }
748 _ => panic!("wrong variant"),
749 }
750 }
751
752 #[test]
753 fn pull_request_roundtrip() {
754 let data = SyncMessage::PullRequest.serialize();
755 let decoded = SyncMessage::deserialize(&data).unwrap();
756 assert!(matches!(decoded, SyncMessage::PullRequest));
757 }
758
759 #[test]
760 fn pull_response_roundtrip() {
761 let msg = SyncMessage::PullResponse {
762 root_page: PageId(15),
763 root_hash: sample_hash(),
764 };
765 let data = msg.serialize();
766 let decoded = SyncMessage::deserialize(&data).unwrap();
767 match decoded {
768 SyncMessage::PullResponse {
769 root_page,
770 root_hash,
771 } => {
772 assert_eq!(root_page, PageId(15));
773 assert_eq!(root_hash, sample_hash());
774 }
775 _ => panic!("wrong variant"),
776 }
777 }
778
779 #[test]
780 fn truncated_data() {
781 let err = SyncMessage::deserialize(&[0, 1]).unwrap_err();
782 assert!(matches!(err, ProtocolError::Truncated { .. }));
783 }
784
785 #[test]
786 fn unknown_message_type() {
787 let data = [255, 0, 0, 0, 0];
788 let err = SyncMessage::deserialize(&data).unwrap_err();
789 assert!(matches!(err, ProtocolError::UnknownMessageType(255)));
790 }
791
792 #[test]
793 fn empty_digest_request() {
794 let msg = SyncMessage::DigestRequest { page_ids: vec![] };
795 let data = msg.serialize();
796 let decoded = SyncMessage::deserialize(&data).unwrap();
797 match decoded {
798 SyncMessage::DigestRequest { page_ids } => assert!(page_ids.is_empty()),
799 _ => panic!("wrong variant"),
800 }
801 }
802
803 #[test]
804 fn table_list_request_roundtrip() {
805 let data = SyncMessage::TableListRequest.serialize();
806 let decoded = SyncMessage::deserialize(&data).unwrap();
807 assert!(matches!(decoded, SyncMessage::TableListRequest));
808 }
809
810 #[test]
811 fn table_list_response_roundtrip() {
812 let msg = SyncMessage::TableListResponse {
813 tables: vec![
814 TableInfo {
815 name: b"users".to_vec(),
816 root_page: PageId(10),
817 root_hash: sample_hash(),
818 },
819 TableInfo {
820 name: b"orders".to_vec(),
821 root_page: PageId(20),
822 root_hash: [0xBB; MERKLE_HASH_SIZE],
823 },
824 ],
825 };
826 let data = msg.serialize();
827 let decoded = SyncMessage::deserialize(&data).unwrap();
828 match decoded {
829 SyncMessage::TableListResponse { tables } => {
830 assert_eq!(tables.len(), 2);
831 assert_eq!(tables[0].name, b"users");
832 assert_eq!(tables[0].root_page, PageId(10));
833 assert_eq!(tables[0].root_hash, sample_hash());
834 assert_eq!(tables[1].name, b"orders");
835 assert_eq!(tables[1].root_page, PageId(20));
836 }
837 _ => panic!("wrong variant"),
838 }
839 }
840
841 #[test]
842 fn table_list_response_empty() {
843 let msg = SyncMessage::TableListResponse { tables: vec![] };
844 let data = msg.serialize();
845 let decoded = SyncMessage::deserialize(&data).unwrap();
846 match decoded {
847 SyncMessage::TableListResponse { tables } => assert!(tables.is_empty()),
848 _ => panic!("wrong variant"),
849 }
850 }
851
852 #[test]
853 fn table_sync_begin_roundtrip() {
854 let msg = SyncMessage::TableSyncBegin {
855 table_name: b"products".to_vec(),
856 root_page: PageId(77),
857 root_hash: sample_hash(),
858 };
859 let data = msg.serialize();
860 let decoded = SyncMessage::deserialize(&data).unwrap();
861 match decoded {
862 SyncMessage::TableSyncBegin {
863 table_name,
864 root_page,
865 root_hash,
866 } => {
867 assert_eq!(table_name, b"products");
868 assert_eq!(root_page, PageId(77));
869 assert_eq!(root_hash, sample_hash());
870 }
871 _ => panic!("wrong variant"),
872 }
873 }
874
875 #[test]
876 fn table_sync_end_roundtrip() {
877 let msg = SyncMessage::TableSyncEnd {
878 table_name: b"products".to_vec(),
879 };
880 let data = msg.serialize();
881 let decoded = SyncMessage::deserialize(&data).unwrap();
882 match decoded {
883 SyncMessage::TableSyncEnd { table_name } => {
884 assert_eq!(table_name, b"products");
885 }
886 _ => panic!("wrong variant"),
887 }
888 }
889
890 #[test]
891 fn empty_entries_response() {
892 let msg = SyncMessage::EntriesResponse { entries: vec![] };
893 let data = msg.serialize();
894 let decoded = SyncMessage::deserialize(&data).unwrap();
895 match decoded {
896 SyncMessage::EntriesResponse { entries } => assert!(entries.is_empty()),
897 _ => panic!("wrong variant"),
898 }
899 }
900}