1use citadel_core::types::PageId;
2use citadel_core::MERKLE_HASH_SIZE;
3
4use crate::apply::ApplyResult;
5use crate::diff::{DiffEntry, MerkleHash, PageDigest};
6use crate::node_id::NodeId;
7
8const MSG_HELLO: u8 = 0;
10const MSG_HELLO_ACK: u8 = 1;
11const MSG_DIGEST_REQUEST: u8 = 2;
12const MSG_DIGEST_RESPONSE: u8 = 3;
13const MSG_ENTRIES_REQUEST: u8 = 4;
14const MSG_ENTRIES_RESPONSE: u8 = 5;
15const MSG_PATCH_DATA: u8 = 6;
16const MSG_PATCH_ACK: u8 = 7;
17const MSG_DONE: u8 = 8;
18const MSG_ERROR: u8 = 9;
19const MSG_PULL_REQUEST: u8 = 10;
20const MSG_PULL_RESPONSE: u8 = 11;
21const MSG_TABLE_LIST_REQUEST: u8 = 12;
22const MSG_TABLE_LIST_RESPONSE: u8 = 13;
23const MSG_TABLE_SYNC_BEGIN: u8 = 14;
24const MSG_TABLE_SYNC_END: u8 = 15;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TableInfo {
29 pub name: Vec<u8>,
30 pub root_page: PageId,
31 pub root_hash: MerkleHash,
32}
33
34#[derive(Debug, Clone)]
36pub enum SyncMessage {
37 Hello {
39 node_id: NodeId,
40 root_page: PageId,
41 root_hash: MerkleHash,
42 },
43 HelloAck {
45 node_id: NodeId,
46 root_page: PageId,
47 root_hash: MerkleHash,
48 in_sync: bool,
49 },
50 DigestRequest {
52 page_ids: Vec<PageId>,
53 },
54 DigestResponse {
56 digests: Vec<PageDigest>,
57 },
58 EntriesRequest {
60 page_ids: Vec<PageId>,
61 },
62 EntriesResponse {
64 entries: Vec<DiffEntry>,
65 },
66 PatchData {
68 data: Vec<u8>,
69 },
70 PatchAck {
72 result: ApplyResult,
73 },
74 Done,
76 Error {
78 message: String,
79 },
80 PullRequest,
82 PullResponse {
84 root_page: PageId,
85 root_hash: MerkleHash,
86 },
87 TableListRequest,
89 TableListResponse {
91 tables: Vec<TableInfo>,
92 },
93 TableSyncBegin {
95 table_name: Vec<u8>,
96 root_page: PageId,
97 root_hash: MerkleHash,
98 },
99 TableSyncEnd {
101 table_name: Vec<u8>,
102 },
103}
104
105#[derive(Debug, thiserror::Error)]
107pub enum ProtocolError {
108 #[error("{context}: expected at least {expected} bytes, got {actual}")]
109 Truncated { context: String, expected: usize, actual: usize },
110
111 #[error("unknown message type: {0}")]
112 UnknownMessageType(u8),
113}
114
115impl SyncMessage {
116 pub fn serialize(&self) -> Vec<u8> {
118 let (msg_type, payload) = match self {
119 SyncMessage::Hello { node_id, root_page, root_hash } => {
120 let mut p = Vec::with_capacity(40);
121 p.extend_from_slice(&node_id.to_bytes());
122 p.extend_from_slice(&root_page.0.to_le_bytes());
123 p.extend_from_slice(root_hash);
124 (MSG_HELLO, p)
125 }
126 SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync } => {
127 let mut p = Vec::with_capacity(41);
128 p.extend_from_slice(&node_id.to_bytes());
129 p.extend_from_slice(&root_page.0.to_le_bytes());
130 p.extend_from_slice(root_hash);
131 p.push(if *in_sync { 1 } else { 0 });
132 (MSG_HELLO_ACK, p)
133 }
134 SyncMessage::DigestRequest { page_ids } => {
135 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
136 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
137 for pid in page_ids {
138 p.extend_from_slice(&pid.0.to_le_bytes());
139 }
140 (MSG_DIGEST_REQUEST, p)
141 }
142 SyncMessage::DigestResponse { digests } => {
143 let mut p = Vec::new();
144 p.extend_from_slice(&(digests.len() as u32).to_le_bytes());
145 for d in digests {
146 serialize_page_digest(&mut p, d);
147 }
148 (MSG_DIGEST_RESPONSE, p)
149 }
150 SyncMessage::EntriesRequest { page_ids } => {
151 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
152 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
153 for pid in page_ids {
154 p.extend_from_slice(&pid.0.to_le_bytes());
155 }
156 (MSG_ENTRIES_REQUEST, p)
157 }
158 SyncMessage::EntriesResponse { entries } => {
159 let mut p = Vec::new();
160 p.extend_from_slice(&(entries.len() as u32).to_le_bytes());
161 for e in entries {
162 serialize_diff_entry(&mut p, e);
163 }
164 (MSG_ENTRIES_RESPONSE, p)
165 }
166 SyncMessage::PatchData { data } => {
167 (MSG_PATCH_DATA, data.clone())
168 }
169 SyncMessage::PatchAck { result } => {
170 let mut p = Vec::with_capacity(24);
171 p.extend_from_slice(&result.entries_applied.to_le_bytes());
172 p.extend_from_slice(&result.entries_skipped.to_le_bytes());
173 p.extend_from_slice(&result.entries_equal.to_le_bytes());
174 (MSG_PATCH_ACK, p)
175 }
176 SyncMessage::Done => {
177 (MSG_DONE, Vec::new())
178 }
179 SyncMessage::Error { message } => {
180 let bytes = message.as_bytes();
181 let mut p = Vec::with_capacity(4 + bytes.len());
182 p.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
183 p.extend_from_slice(bytes);
184 (MSG_ERROR, p)
185 }
186 SyncMessage::PullRequest => {
187 (MSG_PULL_REQUEST, Vec::new())
188 }
189 SyncMessage::PullResponse { root_page, root_hash } => {
190 let mut p = Vec::with_capacity(32);
191 p.extend_from_slice(&root_page.0.to_le_bytes());
192 p.extend_from_slice(root_hash);
193 (MSG_PULL_RESPONSE, p)
194 }
195 SyncMessage::TableListRequest => {
196 (MSG_TABLE_LIST_REQUEST, Vec::new())
197 }
198 SyncMessage::TableListResponse { tables } => {
199 let mut p = Vec::new();
200 p.extend_from_slice(&(tables.len() as u32).to_le_bytes());
201 for t in tables {
202 p.extend_from_slice(&(t.name.len() as u16).to_le_bytes());
203 p.extend_from_slice(&t.name);
204 p.extend_from_slice(&t.root_page.0.to_le_bytes());
205 p.extend_from_slice(&t.root_hash);
206 }
207 (MSG_TABLE_LIST_RESPONSE, p)
208 }
209 SyncMessage::TableSyncBegin { table_name, root_page, root_hash } => {
210 let mut p = Vec::with_capacity(2 + table_name.len() + 4 + MERKLE_HASH_SIZE);
211 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
212 p.extend_from_slice(table_name);
213 p.extend_from_slice(&root_page.0.to_le_bytes());
214 p.extend_from_slice(root_hash);
215 (MSG_TABLE_SYNC_BEGIN, p)
216 }
217 SyncMessage::TableSyncEnd { table_name } => {
218 let mut p = Vec::with_capacity(2 + table_name.len());
219 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
220 p.extend_from_slice(table_name);
221 (MSG_TABLE_SYNC_END, p)
222 }
223 };
224
225 let mut buf = Vec::with_capacity(5 + payload.len());
226 buf.push(msg_type);
227 buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
228 buf.extend_from_slice(&payload);
229 buf
230 }
231
232 pub fn deserialize(data: &[u8]) -> Result<Self, ProtocolError> {
234 if data.len() < 5 {
235 return Err(ProtocolError::Truncated {
236 context: "message header".to_string(),
237 expected: 5,
238 actual: data.len(),
239 });
240 }
241
242 let msg_type = data[0];
243 let payload_len = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
244
245 if data.len() < 5 + payload_len {
246 return Err(ProtocolError::Truncated {
247 context: "message payload".to_string(),
248 expected: 5 + payload_len,
249 actual: data.len(),
250 });
251 }
252
253 let payload = &data[5..5 + payload_len];
254
255 match msg_type {
256 MSG_HELLO => {
257 ensure_len(payload, 40, "Hello")?;
258 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
259 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
260 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
261 root_hash.copy_from_slice(&payload[12..40]);
262 Ok(SyncMessage::Hello { node_id, root_page, root_hash })
263 }
264 MSG_HELLO_ACK => {
265 ensure_len(payload, 41, "HelloAck")?;
266 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
267 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
268 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
269 root_hash.copy_from_slice(&payload[12..40]);
270 let in_sync = payload[40] != 0;
271 Ok(SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync })
272 }
273 MSG_DIGEST_REQUEST => {
274 ensure_len(payload, 4, "DigestRequest")?;
275 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
276 ensure_len(payload, 4 + count * 4, "DigestRequest")?;
277 let page_ids = (0..count)
278 .map(|i| {
279 let off = 4 + i * 4;
280 PageId(u32::from_le_bytes(payload[off..off + 4].try_into().unwrap()))
281 })
282 .collect();
283 Ok(SyncMessage::DigestRequest { page_ids })
284 }
285 MSG_DIGEST_RESPONSE => {
286 ensure_len(payload, 4, "DigestResponse")?;
287 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
288 let mut pos = 4;
289 let mut digests = Vec::with_capacity(count);
290 for _ in 0..count {
291 let (digest, consumed) = deserialize_page_digest(payload, pos)?;
292 digests.push(digest);
293 pos += consumed;
294 }
295 Ok(SyncMessage::DigestResponse { digests })
296 }
297 MSG_ENTRIES_REQUEST => {
298 ensure_len(payload, 4, "EntriesRequest")?;
299 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
300 ensure_len(payload, 4 + count * 4, "EntriesRequest")?;
301 let page_ids = (0..count)
302 .map(|i| {
303 let off = 4 + i * 4;
304 PageId(u32::from_le_bytes(payload[off..off + 4].try_into().unwrap()))
305 })
306 .collect();
307 Ok(SyncMessage::EntriesRequest { page_ids })
308 }
309 MSG_ENTRIES_RESPONSE => {
310 ensure_len(payload, 4, "EntriesResponse")?;
311 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
312 let mut pos = 4;
313 let mut entries = Vec::with_capacity(count);
314 for _ in 0..count {
315 let (entry, consumed) = deserialize_diff_entry(payload, pos)?;
316 entries.push(entry);
317 pos += consumed;
318 }
319 Ok(SyncMessage::EntriesResponse { entries })
320 }
321 MSG_PATCH_DATA => {
322 Ok(SyncMessage::PatchData { data: payload.to_vec() })
323 }
324 MSG_PATCH_ACK => {
325 ensure_len(payload, 24, "PatchAck")?;
326 let entries_applied = u64::from_le_bytes(payload[0..8].try_into().unwrap());
327 let entries_skipped = u64::from_le_bytes(payload[8..16].try_into().unwrap());
328 let entries_equal = u64::from_le_bytes(payload[16..24].try_into().unwrap());
329 Ok(SyncMessage::PatchAck {
330 result: ApplyResult { entries_applied, entries_skipped, entries_equal },
331 })
332 }
333 MSG_DONE => {
334 Ok(SyncMessage::Done)
335 }
336 MSG_ERROR => {
337 ensure_len(payload, 4, "Error")?;
338 let msg_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
339 ensure_len(payload, 4 + msg_len, "Error")?;
340 let message = String::from_utf8_lossy(&payload[4..4 + msg_len]).into_owned();
341 Ok(SyncMessage::Error { message })
342 }
343 MSG_PULL_REQUEST => {
344 Ok(SyncMessage::PullRequest)
345 }
346 MSG_PULL_RESPONSE => {
347 ensure_len(payload, 32, "PullResponse")?;
348 let root_page = PageId(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
349 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
350 root_hash.copy_from_slice(&payload[4..32]);
351 Ok(SyncMessage::PullResponse { root_page, root_hash })
352 }
353 MSG_TABLE_LIST_REQUEST => {
354 Ok(SyncMessage::TableListRequest)
355 }
356 MSG_TABLE_LIST_RESPONSE => {
357 ensure_len(payload, 4, "TableListResponse")?;
358 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
359 let mut pos = 4;
360 let mut tables = Vec::with_capacity(count);
361 for _ in 0..count {
362 ensure_len(payload, pos + 2, "TableInfo name_len")?;
363 let name_len = u16::from_le_bytes(
364 payload[pos..pos + 2].try_into().unwrap(),
365 ) as usize;
366 pos += 2;
367 ensure_len(payload, pos + name_len + 4 + MERKLE_HASH_SIZE, "TableInfo")?;
368 let name = payload[pos..pos + name_len].to_vec();
369 pos += name_len;
370 let root_page = PageId(u32::from_le_bytes(
371 payload[pos..pos + 4].try_into().unwrap(),
372 ));
373 pos += 4;
374 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
375 root_hash.copy_from_slice(&payload[pos..pos + MERKLE_HASH_SIZE]);
376 pos += MERKLE_HASH_SIZE;
377 tables.push(TableInfo { name, root_page, root_hash });
378 }
379 Ok(SyncMessage::TableListResponse { tables })
380 }
381 MSG_TABLE_SYNC_BEGIN => {
382 ensure_len(payload, 2, "TableSyncBegin")?;
383 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
384 ensure_len(payload, 2 + name_len + 4 + MERKLE_HASH_SIZE, "TableSyncBegin")?;
385 let table_name = payload[2..2 + name_len].to_vec();
386 let off = 2 + name_len;
387 let root_page = PageId(u32::from_le_bytes(
388 payload[off..off + 4].try_into().unwrap(),
389 ));
390 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
391 root_hash.copy_from_slice(&payload[off + 4..off + 4 + MERKLE_HASH_SIZE]);
392 Ok(SyncMessage::TableSyncBegin { table_name, root_page, root_hash })
393 }
394 MSG_TABLE_SYNC_END => {
395 ensure_len(payload, 2, "TableSyncEnd")?;
396 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
397 ensure_len(payload, 2 + name_len, "TableSyncEnd")?;
398 let table_name = payload[2..2 + name_len].to_vec();
399 Ok(SyncMessage::TableSyncEnd { table_name })
400 }
401 _ => Err(ProtocolError::UnknownMessageType(msg_type)),
402 }
403 }
404}
405
406fn ensure_len(data: &[u8], needed: usize, ctx: &str) -> Result<(), ProtocolError> {
407 if data.len() < needed {
408 Err(ProtocolError::Truncated {
409 context: ctx.to_string(),
410 expected: needed,
411 actual: data.len(),
412 })
413 } else {
414 Ok(())
415 }
416}
417
418fn serialize_page_digest(buf: &mut Vec<u8>, d: &PageDigest) {
419 buf.extend_from_slice(&d.page_id.0.to_le_bytes());
420 buf.extend_from_slice(&(d.page_type as u16).to_le_bytes());
421 buf.extend_from_slice(&d.merkle_hash);
422 buf.extend_from_slice(&(d.children.len() as u32).to_le_bytes());
423 for c in &d.children {
424 buf.extend_from_slice(&c.0.to_le_bytes());
425 }
426}
427
428fn deserialize_page_digest(data: &[u8], offset: usize) -> Result<(PageDigest, usize), ProtocolError> {
429 let min = 38;
431 if data.len() < offset + min {
432 return Err(ProtocolError::Truncated {
433 context: "PageDigest header".to_string(),
434 expected: offset + min,
435 actual: data.len(),
436 });
437 }
438
439 let page_id = PageId(u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()));
440 let page_type_raw = u16::from_le_bytes(data[offset + 4..offset + 6].try_into().unwrap());
441 let page_type = citadel_core::types::PageType::from_u16(page_type_raw)
442 .unwrap_or(citadel_core::types::PageType::Leaf);
443 let mut merkle_hash = [0u8; MERKLE_HASH_SIZE];
444 merkle_hash.copy_from_slice(&data[offset + 6..offset + 34]);
445 let child_count = u32::from_le_bytes(data[offset + 34..offset + 38].try_into().unwrap()) as usize;
446
447 if data.len() < offset + min + child_count * 4 {
448 return Err(ProtocolError::Truncated {
449 context: "PageDigest children".to_string(),
450 expected: offset + min + child_count * 4,
451 actual: data.len(),
452 });
453 }
454
455 let children = (0..child_count)
456 .map(|i| {
457 let off = offset + 38 + i * 4;
458 PageId(u32::from_le_bytes(data[off..off + 4].try_into().unwrap()))
459 })
460 .collect();
461
462 Ok((
463 PageDigest { page_id, page_type, merkle_hash, children },
464 min + child_count * 4,
465 ))
466}
467
468fn serialize_diff_entry(buf: &mut Vec<u8>, e: &DiffEntry) {
469 buf.extend_from_slice(&(e.key.len() as u16).to_le_bytes());
470 buf.extend_from_slice(&(e.value.len() as u32).to_le_bytes());
471 buf.push(e.val_type);
472 buf.extend_from_slice(&e.key);
473 buf.extend_from_slice(&e.value);
474}
475
476fn deserialize_diff_entry(data: &[u8], offset: usize) -> Result<(DiffEntry, usize), ProtocolError> {
477 let header = 7;
479 if data.len() < offset + header {
480 return Err(ProtocolError::Truncated {
481 context: "DiffEntry header".to_string(),
482 expected: offset + header,
483 actual: data.len(),
484 });
485 }
486
487 let key_len = u16::from_le_bytes(data[offset..offset + 2].try_into().unwrap()) as usize;
488 let val_len = u32::from_le_bytes(data[offset + 2..offset + 6].try_into().unwrap()) as usize;
489 let val_type = data[offset + 6];
490
491 let total = header + key_len + val_len;
492 if data.len() < offset + total {
493 return Err(ProtocolError::Truncated {
494 context: "DiffEntry data".to_string(),
495 expected: offset + total,
496 actual: data.len(),
497 });
498 }
499
500 let key = data[offset + 7..offset + 7 + key_len].to_vec();
501 let value = data[offset + 7 + key_len..offset + 7 + key_len + val_len].to_vec();
502
503 Ok((DiffEntry { key, value, val_type }, total))
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use citadel_core::types::PageType;
510
511 fn sample_hash() -> MerkleHash {
512 let mut h = [0u8; MERKLE_HASH_SIZE];
513 for i in 0..MERKLE_HASH_SIZE {
514 h[i] = i as u8;
515 }
516 h
517 }
518
519 #[test]
520 fn hello_roundtrip() {
521 let msg = SyncMessage::Hello {
522 node_id: NodeId::from_u64(42),
523 root_page: PageId(7),
524 root_hash: sample_hash(),
525 };
526 let data = msg.serialize();
527 let decoded = SyncMessage::deserialize(&data).unwrap();
528 match decoded {
529 SyncMessage::Hello { node_id, root_page, root_hash } => {
530 assert_eq!(node_id, NodeId::from_u64(42));
531 assert_eq!(root_page, PageId(7));
532 assert_eq!(root_hash, sample_hash());
533 }
534 _ => panic!("wrong variant"),
535 }
536 }
537
538 #[test]
539 fn hello_ack_roundtrip() {
540 let msg = SyncMessage::HelloAck {
541 node_id: NodeId::from_u64(99),
542 root_page: PageId(3),
543 root_hash: sample_hash(),
544 in_sync: true,
545 };
546 let data = msg.serialize();
547 let decoded = SyncMessage::deserialize(&data).unwrap();
548 match decoded {
549 SyncMessage::HelloAck { node_id, root_page, root_hash, in_sync } => {
550 assert_eq!(node_id, NodeId::from_u64(99));
551 assert_eq!(root_page, PageId(3));
552 assert_eq!(root_hash, sample_hash());
553 assert!(in_sync);
554 }
555 _ => panic!("wrong variant"),
556 }
557 }
558
559 #[test]
560 fn digest_request_roundtrip() {
561 let msg = SyncMessage::DigestRequest {
562 page_ids: vec![PageId(1), PageId(5), PageId(100)],
563 };
564 let data = msg.serialize();
565 let decoded = SyncMessage::deserialize(&data).unwrap();
566 match decoded {
567 SyncMessage::DigestRequest { page_ids } => {
568 assert_eq!(page_ids, vec![PageId(1), PageId(5), PageId(100)]);
569 }
570 _ => panic!("wrong variant"),
571 }
572 }
573
574 #[test]
575 fn digest_response_roundtrip() {
576 let msg = SyncMessage::DigestResponse {
577 digests: vec![
578 PageDigest {
579 page_id: PageId(1),
580 page_type: PageType::Leaf,
581 merkle_hash: sample_hash(),
582 children: vec![],
583 },
584 PageDigest {
585 page_id: PageId(2),
586 page_type: PageType::Branch,
587 merkle_hash: [0xAA; MERKLE_HASH_SIZE],
588 children: vec![PageId(3), PageId(4)],
589 },
590 ],
591 };
592 let data = msg.serialize();
593 let decoded = SyncMessage::deserialize(&data).unwrap();
594 match decoded {
595 SyncMessage::DigestResponse { digests } => {
596 assert_eq!(digests.len(), 2);
597 assert_eq!(digests[0].page_id, PageId(1));
598 assert!(digests[0].children.is_empty());
599 assert_eq!(digests[1].children, vec![PageId(3), PageId(4)]);
600 }
601 _ => panic!("wrong variant"),
602 }
603 }
604
605 #[test]
606 fn entries_request_roundtrip() {
607 let msg = SyncMessage::EntriesRequest {
608 page_ids: vec![PageId(10)],
609 };
610 let data = msg.serialize();
611 let decoded = SyncMessage::deserialize(&data).unwrap();
612 match decoded {
613 SyncMessage::EntriesRequest { page_ids } => {
614 assert_eq!(page_ids, vec![PageId(10)]);
615 }
616 _ => panic!("wrong variant"),
617 }
618 }
619
620 #[test]
621 fn entries_response_roundtrip() {
622 let msg = SyncMessage::EntriesResponse {
623 entries: vec![
624 DiffEntry { key: b"k1".to_vec(), value: b"v1".to_vec(), val_type: 0 },
625 DiffEntry { key: b"k2".to_vec(), value: b"v2".to_vec(), val_type: 1 },
626 ],
627 };
628 let data = msg.serialize();
629 let decoded = SyncMessage::deserialize(&data).unwrap();
630 match decoded {
631 SyncMessage::EntriesResponse { entries } => {
632 assert_eq!(entries.len(), 2);
633 assert_eq!(entries[0].key, b"k1");
634 assert_eq!(entries[1].val_type, 1);
635 }
636 _ => panic!("wrong variant"),
637 }
638 }
639
640 #[test]
641 fn patch_data_roundtrip() {
642 let msg = SyncMessage::PatchData { data: vec![1, 2, 3, 4, 5] };
643 let data = msg.serialize();
644 let decoded = SyncMessage::deserialize(&data).unwrap();
645 match decoded {
646 SyncMessage::PatchData { data: d } => {
647 assert_eq!(d, vec![1, 2, 3, 4, 5]);
648 }
649 _ => panic!("wrong variant"),
650 }
651 }
652
653 #[test]
654 fn patch_ack_roundtrip() {
655 let msg = SyncMessage::PatchAck {
656 result: ApplyResult {
657 entries_applied: 10,
658 entries_skipped: 3,
659 entries_equal: 2,
660 },
661 };
662 let data = msg.serialize();
663 let decoded = SyncMessage::deserialize(&data).unwrap();
664 match decoded {
665 SyncMessage::PatchAck { result } => {
666 assert_eq!(result.entries_applied, 10);
667 assert_eq!(result.entries_skipped, 3);
668 assert_eq!(result.entries_equal, 2);
669 }
670 _ => panic!("wrong variant"),
671 }
672 }
673
674 #[test]
675 fn done_roundtrip() {
676 let data = SyncMessage::Done.serialize();
677 let decoded = SyncMessage::deserialize(&data).unwrap();
678 assert!(matches!(decoded, SyncMessage::Done));
679 }
680
681 #[test]
682 fn error_roundtrip() {
683 let msg = SyncMessage::Error { message: "something broke".into() };
684 let data = msg.serialize();
685 let decoded = SyncMessage::deserialize(&data).unwrap();
686 match decoded {
687 SyncMessage::Error { message } => {
688 assert_eq!(message, "something broke");
689 }
690 _ => panic!("wrong variant"),
691 }
692 }
693
694 #[test]
695 fn pull_request_roundtrip() {
696 let data = SyncMessage::PullRequest.serialize();
697 let decoded = SyncMessage::deserialize(&data).unwrap();
698 assert!(matches!(decoded, SyncMessage::PullRequest));
699 }
700
701 #[test]
702 fn pull_response_roundtrip() {
703 let msg = SyncMessage::PullResponse {
704 root_page: PageId(15),
705 root_hash: sample_hash(),
706 };
707 let data = msg.serialize();
708 let decoded = SyncMessage::deserialize(&data).unwrap();
709 match decoded {
710 SyncMessage::PullResponse { root_page, root_hash } => {
711 assert_eq!(root_page, PageId(15));
712 assert_eq!(root_hash, sample_hash());
713 }
714 _ => panic!("wrong variant"),
715 }
716 }
717
718 #[test]
719 fn truncated_data() {
720 let err = SyncMessage::deserialize(&[0, 1]).unwrap_err();
721 assert!(matches!(err, ProtocolError::Truncated { .. }));
722 }
723
724 #[test]
725 fn unknown_message_type() {
726 let data = [255, 0, 0, 0, 0];
727 let err = SyncMessage::deserialize(&data).unwrap_err();
728 assert!(matches!(err, ProtocolError::UnknownMessageType(255)));
729 }
730
731 #[test]
732 fn empty_digest_request() {
733 let msg = SyncMessage::DigestRequest { page_ids: vec![] };
734 let data = msg.serialize();
735 let decoded = SyncMessage::deserialize(&data).unwrap();
736 match decoded {
737 SyncMessage::DigestRequest { page_ids } => assert!(page_ids.is_empty()),
738 _ => panic!("wrong variant"),
739 }
740 }
741
742 #[test]
743 fn table_list_request_roundtrip() {
744 let data = SyncMessage::TableListRequest.serialize();
745 let decoded = SyncMessage::deserialize(&data).unwrap();
746 assert!(matches!(decoded, SyncMessage::TableListRequest));
747 }
748
749 #[test]
750 fn table_list_response_roundtrip() {
751 let msg = SyncMessage::TableListResponse {
752 tables: vec![
753 TableInfo {
754 name: b"users".to_vec(),
755 root_page: PageId(10),
756 root_hash: sample_hash(),
757 },
758 TableInfo {
759 name: b"orders".to_vec(),
760 root_page: PageId(20),
761 root_hash: [0xBB; MERKLE_HASH_SIZE],
762 },
763 ],
764 };
765 let data = msg.serialize();
766 let decoded = SyncMessage::deserialize(&data).unwrap();
767 match decoded {
768 SyncMessage::TableListResponse { tables } => {
769 assert_eq!(tables.len(), 2);
770 assert_eq!(tables[0].name, b"users");
771 assert_eq!(tables[0].root_page, PageId(10));
772 assert_eq!(tables[0].root_hash, sample_hash());
773 assert_eq!(tables[1].name, b"orders");
774 assert_eq!(tables[1].root_page, PageId(20));
775 }
776 _ => panic!("wrong variant"),
777 }
778 }
779
780 #[test]
781 fn table_list_response_empty() {
782 let msg = SyncMessage::TableListResponse { tables: vec![] };
783 let data = msg.serialize();
784 let decoded = SyncMessage::deserialize(&data).unwrap();
785 match decoded {
786 SyncMessage::TableListResponse { tables } => assert!(tables.is_empty()),
787 _ => panic!("wrong variant"),
788 }
789 }
790
791 #[test]
792 fn table_sync_begin_roundtrip() {
793 let msg = SyncMessage::TableSyncBegin {
794 table_name: b"products".to_vec(),
795 root_page: PageId(77),
796 root_hash: sample_hash(),
797 };
798 let data = msg.serialize();
799 let decoded = SyncMessage::deserialize(&data).unwrap();
800 match decoded {
801 SyncMessage::TableSyncBegin { table_name, root_page, root_hash } => {
802 assert_eq!(table_name, b"products");
803 assert_eq!(root_page, PageId(77));
804 assert_eq!(root_hash, sample_hash());
805 }
806 _ => panic!("wrong variant"),
807 }
808 }
809
810 #[test]
811 fn table_sync_end_roundtrip() {
812 let msg = SyncMessage::TableSyncEnd {
813 table_name: b"products".to_vec(),
814 };
815 let data = msg.serialize();
816 let decoded = SyncMessage::deserialize(&data).unwrap();
817 match decoded {
818 SyncMessage::TableSyncEnd { table_name } => {
819 assert_eq!(table_name, b"products");
820 }
821 _ => panic!("wrong variant"),
822 }
823 }
824
825 #[test]
826 fn empty_entries_response() {
827 let msg = SyncMessage::EntriesResponse { entries: vec![] };
828 let data = msg.serialize();
829 let decoded = SyncMessage::deserialize(&data).unwrap();
830 match decoded {
831 SyncMessage::EntriesResponse { entries } => assert!(entries.is_empty()),
832 _ => panic!("wrong variant"),
833 }
834 }
835}