1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16 Push,
18 Pull,
20 Bidirectional,
22}
23
24#[derive(Debug, Clone)]
26pub struct SyncConfig {
27 pub node_id: NodeId,
28 pub direction: SyncDirection,
29 pub crdt_aware: bool,
30}
31
32#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35 pub pushed: Option<ApplyResult>,
37 pub pulled: Option<ApplyResult>,
39 pub already_in_sync: bool,
41}
42
43pub struct SyncSession {
49 config: SyncConfig,
50}
51
52impl SyncSession {
53 pub fn new(config: SyncConfig) -> Self {
54 Self { config }
55 }
56
57 pub fn sync_as_initiator(
59 &self,
60 manager: &TxnManager,
61 transport: &dyn SyncTransport,
62 ) -> std::result::Result<SyncOutcome, SyncError> {
63 let local_reader = LocalTreeReader::new(manager);
64 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66 transport.send(&SyncMessage::Hello {
68 node_id: self.config.node_id,
69 root_page: local_root,
70 root_hash: local_hash,
71 })?;
72
73 let (remote_root, remote_hash, in_sync) = match transport.recv()? {
74 SyncMessage::HelloAck { root_page, root_hash, in_sync, .. } => {
75 (root_page, root_hash, in_sync)
76 }
77 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
78 other => return Err(SyncError::UnexpectedMessage {
79 expected: "HelloAck".into(),
80 actual: msg_name(&other).into(),
81 }),
82 };
83
84 if in_sync {
85 transport.send(&SyncMessage::Done)?;
86 return Ok(SyncOutcome {
87 pushed: None,
88 pulled: None,
89 already_in_sync: true,
90 });
91 }
92
93 let mut outcome = SyncOutcome {
94 pushed: None,
95 pulled: None,
96 already_in_sync: false,
97 };
98
99 if self.config.direction == SyncDirection::Push
101 || self.config.direction == SyncDirection::Bidirectional
102 {
103 let result = self.initiator_push(
104 manager, transport, remote_root, remote_hash,
105 )?;
106 outcome.pushed = Some(result);
107 }
108
109 if self.config.direction == SyncDirection::Pull
111 || self.config.direction == SyncDirection::Bidirectional
112 {
113 let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
115 transport.send(&SyncMessage::PullRequest)?;
116 match transport.recv()? {
117 SyncMessage::PullResponse { root_page, root_hash } => {
118 (root_page, root_hash)
119 }
120 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
121 other => return Err(SyncError::UnexpectedMessage {
122 expected: "PullResponse".into(),
123 actual: msg_name(&other).into(),
124 }),
125 }
126 } else {
127 (remote_root, remote_hash)
128 };
129
130 let result = self.initiator_pull(
131 manager, transport, pull_root, pull_hash,
132 )?;
133 outcome.pulled = Some(result);
134 }
135
136 transport.send(&SyncMessage::Done)?;
137 Ok(outcome)
138 }
139
140 pub fn sync_as_responder(
142 &self,
143 manager: &TxnManager,
144 transport: &dyn SyncTransport,
145 ) -> std::result::Result<SyncOutcome, SyncError> {
146 let local_reader = LocalTreeReader::new(manager);
147 let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
148
149 let remote_hash = match transport.recv()? {
151 SyncMessage::Hello { root_hash, .. } => root_hash,
152 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
153 other => return Err(SyncError::UnexpectedMessage {
154 expected: "Hello".into(),
155 actual: msg_name(&other).into(),
156 }),
157 };
158
159 let in_sync = local_hash == remote_hash;
160
161 transport.send(&SyncMessage::HelloAck {
162 node_id: self.config.node_id,
163 root_page: local_root,
164 root_hash: local_hash,
165 in_sync,
166 })?;
167
168 if in_sync {
169 match transport.recv()? {
170 SyncMessage::Done => {}
171 _ => {}
172 }
173 return Ok(SyncOutcome {
174 pushed: None,
175 pulled: None,
176 already_in_sync: true,
177 });
178 }
179
180 let mut outcome = SyncOutcome {
181 pushed: None,
182 pulled: None,
183 already_in_sync: false,
184 };
185
186 loop {
188 let msg = transport.recv()?;
189 match msg {
190 SyncMessage::DigestRequest { page_ids } => {
191 let reader = LocalTreeReader::new(manager);
192 let mut digests = Vec::with_capacity(page_ids.len());
193 for pid in &page_ids {
194 match reader.page_digest(*pid) {
195 Ok(d) => digests.push(d),
196 Err(e) => {
197 transport.send(&SyncMessage::Error {
198 message: e.to_string(),
199 })?;
200 continue;
201 }
202 }
203 }
204 transport.send(&SyncMessage::DigestResponse { digests })?;
205 }
206 SyncMessage::EntriesRequest { page_ids } => {
207 let reader = LocalTreeReader::new(manager);
208 let mut entries = Vec::new();
209 for pid in &page_ids {
210 match reader.leaf_entries(*pid) {
211 Ok(e) => entries.extend(e),
212 Err(e) => {
213 transport.send(&SyncMessage::Error {
214 message: e.to_string(),
215 })?;
216 continue;
217 }
218 }
219 }
220 transport.send(&SyncMessage::EntriesResponse { entries })?;
221 }
222 SyncMessage::PatchData { data } => {
223 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
224 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
225 outcome.pushed = Some(result.clone());
226 transport.send(&SyncMessage::PatchAck { result })?;
227 }
228 SyncMessage::PullRequest => {
229 let reader = LocalTreeReader::new(manager);
230 let (root_page, root_hash) =
231 reader.root_info().map_err(SyncError::Database)?;
232 transport.send(&SyncMessage::PullResponse { root_page, root_hash })?;
233 }
234 SyncMessage::Done => {
235 break;
236 }
237 SyncMessage::Error { message } => {
238 return Err(SyncError::Remote(message));
239 }
240 _ => {
241 transport.send(&SyncMessage::Error {
242 message: "unexpected message".into(),
243 })?;
244 }
245 }
246 }
247
248 Ok(outcome)
249 }
250
251 fn initiator_push(
253 &self,
254 manager: &TxnManager,
255 transport: &dyn SyncTransport,
256 remote_root: PageId,
257 remote_hash: MerkleHash,
258 ) -> std::result::Result<ApplyResult, SyncError> {
259 let local_reader = LocalTreeReader::new(manager);
260 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
261
262 let diff = merkle_diff(&local_reader, &remote_reader)
264 .map_err(SyncError::Database)?;
265
266 if diff.is_empty() {
267 return Ok(ApplyResult::empty());
268 }
269
270 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
271 let patch_data = patch.serialize();
272
273 transport.send(&SyncMessage::PatchData { data: patch_data })?;
274
275 match transport.recv()? {
276 SyncMessage::PatchAck { result } => Ok(result),
277 SyncMessage::Error { message } => Err(SyncError::Remote(message)),
278 other => Err(SyncError::UnexpectedMessage {
279 expected: "PatchAck".into(),
280 actual: msg_name(&other).into(),
281 }),
282 }
283 }
284
285 pub fn sync_tables_as_initiator(
287 &self,
288 manager: &TxnManager,
289 transport: &dyn SyncTransport,
290 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
291 transport.send(&SyncMessage::TableListRequest)?;
292
293 let remote_tables = match transport.recv()? {
294 SyncMessage::TableListResponse { tables } => tables,
295 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
296 other => return Err(SyncError::UnexpectedMessage {
297 expected: "TableListResponse".into(),
298 actual: msg_name(&other).into(),
299 }),
300 };
301
302 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
303
304 let mut all_names: Vec<Vec<u8>> = Vec::new();
305 for (name, _) in &local_tables {
306 if !name.starts_with(b"__idx_") {
307 if !all_names.iter().any(|n| n == name) {
308 all_names.push(name.clone());
309 }
310 }
311 }
312 for info in &remote_tables {
313 if !info.name.starts_with(b"__idx_") {
314 if !all_names.iter().any(|n| *n == info.name) {
315 all_names.push(info.name.clone());
316 }
317 }
318 }
319
320 let mut results = Vec::new();
321
322 for table_name in &all_names {
323 let local_info = local_tables.iter().find(|(n, _)| n == table_name);
324 let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
325
326 let local_root = local_info
327 .map(|(_, desc)| desc.root_page)
328 .unwrap_or(PageId::INVALID);
329 let local_hash = if local_root.is_valid() {
330 manager.read_page_from_disk(local_root)
331 .map(|p| p.merkle_hash())
332 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
333 } else {
334 [0u8; citadel_core::MERKLE_HASH_SIZE]
335 };
336
337 let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
338 let remote_hash = remote_info
339 .map(|t| t.root_hash)
340 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
341
342 if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
343 continue;
344 }
345
346 transport.send(&SyncMessage::TableSyncBegin {
347 table_name: table_name.clone(),
348 root_page: local_root,
349 root_hash: local_hash,
350 })?;
351
352 if local_root.is_valid() && remote_root.is_valid() {
353 let local_reader = LocalTreeReader::for_table(manager, local_root)
354 .map_err(SyncError::Database)?;
355 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
356 let diff = merkle_diff(&local_reader, &remote_reader)
357 .map_err(SyncError::Database)?;
358
359 if !diff.is_empty() {
360 let patch = SyncPatch::from_diff(
361 self.config.node_id, &diff, self.config.crdt_aware,
362 );
363 transport.send(&SyncMessage::PatchData { data: patch.serialize() })?;
364 match transport.recv()? {
365 SyncMessage::PatchAck { result } => {
366 results.push((table_name.clone(), result));
367 }
368 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
369 other => return Err(SyncError::UnexpectedMessage {
370 expected: "PatchAck".into(),
371 actual: msg_name(&other).into(),
372 }),
373 }
374 }
375 } else if local_root.is_valid() {
376 let local_reader = LocalTreeReader::for_table(manager, local_root)
377 .map_err(SyncError::Database)?;
378 let entries = local_reader.subtree_entries(local_root)
379 .map_err(SyncError::Database)?;
380 if !entries.is_empty() {
381 let diff = crate::diff::DiffResult {
382 entries,
383 pages_compared: 0,
384 subtrees_skipped: 0,
385 };
386 let patch = SyncPatch::from_diff(
387 self.config.node_id, &diff, self.config.crdt_aware,
388 );
389 transport.send(&SyncMessage::PatchData { data: patch.serialize() })?;
390 match transport.recv()? {
391 SyncMessage::PatchAck { result } => {
392 results.push((table_name.clone(), result));
393 }
394 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
395 other => return Err(SyncError::UnexpectedMessage {
396 expected: "PatchAck".into(),
397 actual: msg_name(&other).into(),
398 }),
399 }
400 }
401 }
402
403 transport.send(&SyncMessage::TableSyncEnd {
404 table_name: table_name.clone(),
405 })?;
406 }
407
408 transport.send(&SyncMessage::Done)?;
409 Ok(results)
410 }
411
412 pub fn handle_table_sync_as_responder(
414 &self,
415 manager: &TxnManager,
416 transport: &dyn SyncTransport,
417 ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
418 match transport.recv()? {
419 SyncMessage::TableListRequest => {}
420 SyncMessage::Done => return Ok(Vec::new()),
421 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
422 other => return Err(SyncError::UnexpectedMessage {
423 expected: "TableListRequest".into(),
424 actual: msg_name(&other).into(),
425 }),
426 }
427
428 let local_tables = manager.list_tables().map_err(SyncError::Database)?;
429 let table_infos: Vec<TableInfo> = local_tables
430 .iter()
431 .filter(|(name, _)| !name.starts_with(b"__idx_"))
432 .filter_map(|(name, desc)| {
433 if desc.root_page.is_valid() {
434 let hash = manager.read_page_from_disk(desc.root_page)
435 .map(|p| p.merkle_hash())
436 .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
437 Some(TableInfo {
438 name: name.clone(),
439 root_page: desc.root_page,
440 root_hash: hash,
441 })
442 } else {
443 None
444 }
445 })
446 .collect();
447 transport.send(&SyncMessage::TableListResponse { tables: table_infos })?;
448
449 let mut results = Vec::new();
450 let mut current_table: Option<Vec<u8>> = None;
451
452 loop {
453 let msg = transport.recv()?;
454 match msg {
455 SyncMessage::TableSyncBegin { table_name, .. } => {
456 current_table = Some(table_name);
457 }
458 SyncMessage::TableSyncEnd { .. } => {
459 current_table = None;
460 }
461 SyncMessage::DigestRequest { page_ids } => {
462 let reader = if let Some(ref tname) = current_table {
463 let root = manager.table_root(tname).map_err(SyncError::Database)?;
464 if let Some(r) = root {
465 LocalTreeReader::for_table(manager, r)
466 .map_err(SyncError::Database)?
467 } else {
468 LocalTreeReader::new(manager)
469 }
470 } else {
471 LocalTreeReader::new(manager)
472 };
473
474 let mut digests = Vec::with_capacity(page_ids.len());
475 for pid in &page_ids {
476 match reader.page_digest(*pid) {
477 Ok(d) => digests.push(d),
478 Err(e) => {
479 transport.send(&SyncMessage::Error {
480 message: e.to_string(),
481 })?;
482 continue;
483 }
484 }
485 }
486 transport.send(&SyncMessage::DigestResponse { digests })?;
487 }
488 SyncMessage::EntriesRequest { page_ids } => {
489 let reader = if let Some(ref tname) = current_table {
490 let root = manager.table_root(tname).map_err(SyncError::Database)?;
491 if let Some(r) = root {
492 LocalTreeReader::for_table(manager, r)
493 .map_err(SyncError::Database)?
494 } else {
495 LocalTreeReader::new(manager)
496 }
497 } else {
498 LocalTreeReader::new(manager)
499 };
500
501 let mut entries = Vec::new();
502 for pid in &page_ids {
503 match reader.leaf_entries(*pid) {
504 Ok(e) => entries.extend(e),
505 Err(e) => {
506 transport.send(&SyncMessage::Error {
507 message: e.to_string(),
508 })?;
509 continue;
510 }
511 }
512 }
513 transport.send(&SyncMessage::EntriesResponse { entries })?;
514 }
515 SyncMessage::PatchData { data } => {
516 let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
517 let result = if let Some(ref tname) = current_table {
518 apply_patch_to_table(manager, tname, &patch)
519 .map_err(SyncError::Database)?
520 } else {
521 apply_patch(manager, &patch).map_err(SyncError::Database)?
522 };
523 if let Some(ref tname) = current_table {
524 results.push((tname.clone(), result.clone()));
525 }
526 transport.send(&SyncMessage::PatchAck { result })?;
527 }
528 SyncMessage::Done => break,
529 SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
530 _ => {
531 transport.send(&SyncMessage::Error {
532 message: "unexpected message in table sync".into(),
533 })?;
534 }
535 }
536 }
537
538 Ok(results)
539 }
540
541 fn initiator_pull(
543 &self,
544 manager: &TxnManager,
545 transport: &dyn SyncTransport,
546 remote_root: PageId,
547 remote_hash: MerkleHash,
548 ) -> std::result::Result<ApplyResult, SyncError> {
549 let local_reader = LocalTreeReader::new(manager);
550 let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
551
552 if local_hash == remote_hash {
553 return Ok(ApplyResult::empty());
554 }
555
556 let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
557
558 let diff = merkle_diff(&remote_reader, &local_reader)
560 .map_err(SyncError::Database)?;
561
562 if diff.is_empty() {
563 return Ok(ApplyResult::empty());
564 }
565
566 let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
567 let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
568 Ok(result)
569 }
570}