1use citadel_core::types::PageId;
2use citadel_core::MERKLE_HASH_SIZE;
3
4use crate::apply::ApplyResult;
5use crate::diff::{DiffEntry, MerkleHash, PageDigest};
6use crate::node_id::NodeId;
7
8const MSG_HELLO: u8 = 0;
10const MSG_HELLO_ACK: u8 = 1;
11const MSG_DIGEST_REQUEST: u8 = 2;
12const MSG_DIGEST_RESPONSE: u8 = 3;
13const MSG_ENTRIES_REQUEST: u8 = 4;
14const MSG_ENTRIES_RESPONSE: u8 = 5;
15const MSG_PATCH_DATA: u8 = 6;
16const MSG_PATCH_ACK: u8 = 7;
17const MSG_DONE: u8 = 8;
18const MSG_ERROR: u8 = 9;
19const MSG_PULL_REQUEST: u8 = 10;
20const MSG_PULL_RESPONSE: u8 = 11;
21const MSG_TABLE_LIST_REQUEST: u8 = 12;
22const MSG_TABLE_LIST_RESPONSE: u8 = 13;
23const MSG_TABLE_SYNC_BEGIN: u8 = 14;
24const MSG_TABLE_SYNC_END: u8 = 15;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct TableInfo {
29 pub name: Vec<u8>,
30 pub root_page: PageId,
31 pub root_hash: MerkleHash,
32}
33
34#[derive(Debug, Clone)]
36pub enum SyncMessage {
37 Hello {
39 node_id: NodeId,
40 root_page: PageId,
41 root_hash: MerkleHash,
42 },
43 HelloAck {
45 node_id: NodeId,
46 root_page: PageId,
47 root_hash: MerkleHash,
48 in_sync: bool,
49 },
50 DigestRequest { page_ids: Vec<PageId> },
52 DigestResponse { digests: Vec<PageDigest> },
54 EntriesRequest { page_ids: Vec<PageId> },
56 EntriesResponse { entries: Vec<DiffEntry> },
58 PatchData { data: Vec<u8> },
60 PatchAck { result: ApplyResult },
62 Done,
64 Error { message: String },
66 PullRequest,
68 PullResponse {
70 root_page: PageId,
71 root_hash: MerkleHash,
72 },
73 TableListRequest,
75 TableListResponse { tables: Vec<TableInfo> },
77 TableSyncBegin {
79 table_name: Vec<u8>,
80 root_page: PageId,
81 root_hash: MerkleHash,
82 },
83 TableSyncEnd { table_name: Vec<u8> },
85}
86
87#[derive(Debug, thiserror::Error)]
89pub enum ProtocolError {
90 #[error("{context}: expected at least {expected} bytes, got {actual}")]
91 Truncated {
92 context: String,
93 expected: usize,
94 actual: usize,
95 },
96
97 #[error("unknown message type: {0}")]
98 UnknownMessageType(u8),
99}
100
101impl SyncMessage {
102 pub fn serialize(&self) -> Vec<u8> {
104 let (msg_type, payload) = match self {
105 SyncMessage::Hello {
106 node_id,
107 root_page,
108 root_hash,
109 } => {
110 let mut p = Vec::with_capacity(40);
111 p.extend_from_slice(&node_id.to_bytes());
112 p.extend_from_slice(&root_page.0.to_le_bytes());
113 p.extend_from_slice(root_hash);
114 (MSG_HELLO, p)
115 }
116 SyncMessage::HelloAck {
117 node_id,
118 root_page,
119 root_hash,
120 in_sync,
121 } => {
122 let mut p = Vec::with_capacity(41);
123 p.extend_from_slice(&node_id.to_bytes());
124 p.extend_from_slice(&root_page.0.to_le_bytes());
125 p.extend_from_slice(root_hash);
126 p.push(if *in_sync { 1 } else { 0 });
127 (MSG_HELLO_ACK, p)
128 }
129 SyncMessage::DigestRequest { page_ids } => {
130 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
131 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
132 for pid in page_ids {
133 p.extend_from_slice(&pid.0.to_le_bytes());
134 }
135 (MSG_DIGEST_REQUEST, p)
136 }
137 SyncMessage::DigestResponse { digests } => {
138 let mut p = Vec::new();
139 p.extend_from_slice(&(digests.len() as u32).to_le_bytes());
140 for d in digests {
141 serialize_page_digest(&mut p, d);
142 }
143 (MSG_DIGEST_RESPONSE, p)
144 }
145 SyncMessage::EntriesRequest { page_ids } => {
146 let mut p = Vec::with_capacity(4 + page_ids.len() * 4);
147 p.extend_from_slice(&(page_ids.len() as u32).to_le_bytes());
148 for pid in page_ids {
149 p.extend_from_slice(&pid.0.to_le_bytes());
150 }
151 (MSG_ENTRIES_REQUEST, p)
152 }
153 SyncMessage::EntriesResponse { entries } => {
154 let mut p = Vec::new();
155 p.extend_from_slice(&(entries.len() as u32).to_le_bytes());
156 for e in entries {
157 serialize_diff_entry(&mut p, e);
158 }
159 (MSG_ENTRIES_RESPONSE, p)
160 }
161 SyncMessage::PatchData { data } => (MSG_PATCH_DATA, data.clone()),
162 SyncMessage::PatchAck { result } => {
163 let mut p = Vec::with_capacity(24);
164 p.extend_from_slice(&result.entries_applied.to_le_bytes());
165 p.extend_from_slice(&result.entries_skipped.to_le_bytes());
166 p.extend_from_slice(&result.entries_equal.to_le_bytes());
167 (MSG_PATCH_ACK, p)
168 }
169 SyncMessage::Done => (MSG_DONE, Vec::new()),
170 SyncMessage::Error { message } => {
171 let bytes = message.as_bytes();
172 let mut p = Vec::with_capacity(4 + bytes.len());
173 p.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
174 p.extend_from_slice(bytes);
175 (MSG_ERROR, p)
176 }
177 SyncMessage::PullRequest => (MSG_PULL_REQUEST, Vec::new()),
178 SyncMessage::PullResponse {
179 root_page,
180 root_hash,
181 } => {
182 let mut p = Vec::with_capacity(32);
183 p.extend_from_slice(&root_page.0.to_le_bytes());
184 p.extend_from_slice(root_hash);
185 (MSG_PULL_RESPONSE, p)
186 }
187 SyncMessage::TableListRequest => (MSG_TABLE_LIST_REQUEST, Vec::new()),
188 SyncMessage::TableListResponse { tables } => {
189 let mut p = Vec::new();
190 p.extend_from_slice(&(tables.len() as u32).to_le_bytes());
191 for t in tables {
192 p.extend_from_slice(&(t.name.len() as u16).to_le_bytes());
193 p.extend_from_slice(&t.name);
194 p.extend_from_slice(&t.root_page.0.to_le_bytes());
195 p.extend_from_slice(&t.root_hash);
196 }
197 (MSG_TABLE_LIST_RESPONSE, p)
198 }
199 SyncMessage::TableSyncBegin {
200 table_name,
201 root_page,
202 root_hash,
203 } => {
204 let mut p = Vec::with_capacity(2 + table_name.len() + 4 + MERKLE_HASH_SIZE);
205 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
206 p.extend_from_slice(table_name);
207 p.extend_from_slice(&root_page.0.to_le_bytes());
208 p.extend_from_slice(root_hash);
209 (MSG_TABLE_SYNC_BEGIN, p)
210 }
211 SyncMessage::TableSyncEnd { table_name } => {
212 let mut p = Vec::with_capacity(2 + table_name.len());
213 p.extend_from_slice(&(table_name.len() as u16).to_le_bytes());
214 p.extend_from_slice(table_name);
215 (MSG_TABLE_SYNC_END, p)
216 }
217 };
218
219 let mut buf = Vec::with_capacity(5 + payload.len());
220 buf.push(msg_type);
221 buf.extend_from_slice(&(payload.len() as u32).to_le_bytes());
222 buf.extend_from_slice(&payload);
223 buf
224 }
225
226 pub fn deserialize(data: &[u8]) -> Result<Self, ProtocolError> {
228 if data.len() < 5 {
229 return Err(ProtocolError::Truncated {
230 context: "message header".to_string(),
231 expected: 5,
232 actual: data.len(),
233 });
234 }
235
236 let msg_type = data[0];
237 let payload_len = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
238
239 if data.len() < 5 + payload_len {
240 return Err(ProtocolError::Truncated {
241 context: "message payload".to_string(),
242 expected: 5 + payload_len,
243 actual: data.len(),
244 });
245 }
246
247 let payload = &data[5..5 + payload_len];
248
249 match msg_type {
250 MSG_HELLO => {
251 ensure_len(payload, 40, "Hello")?;
252 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
253 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
254 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
255 root_hash.copy_from_slice(&payload[12..40]);
256 Ok(SyncMessage::Hello {
257 node_id,
258 root_page,
259 root_hash,
260 })
261 }
262 MSG_HELLO_ACK => {
263 ensure_len(payload, 41, "HelloAck")?;
264 let node_id = NodeId::from_bytes(payload[0..8].try_into().unwrap());
265 let root_page = PageId(u32::from_le_bytes(payload[8..12].try_into().unwrap()));
266 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
267 root_hash.copy_from_slice(&payload[12..40]);
268 let in_sync = payload[40] != 0;
269 Ok(SyncMessage::HelloAck {
270 node_id,
271 root_page,
272 root_hash,
273 in_sync,
274 })
275 }
276 MSG_DIGEST_REQUEST => {
277 ensure_len(payload, 4, "DigestRequest")?;
278 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
279 ensure_len(payload, 4 + count * 4, "DigestRequest")?;
280 let page_ids = (0..count)
281 .map(|i| {
282 let off = 4 + i * 4;
283 PageId(u32::from_le_bytes(
284 payload[off..off + 4].try_into().unwrap(),
285 ))
286 })
287 .collect();
288 Ok(SyncMessage::DigestRequest { page_ids })
289 }
290 MSG_DIGEST_RESPONSE => {
291 ensure_len(payload, 4, "DigestResponse")?;
292 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
293 let mut pos = 4;
294 let mut digests = Vec::with_capacity(count);
295 for _ in 0..count {
296 let (digest, consumed) = deserialize_page_digest(payload, pos)?;
297 digests.push(digest);
298 pos += consumed;
299 }
300 Ok(SyncMessage::DigestResponse { digests })
301 }
302 MSG_ENTRIES_REQUEST => {
303 ensure_len(payload, 4, "EntriesRequest")?;
304 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
305 ensure_len(payload, 4 + count * 4, "EntriesRequest")?;
306 let page_ids = (0..count)
307 .map(|i| {
308 let off = 4 + i * 4;
309 PageId(u32::from_le_bytes(
310 payload[off..off + 4].try_into().unwrap(),
311 ))
312 })
313 .collect();
314 Ok(SyncMessage::EntriesRequest { page_ids })
315 }
316 MSG_ENTRIES_RESPONSE => {
317 ensure_len(payload, 4, "EntriesResponse")?;
318 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
319 let mut pos = 4;
320 let mut entries = Vec::with_capacity(count);
321 for _ in 0..count {
322 let (entry, consumed) = deserialize_diff_entry(payload, pos)?;
323 entries.push(entry);
324 pos += consumed;
325 }
326 Ok(SyncMessage::EntriesResponse { entries })
327 }
328 MSG_PATCH_DATA => Ok(SyncMessage::PatchData {
329 data: payload.to_vec(),
330 }),
331 MSG_PATCH_ACK => {
332 ensure_len(payload, 24, "PatchAck")?;
333 let entries_applied = u64::from_le_bytes(payload[0..8].try_into().unwrap());
334 let entries_skipped = u64::from_le_bytes(payload[8..16].try_into().unwrap());
335 let entries_equal = u64::from_le_bytes(payload[16..24].try_into().unwrap());
336 Ok(SyncMessage::PatchAck {
337 result: ApplyResult {
338 entries_applied,
339 entries_skipped,
340 entries_equal,
341 },
342 })
343 }
344 MSG_DONE => Ok(SyncMessage::Done),
345 MSG_ERROR => {
346 ensure_len(payload, 4, "Error")?;
347 let msg_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
348 ensure_len(payload, 4 + msg_len, "Error")?;
349 let message = String::from_utf8_lossy(&payload[4..4 + msg_len]).into_owned();
350 Ok(SyncMessage::Error { message })
351 }
352 MSG_PULL_REQUEST => Ok(SyncMessage::PullRequest),
353 MSG_PULL_RESPONSE => {
354 ensure_len(payload, 32, "PullResponse")?;
355 let root_page = PageId(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
356 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
357 root_hash.copy_from_slice(&payload[4..32]);
358 Ok(SyncMessage::PullResponse {
359 root_page,
360 root_hash,
361 })
362 }
363 MSG_TABLE_LIST_REQUEST => Ok(SyncMessage::TableListRequest),
364 MSG_TABLE_LIST_RESPONSE => {
365 ensure_len(payload, 4, "TableListResponse")?;
366 let count = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
367 let mut pos = 4;
368 let mut tables = Vec::with_capacity(count);
369 for _ in 0..count {
370 ensure_len(payload, pos + 2, "TableInfo name_len")?;
371 let name_len =
372 u16::from_le_bytes(payload[pos..pos + 2].try_into().unwrap()) as usize;
373 pos += 2;
374 ensure_len(payload, pos + name_len + 4 + MERKLE_HASH_SIZE, "TableInfo")?;
375 let name = payload[pos..pos + name_len].to_vec();
376 pos += name_len;
377 let root_page = PageId(u32::from_le_bytes(
378 payload[pos..pos + 4].try_into().unwrap(),
379 ));
380 pos += 4;
381 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
382 root_hash.copy_from_slice(&payload[pos..pos + MERKLE_HASH_SIZE]);
383 pos += MERKLE_HASH_SIZE;
384 tables.push(TableInfo {
385 name,
386 root_page,
387 root_hash,
388 });
389 }
390 Ok(SyncMessage::TableListResponse { tables })
391 }
392 MSG_TABLE_SYNC_BEGIN => {
393 ensure_len(payload, 2, "TableSyncBegin")?;
394 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
395 ensure_len(
396 payload,
397 2 + name_len + 4 + MERKLE_HASH_SIZE,
398 "TableSyncBegin",
399 )?;
400 let table_name = payload[2..2 + name_len].to_vec();
401 let off = 2 + name_len;
402 let root_page = PageId(u32::from_le_bytes(
403 payload[off..off + 4].try_into().unwrap(),
404 ));
405 let mut root_hash = [0u8; MERKLE_HASH_SIZE];
406 root_hash.copy_from_slice(&payload[off + 4..off + 4 + MERKLE_HASH_SIZE]);
407 Ok(SyncMessage::TableSyncBegin {
408 table_name,
409 root_page,
410 root_hash,
411 })
412 }
413 MSG_TABLE_SYNC_END => {
414 ensure_len(payload, 2, "TableSyncEnd")?;
415 let name_len = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
416 ensure_len(payload, 2 + name_len, "TableSyncEnd")?;
417 let table_name = payload[2..2 + name_len].to_vec();
418 Ok(SyncMessage::TableSyncEnd { table_name })
419 }
420 _ => Err(ProtocolError::UnknownMessageType(msg_type)),
421 }
422 }
423}
424
425fn ensure_len(data: &[u8], needed: usize, ctx: &str) -> Result<(), ProtocolError> {
426 if data.len() < needed {
427 Err(ProtocolError::Truncated {
428 context: ctx.to_string(),
429 expected: needed,
430 actual: data.len(),
431 })
432 } else {
433 Ok(())
434 }
435}
436
437fn serialize_page_digest(buf: &mut Vec<u8>, d: &PageDigest) {
438 buf.extend_from_slice(&d.page_id.0.to_le_bytes());
439 buf.extend_from_slice(&(d.page_type as u16).to_le_bytes());
440 buf.extend_from_slice(&d.merkle_hash);
441 buf.extend_from_slice(&(d.children.len() as u32).to_le_bytes());
442 for c in &d.children {
443 buf.extend_from_slice(&c.0.to_le_bytes());
444 }
445}
446
447fn deserialize_page_digest(
448 data: &[u8],
449 offset: usize,
450) -> Result<(PageDigest, usize), ProtocolError> {
451 let min = 38;
453 if data.len() < offset + min {
454 return Err(ProtocolError::Truncated {
455 context: "PageDigest header".to_string(),
456 expected: offset + min,
457 actual: data.len(),
458 });
459 }
460
461 let page_id = PageId(u32::from_le_bytes(
462 data[offset..offset + 4].try_into().unwrap(),
463 ));
464 let page_type_raw = u16::from_le_bytes(data[offset + 4..offset + 6].try_into().unwrap());
465 let page_type = citadel_core::types::PageType::from_u16(page_type_raw)
466 .unwrap_or(citadel_core::types::PageType::Leaf);
467 let mut merkle_hash = [0u8; MERKLE_HASH_SIZE];
468 merkle_hash.copy_from_slice(&data[offset + 6..offset + 34]);
469 let child_count =
470 u32::from_le_bytes(data[offset + 34..offset + 38].try_into().unwrap()) as usize;
471
472 if data.len() < offset + min + child_count * 4 {
473 return Err(ProtocolError::Truncated {
474 context: "PageDigest children".to_string(),
475 expected: offset + min + child_count * 4,
476 actual: data.len(),
477 });
478 }
479
480 let children = (0..child_count)
481 .map(|i| {
482 let off = offset + 38 + i * 4;
483 PageId(u32::from_le_bytes(data[off..off + 4].try_into().unwrap()))
484 })
485 .collect();
486
487 Ok((
488 PageDigest {
489 page_id,
490 page_type,
491 merkle_hash,
492 children,
493 },
494 min + child_count * 4,
495 ))
496}
497
498fn serialize_diff_entry(buf: &mut Vec<u8>, e: &DiffEntry) {
499 buf.extend_from_slice(&(e.key.len() as u16).to_le_bytes());
500 buf.extend_from_slice(&(e.value.len() as u32).to_le_bytes());
501 buf.push(e.val_type);
502 buf.extend_from_slice(&e.key);
503 buf.extend_from_slice(&e.value);
504}
505
506fn deserialize_diff_entry(data: &[u8], offset: usize) -> Result<(DiffEntry, usize), ProtocolError> {
507 let header = 7;
509 if data.len() < offset + header {
510 return Err(ProtocolError::Truncated {
511 context: "DiffEntry header".to_string(),
512 expected: offset + header,
513 actual: data.len(),
514 });
515 }
516
517 let key_len = u16::from_le_bytes(data[offset..offset + 2].try_into().unwrap()) as usize;
518 let val_len = u32::from_le_bytes(data[offset + 2..offset + 6].try_into().unwrap()) as usize;
519 let val_type = data[offset + 6];
520
521 let total = header + key_len + val_len;
522 if data.len() < offset + total {
523 return Err(ProtocolError::Truncated {
524 context: "DiffEntry data".to_string(),
525 expected: offset + total,
526 actual: data.len(),
527 });
528 }
529
530 let key = data[offset + 7..offset + 7 + key_len].to_vec();
531 let value = data[offset + 7 + key_len..offset + 7 + key_len + val_len].to_vec();
532
533 Ok((
534 DiffEntry {
535 key,
536 value,
537 val_type,
538 },
539 total,
540 ))
541}
542
543#[cfg(test)]
544#[path = "protocol_tests.rs"]
545mod tests;