1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16 Push,
18 Pull,
20 Bidirectional,
22}
23
24#[derive(Debug, Clone)]
26pub struct SyncConfig {
27 pub node_id: NodeId,
28 pub direction: SyncDirection,
29 pub crdt_aware: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35 pub pushed: Option<ApplyResult>,
37 pub pulled: Option<ApplyResult>,
39 pub already_in_sync: bool,
41}
42
43pub struct SyncSession {
49 config: SyncConfig,
50}
51
52impl SyncSession {
53 pub fn new(config: SyncConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn sync_as_initiator(
59 &self,
60 manager: &TxnManager,
61 transport: &dyn SyncTransport,
62 ) -> std::result::Result<SyncOutcome, SyncError> {
63 let local_reader = LocalTreeReader::new(manager);
64 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66 transport.send(&SyncMessage::Hello {
67 node_id: self.config.node_id,
68 root_page: local_root,
69 root_hash: local_hash,
70 })?;
71
72 let (remote_root, remote_hash, in_sync) = match transport.recv()? {
73 SyncMessage::HelloAck {
74 root_page,
75 root_hash,
76 in_sync,
77 ..
78 } => (root_page, root_hash, in_sync),
79 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
80 other => {
81 return Err(SyncError::UnexpectedMessage {
82 expected: "HelloAck".into(),
83 actual: msg_name(&other).into(),
84 })
85 }
86 };
87
88 if in_sync {
89 transport.send(&SyncMessage::Done)?;
90 return Ok(SyncOutcome {
91 pushed: None,
92 pulled: None,
93 already_in_sync: true,
94 });
95 }
96
97 let mut outcome = SyncOutcome {
98 pushed: None,
99 pulled: None,
100 already_in_sync: false,
101 };
102
103 if self.config.direction == SyncDirection::Push
105 || self.config.direction == SyncDirection::Bidirectional
106 {
107 let result = self.initiator_push(manager, transport, remote_root, remote_hash)?;
108 outcome.pushed = Some(result);
109 }
110
111 if self.config.direction == SyncDirection::Pull
113 || self.config.direction == SyncDirection::Bidirectional
114 {
115 let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
117 transport.send(&SyncMessage::PullRequest)?;
118 match transport.recv()? {
119 SyncMessage::PullResponse {
120 root_page,
121 root_hash,
122 } => (root_page, root_hash),
123 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
124 other => {
125 return Err(SyncError::UnexpectedMessage {
126 expected: "PullResponse".into(),
127 actual: msg_name(&other).into(),
128 })
129 }
130 }
131 } else {
132 (remote_root, remote_hash)
133 };
134
135 let result = self.initiator_pull(manager, transport, pull_root, pull_hash)?;
136 outcome.pulled = Some(result);
137 }
138
139 transport.send(&SyncMessage::Done)?;
140 Ok(outcome)
141 }
142
143 pub fn sync_as_responder(
145 &self,
146 manager: &TxnManager,
147 transport: &dyn SyncTransport,
148 ) -> std::result::Result<SyncOutcome, SyncError> {
149 let local_reader = LocalTreeReader::new(manager);
150 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
151
152 let remote_hash = match transport.recv()? {
153 SyncMessage::Hello { root_hash, .. } => root_hash,
154 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
155 other => {
156 return Err(SyncError::UnexpectedMessage {
157 expected: "Hello".into(),
158 actual: msg_name(&other).into(),
159 })
160 }
161 };
162
163 let in_sync = local_hash == remote_hash;
164
165 transport.send(&SyncMessage::HelloAck {
166 node_id: self.config.node_id,
167 root_page: local_root,
168 root_hash: local_hash,
169 in_sync,
170 })?;
171
172 if in_sync {
173 let _ = transport.recv()?;
174 return Ok(SyncOutcome {
175 pushed: None,
176 pulled: None,
177 already_in_sync: true,
178 });
179 }
180
181 let mut outcome = SyncOutcome {
182 pushed: None,
183 pulled: None,
184 already_in_sync: false,
185 };
186
187 loop {
188 let msg = transport.recv()?;
189 match msg {
190 SyncMessage::DigestRequest { page_ids } => {
191 let reader = LocalTreeReader::new(manager);
192 let mut digests = Vec::with_capacity(page_ids.len());
193 for pid in &page_ids {
194 match reader.page_digest(*pid) {
195 Ok(d) => digests.push(d),
196 Err(e) => {
197 transport.send(&SyncMessage::Error {
198 message: e.to_string(),
199 })?;
200 continue;
201 }
202 }
203 }
204 transport.send(&SyncMessage::DigestResponse { digests })?;
205 }
206 SyncMessage::EntriesRequest { page_ids } => {
207 let reader = LocalTreeReader::new(manager);
208 let mut entries = Vec::new();
209 for pid in &page_ids {
210 match reader.leaf_entries(*pid) {
211 Ok(e) => entries.extend(e),
212 Err(e) => {
213 transport.send(&SyncMessage::Error {
214 message: e.to_string(),
215 })?;
216 continue;
217 }
218 }
219 }
220 transport.send(&SyncMessage::EntriesResponse { entries })?;
221 }
222 SyncMessage::PatchData { data } => {
223 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
224 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
225 outcome.pushed = Some(result.clone());
226 transport.send(&SyncMessage::PatchAck { result })?;
227 }
228 SyncMessage::PullRequest => {
229 let reader = LocalTreeReader::new(manager);
230 let (root_page, root_hash) = reader.root_info().map_err(SyncError::Database)?;
231 transport.send(&SyncMessage::PullResponse {
232 root_page,
233 root_hash,
234 })?;
235 }
236 SyncMessage::Done => {
237 break;
238 }
239 SyncMessage::Error { message } => {
240 return Err(SyncError::Remote(message));
241 }
242 _ => {
243 transport.send(&SyncMessage::Error {
244 message: "unexpected message".into(),
245 })?;
246 }
247 }
248 }
249
250 Ok(outcome)
251 }
252
253 fn initiator_push(
255 &self,
256 manager: &TxnManager,
257 transport: &dyn SyncTransport,
258 remote_root: PageId,
259 remote_hash: MerkleHash,
260 ) -> std::result::Result<ApplyResult, SyncError> {
261 let local_reader = LocalTreeReader::new(manager);
262 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
263
264 let diff = merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
266
267 if diff.is_empty() {
268 return Ok(ApplyResult::empty());
269 }
270
271 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
272 let patch_data = patch.serialize();
273
274 transport.send(&SyncMessage::PatchData { data: patch_data })?;
275
276 match transport.recv()? {
277 SyncMessage::PatchAck { result } => Ok(result),
278 SyncMessage::Error { message } => Err(SyncError::Remote(message)),
279 other => Err(SyncError::UnexpectedMessage {
280 expected: "PatchAck".into(),
281 actual: msg_name(&other).into(),
282 }),
283 }
284 }
285
286 pub fn sync_tables_as_initiator(
288 &self,
289 manager: &TxnManager,
290 transport: &dyn SyncTransport,
291 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
292 transport.send(&SyncMessage::TableListRequest)?;
293
294 let remote_tables = match transport.recv()? {
295 SyncMessage::TableListResponse { tables } => tables,
296 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
297 other => {
298 return Err(SyncError::UnexpectedMessage {
299 expected: "TableListResponse".into(),
300 actual: msg_name(&other).into(),
301 })
302 }
303 };
304
305 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
306
307 let mut all_names: Vec<Vec<u8>> = Vec::new();
308 for (name, _) in &local_tables {
309 if !name.starts_with(b"__idx_") && !all_names.contains(name) {
310 all_names.push(name.clone());
311 }
312 }
313 for info in &remote_tables {
314 if !info.name.starts_with(b"__idx_") && !all_names.contains(&info.name) {
315 all_names.push(info.name.clone());
316 }
317 }
318
319 let mut results = Vec::new();
320
321 for table_name in &all_names {
322 let local_info = local_tables.iter().find(|(n, _)| n == table_name);
323 let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
324
325 let local_root = local_info
326 .map(|(_, desc)| desc.root_page)
327 .unwrap_or(PageId::INVALID);
328 let local_hash = if local_root.is_valid() {
329 manager
330 .read_page_from_disk(local_root)
331 .map(|p| p.merkle_hash())
332 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
333 } else {
334 [0u8; citadel_core::MERKLE_HASH_SIZE]
335 };
336
337 let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
338 let remote_hash = remote_info
339 .map(|t| t.root_hash)
340 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
341
342 if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
343 continue;
344 }
345
346 transport.send(&SyncMessage::TableSyncBegin {
347 table_name: table_name.clone(),
348 root_page: local_root,
349 root_hash: local_hash,
350 })?;
351
352 if local_root.is_valid() && remote_root.is_valid() {
353 let local_reader =
354 LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
355 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
356 let diff =
357 merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
358
359 if !diff.is_empty() {
360 let patch =
361 SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
362 transport.send(&SyncMessage::PatchData {
363 data: patch.serialize(),
364 })?;
365 match transport.recv()? {
366 SyncMessage::PatchAck { result } => {
367 results.push((table_name.clone(), result));
368 }
369 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
370 other => {
371 return Err(SyncError::UnexpectedMessage {
372 expected: "PatchAck".into(),
373 actual: msg_name(&other).into(),
374 })
375 }
376 }
377 }
378 } else if local_root.is_valid() {
379 let local_reader =
380 LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
381 let entries = local_reader
382 .subtree_entries(local_root)
383 .map_err(SyncError::Database)?;
384 if !entries.is_empty() {
385 let diff = crate::diff::DiffResult {
386 entries,
387 pages_compared: 0,
388 subtrees_skipped: 0,
389 };
390 let patch =
391 SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
392 transport.send(&SyncMessage::PatchData {
393 data: patch.serialize(),
394 })?;
395 match transport.recv()? {
396 SyncMessage::PatchAck { result } => {
397 results.push((table_name.clone(), result));
398 }
399 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
400 other => {
401 return Err(SyncError::UnexpectedMessage {
402 expected: "PatchAck".into(),
403 actual: msg_name(&other).into(),
404 })
405 }
406 }
407 }
408 }
409
410 transport.send(&SyncMessage::TableSyncEnd {
411 table_name: table_name.clone(),
412 })?;
413 }
414
415 transport.send(&SyncMessage::Done)?;
416 Ok(results)
417 }
418
419 pub fn handle_table_sync_as_responder(
421 &self,
422 manager: &TxnManager,
423 transport: &dyn SyncTransport,
424 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
425 match transport.recv()? {
426 SyncMessage::TableListRequest => {}
427 SyncMessage::Done => return Ok(Vec::new()),
428 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
429 other => {
430 return Err(SyncError::UnexpectedMessage {
431 expected: "TableListRequest".into(),
432 actual: msg_name(&other).into(),
433 })
434 }
435 }
436
437 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
438 let table_infos: Vec<TableInfo> = local_tables
439 .iter()
440 .filter(|(name, _)| !name.starts_with(b"__idx_"))
441 .filter_map(|(name, desc)| {
442 if desc.root_page.is_valid() {
443 let hash = manager
444 .read_page_from_disk(desc.root_page)
445 .map(|p| p.merkle_hash())
446 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
447 Some(TableInfo {
448 name: name.clone(),
449 root_page: desc.root_page,
450 root_hash: hash,
451 })
452 } else {
453 None
454 }
455 })
456 .collect();
457 transport.send(&SyncMessage::TableListResponse {
458 tables: table_infos,
459 })?;
460
461 let mut results = Vec::new();
462 let mut current_table: Option<Vec<u8>> = None;
463
464 loop {
465 let msg = transport.recv()?;
466 match msg {
467 SyncMessage::TableSyncBegin { table_name, .. } => {
468 current_table = Some(table_name);
469 }
470 SyncMessage::TableSyncEnd { .. } => {
471 current_table = None;
472 }
473 SyncMessage::DigestRequest { page_ids } => {
474 let reader = if let Some(ref tname) = current_table {
475 let root = manager.table_root(tname).map_err(SyncError::Database)?;
476 if let Some(r) = root {
477 LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
478 } else {
479 LocalTreeReader::new(manager)
480 }
481 } else {
482 LocalTreeReader::new(manager)
483 };
484
485 let mut digests = Vec::with_capacity(page_ids.len());
486 for pid in &page_ids {
487 match reader.page_digest(*pid) {
488 Ok(d) => digests.push(d),
489 Err(e) => {
490 transport.send(&SyncMessage::Error {
491 message: e.to_string(),
492 })?;
493 continue;
494 }
495 }
496 }
497 transport.send(&SyncMessage::DigestResponse { digests })?;
498 }
499 SyncMessage::EntriesRequest { page_ids } => {
500 let reader = if let Some(ref tname) = current_table {
501 let root = manager.table_root(tname).map_err(SyncError::Database)?;
502 if let Some(r) = root {
503 LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
504 } else {
505 LocalTreeReader::new(manager)
506 }
507 } else {
508 LocalTreeReader::new(manager)
509 };
510
511 let mut entries = Vec::new();
512 for pid in &page_ids {
513 match reader.leaf_entries(*pid) {
514 Ok(e) => entries.extend(e),
515 Err(e) => {
516 transport.send(&SyncMessage::Error {
517 message: e.to_string(),
518 })?;
519 continue;
520 }
521 }
522 }
523 transport.send(&SyncMessage::EntriesResponse { entries })?;
524 }
525 SyncMessage::PatchData { data } => {
526 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
527 let result = if let Some(ref tname) = current_table {
528 apply_patch_to_table(manager, tname, &patch).map_err(SyncError::Database)?
529 } else {
530 apply_patch(manager, &patch).map_err(SyncError::Database)?
531 };
532 if let Some(ref tname) = current_table {
533 results.push((tname.clone(), result.clone()));
534 }
535 transport.send(&SyncMessage::PatchAck { result })?;
536 }
537 SyncMessage::Done => break,
538 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
539 _ => {
540 transport.send(&SyncMessage::Error {
541 message: "unexpected message in table sync".into(),
542 })?;
543 }
544 }
545 }
546
547 Ok(results)
548 }
549
550 fn initiator_pull(
552 &self,
553 manager: &TxnManager,
554 transport: &dyn SyncTransport,
555 remote_root: PageId,
556 remote_hash: MerkleHash,
557 ) -> std::result::Result<ApplyResult, SyncError> {
558 let local_reader = LocalTreeReader::new(manager);
559 let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
560
561 if local_hash == remote_hash {
562 return Ok(ApplyResult::empty());
563 }
564
565 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
566
567 let diff = merkle_diff(&remote_reader, &local_reader).map_err(SyncError::Database)?;
569
570 if diff.is_empty() {
571 return Ok(ApplyResult::empty());
572 }
573
574 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
575 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
576 Ok(result)
577 }
578}