1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16 Push,
18 Pull,
20 Bidirectional,
22}
23
24#[derive(Debug, Clone)]
26pub struct SyncConfig {
27 pub node_id: NodeId,
28 pub direction: SyncDirection,
29 pub crdt_aware: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35 pub pushed: Option<ApplyResult>,
37 pub pulled: Option<ApplyResult>,
39 pub already_in_sync: bool,
41}
42
43pub struct SyncSession {
49 config: SyncConfig,
50}
51
52impl SyncSession {
53 pub fn new(config: SyncConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn sync_as_initiator(
59 &self,
60 manager: &TxnManager,
61 transport: &dyn SyncTransport,
62 ) -> std::result::Result<SyncOutcome, SyncError> {
63 let local_reader = LocalTreeReader::new(manager);
64 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66 transport.send(&SyncMessage::Hello {
68 node_id: self.config.node_id,
69 root_page: local_root,
70 root_hash: local_hash,
71 })?;
72
73 let (remote_root, remote_hash, in_sync) = match transport.recv()? {
74 SyncMessage::HelloAck {
75 root_page,
76 root_hash,
77 in_sync,
78 ..
79 } => (root_page, root_hash, in_sync),
80 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
81 other => {
82 return Err(SyncError::UnexpectedMessage {
83 expected: "HelloAck".into(),
84 actual: msg_name(&other).into(),
85 })
86 }
87 };
88
89 if in_sync {
90 transport.send(&SyncMessage::Done)?;
91 return Ok(SyncOutcome {
92 pushed: None,
93 pulled: None,
94 already_in_sync: true,
95 });
96 }
97
98 let mut outcome = SyncOutcome {
99 pushed: None,
100 pulled: None,
101 already_in_sync: false,
102 };
103
104 if self.config.direction == SyncDirection::Push
106 || self.config.direction == SyncDirection::Bidirectional
107 {
108 let result = self.initiator_push(manager, transport, remote_root, remote_hash)?;
109 outcome.pushed = Some(result);
110 }
111
112 if self.config.direction == SyncDirection::Pull
114 || self.config.direction == SyncDirection::Bidirectional
115 {
116 let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
118 transport.send(&SyncMessage::PullRequest)?;
119 match transport.recv()? {
120 SyncMessage::PullResponse {
121 root_page,
122 root_hash,
123 } => (root_page, root_hash),
124 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
125 other => {
126 return Err(SyncError::UnexpectedMessage {
127 expected: "PullResponse".into(),
128 actual: msg_name(&other).into(),
129 })
130 }
131 }
132 } else {
133 (remote_root, remote_hash)
134 };
135
136 let result = self.initiator_pull(manager, transport, pull_root, pull_hash)?;
137 outcome.pulled = Some(result);
138 }
139
140 transport.send(&SyncMessage::Done)?;
141 Ok(outcome)
142 }
143
144 pub fn sync_as_responder(
146 &self,
147 manager: &TxnManager,
148 transport: &dyn SyncTransport,
149 ) -> std::result::Result<SyncOutcome, SyncError> {
150 let local_reader = LocalTreeReader::new(manager);
151 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
152
153 let remote_hash = match transport.recv()? {
155 SyncMessage::Hello { root_hash, .. } => root_hash,
156 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
157 other => {
158 return Err(SyncError::UnexpectedMessage {
159 expected: "Hello".into(),
160 actual: msg_name(&other).into(),
161 })
162 }
163 };
164
165 let in_sync = local_hash == remote_hash;
166
167 transport.send(&SyncMessage::HelloAck {
168 node_id: self.config.node_id,
169 root_page: local_root,
170 root_hash: local_hash,
171 in_sync,
172 })?;
173
174 if in_sync {
175 let _ = transport.recv()?;
176 return Ok(SyncOutcome {
177 pushed: None,
178 pulled: None,
179 already_in_sync: true,
180 });
181 }
182
183 let mut outcome = SyncOutcome {
184 pushed: None,
185 pulled: None,
186 already_in_sync: false,
187 };
188
189 loop {
191 let msg = transport.recv()?;
192 match msg {
193 SyncMessage::DigestRequest { page_ids } => {
194 let reader = LocalTreeReader::new(manager);
195 let mut digests = Vec::with_capacity(page_ids.len());
196 for pid in &page_ids {
197 match reader.page_digest(*pid) {
198 Ok(d) => digests.push(d),
199 Err(e) => {
200 transport.send(&SyncMessage::Error {
201 message: e.to_string(),
202 })?;
203 continue;
204 }
205 }
206 }
207 transport.send(&SyncMessage::DigestResponse { digests })?;
208 }
209 SyncMessage::EntriesRequest { page_ids } => {
210 let reader = LocalTreeReader::new(manager);
211 let mut entries = Vec::new();
212 for pid in &page_ids {
213 match reader.leaf_entries(*pid) {
214 Ok(e) => entries.extend(e),
215 Err(e) => {
216 transport.send(&SyncMessage::Error {
217 message: e.to_string(),
218 })?;
219 continue;
220 }
221 }
222 }
223 transport.send(&SyncMessage::EntriesResponse { entries })?;
224 }
225 SyncMessage::PatchData { data } => {
226 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
227 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
228 outcome.pushed = Some(result.clone());
229 transport.send(&SyncMessage::PatchAck { result })?;
230 }
231 SyncMessage::PullRequest => {
232 let reader = LocalTreeReader::new(manager);
233 let (root_page, root_hash) = reader.root_info().map_err(SyncError::Database)?;
234 transport.send(&SyncMessage::PullResponse {
235 root_page,
236 root_hash,
237 })?;
238 }
239 SyncMessage::Done => {
240 break;
241 }
242 SyncMessage::Error { message } => {
243 return Err(SyncError::Remote(message));
244 }
245 _ => {
246 transport.send(&SyncMessage::Error {
247 message: "unexpected message".into(),
248 })?;
249 }
250 }
251 }
252
253 Ok(outcome)
254 }
255
256 fn initiator_push(
258 &self,
259 manager: &TxnManager,
260 transport: &dyn SyncTransport,
261 remote_root: PageId,
262 remote_hash: MerkleHash,
263 ) -> std::result::Result<ApplyResult, SyncError> {
264 let local_reader = LocalTreeReader::new(manager);
265 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
266
267 let diff = merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
269
270 if diff.is_empty() {
271 return Ok(ApplyResult::empty());
272 }
273
274 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
275 let patch_data = patch.serialize();
276
277 transport.send(&SyncMessage::PatchData { data: patch_data })?;
278
279 match transport.recv()? {
280 SyncMessage::PatchAck { result } => Ok(result),
281 SyncMessage::Error { message } => Err(SyncError::Remote(message)),
282 other => Err(SyncError::UnexpectedMessage {
283 expected: "PatchAck".into(),
284 actual: msg_name(&other).into(),
285 }),
286 }
287 }
288
289 pub fn sync_tables_as_initiator(
291 &self,
292 manager: &TxnManager,
293 transport: &dyn SyncTransport,
294 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
295 transport.send(&SyncMessage::TableListRequest)?;
296
297 let remote_tables = match transport.recv()? {
298 SyncMessage::TableListResponse { tables } => tables,
299 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
300 other => {
301 return Err(SyncError::UnexpectedMessage {
302 expected: "TableListResponse".into(),
303 actual: msg_name(&other).into(),
304 })
305 }
306 };
307
308 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
309
310 let mut all_names: Vec<Vec<u8>> = Vec::new();
311 for (name, _) in &local_tables {
312 if !name.starts_with(b"__idx_") && !all_names.contains(name) {
313 all_names.push(name.clone());
314 }
315 }
316 for info in &remote_tables {
317 if !info.name.starts_with(b"__idx_") && !all_names.contains(&info.name) {
318 all_names.push(info.name.clone());
319 }
320 }
321
322 let mut results = Vec::new();
323
324 for table_name in &all_names {
325 let local_info = local_tables.iter().find(|(n, _)| n == table_name);
326 let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
327
328 let local_root = local_info
329 .map(|(_, desc)| desc.root_page)
330 .unwrap_or(PageId::INVALID);
331 let local_hash = if local_root.is_valid() {
332 manager
333 .read_page_from_disk(local_root)
334 .map(|p| p.merkle_hash())
335 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
336 } else {
337 [0u8; citadel_core::MERKLE_HASH_SIZE]
338 };
339
340 let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
341 let remote_hash = remote_info
342 .map(|t| t.root_hash)
343 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
344
345 if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
346 continue;
347 }
348
349 transport.send(&SyncMessage::TableSyncBegin {
350 table_name: table_name.clone(),
351 root_page: local_root,
352 root_hash: local_hash,
353 })?;
354
355 if local_root.is_valid() && remote_root.is_valid() {
356 let local_reader =
357 LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
358 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
359 let diff =
360 merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
361
362 if !diff.is_empty() {
363 let patch =
364 SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
365 transport.send(&SyncMessage::PatchData {
366 data: patch.serialize(),
367 })?;
368 match transport.recv()? {
369 SyncMessage::PatchAck { result } => {
370 results.push((table_name.clone(), result));
371 }
372 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
373 other => {
374 return Err(SyncError::UnexpectedMessage {
375 expected: "PatchAck".into(),
376 actual: msg_name(&other).into(),
377 })
378 }
379 }
380 }
381 } else if local_root.is_valid() {
382 let local_reader =
383 LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
384 let entries = local_reader
385 .subtree_entries(local_root)
386 .map_err(SyncError::Database)?;
387 if !entries.is_empty() {
388 let diff = crate::diff::DiffResult {
389 entries,
390 pages_compared: 0,
391 subtrees_skipped: 0,
392 };
393 let patch =
394 SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
395 transport.send(&SyncMessage::PatchData {
396 data: patch.serialize(),
397 })?;
398 match transport.recv()? {
399 SyncMessage::PatchAck { result } => {
400 results.push((table_name.clone(), result));
401 }
402 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
403 other => {
404 return Err(SyncError::UnexpectedMessage {
405 expected: "PatchAck".into(),
406 actual: msg_name(&other).into(),
407 })
408 }
409 }
410 }
411 }
412
413 transport.send(&SyncMessage::TableSyncEnd {
414 table_name: table_name.clone(),
415 })?;
416 }
417
418 transport.send(&SyncMessage::Done)?;
419 Ok(results)
420 }
421
422 pub fn handle_table_sync_as_responder(
424 &self,
425 manager: &TxnManager,
426 transport: &dyn SyncTransport,
427 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
428 match transport.recv()? {
429 SyncMessage::TableListRequest => {}
430 SyncMessage::Done => return Ok(Vec::new()),
431 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
432 other => {
433 return Err(SyncError::UnexpectedMessage {
434 expected: "TableListRequest".into(),
435 actual: msg_name(&other).into(),
436 })
437 }
438 }
439
440 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
441 let table_infos: Vec<TableInfo> = local_tables
442 .iter()
443 .filter(|(name, _)| !name.starts_with(b"__idx_"))
444 .filter_map(|(name, desc)| {
445 if desc.root_page.is_valid() {
446 let hash = manager
447 .read_page_from_disk(desc.root_page)
448 .map(|p| p.merkle_hash())
449 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
450 Some(TableInfo {
451 name: name.clone(),
452 root_page: desc.root_page,
453 root_hash: hash,
454 })
455 } else {
456 None
457 }
458 })
459 .collect();
460 transport.send(&SyncMessage::TableListResponse {
461 tables: table_infos,
462 })?;
463
464 let mut results = Vec::new();
465 let mut current_table: Option<Vec<u8>> = None;
466
467 loop {
468 let msg = transport.recv()?;
469 match msg {
470 SyncMessage::TableSyncBegin { table_name, .. } => {
471 current_table = Some(table_name);
472 }
473 SyncMessage::TableSyncEnd { .. } => {
474 current_table = None;
475 }
476 SyncMessage::DigestRequest { page_ids } => {
477 let reader = if let Some(ref tname) = current_table {
478 let root = manager.table_root(tname).map_err(SyncError::Database)?;
479 if let Some(r) = root {
480 LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
481 } else {
482 LocalTreeReader::new(manager)
483 }
484 } else {
485 LocalTreeReader::new(manager)
486 };
487
488 let mut digests = Vec::with_capacity(page_ids.len());
489 for pid in &page_ids {
490 match reader.page_digest(*pid) {
491 Ok(d) => digests.push(d),
492 Err(e) => {
493 transport.send(&SyncMessage::Error {
494 message: e.to_string(),
495 })?;
496 continue;
497 }
498 }
499 }
500 transport.send(&SyncMessage::DigestResponse { digests })?;
501 }
502 SyncMessage::EntriesRequest { page_ids } => {
503 let reader = if let Some(ref tname) = current_table {
504 let root = manager.table_root(tname).map_err(SyncError::Database)?;
505 if let Some(r) = root {
506 LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
507 } else {
508 LocalTreeReader::new(manager)
509 }
510 } else {
511 LocalTreeReader::new(manager)
512 };
513
514 let mut entries = Vec::new();
515 for pid in &page_ids {
516 match reader.leaf_entries(*pid) {
517 Ok(e) => entries.extend(e),
518 Err(e) => {
519 transport.send(&SyncMessage::Error {
520 message: e.to_string(),
521 })?;
522 continue;
523 }
524 }
525 }
526 transport.send(&SyncMessage::EntriesResponse { entries })?;
527 }
528 SyncMessage::PatchData { data } => {
529 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
530 let result = if let Some(ref tname) = current_table {
531 apply_patch_to_table(manager, tname, &patch).map_err(SyncError::Database)?
532 } else {
533 apply_patch(manager, &patch).map_err(SyncError::Database)?
534 };
535 if let Some(ref tname) = current_table {
536 results.push((tname.clone(), result.clone()));
537 }
538 transport.send(&SyncMessage::PatchAck { result })?;
539 }
540 SyncMessage::Done => break,
541 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
542 _ => {
543 transport.send(&SyncMessage::Error {
544 message: "unexpected message in table sync".into(),
545 })?;
546 }
547 }
548 }
549
550 Ok(results)
551 }
552
553 fn initiator_pull(
555 &self,
556 manager: &TxnManager,
557 transport: &dyn SyncTransport,
558 remote_root: PageId,
559 remote_hash: MerkleHash,
560 ) -> std::result::Result<ApplyResult, SyncError> {
561 let local_reader = LocalTreeReader::new(manager);
562 let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
563
564 if local_hash == remote_hash {
565 return Ok(ApplyResult::empty());
566 }
567
568 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
569
570 let diff = merkle_diff(&remote_reader, &local_reader).map_err(SyncError::Database)?;
572
573 if diff.is_empty() {
574 return Ok(ApplyResult::empty());
575 }
576
577 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
578 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
579 Ok(result)
580 }
581}