Skip to main content

citadel_sync/
session.rs

1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13/// Sync direction for a session.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16    /// Push local changes to the remote peer.
17    Push,
18    /// Pull remote changes to the local database.
19    Pull,
20    /// Push then pull (full bidirectional sync).
21    Bidirectional,
22}
23
24/// Configuration for a sync session.
25#[derive(Debug, Clone)]
26pub struct SyncConfig {
27    pub node_id: NodeId,
28    pub direction: SyncDirection,
29    pub crdt_aware: bool,
30}
31
32/// Outcome of a sync session.
33#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35    /// Result of the push phase (if Push or Bidirectional).
36    pub pushed: Option<ApplyResult>,
37    /// Result of the pull phase (if Pull or Bidirectional).
38    pub pulled: Option<ApplyResult>,
39    /// True if both databases were already identical.
40    pub already_in_sync: bool,
41}
42
43/// Orchestrates a sync session between two databases.
44///
45/// The initiator drives the protocol: sends Hello, computes diffs,
46/// builds patches, and coordinates push/pull phases.
47/// The responder answers requests and applies patches.
48pub struct SyncSession {
49    config: SyncConfig,
50}
51
52impl SyncSession {
53    pub fn new(config: SyncConfig) -> Self {
54        Self { config }
55    }
56
57    /// Run as the initiator (client) side of a sync session.
58    pub fn sync_as_initiator(
59        &self,
60        manager: &TxnManager,
61        transport: &dyn SyncTransport,
62    ) -> std::result::Result<SyncOutcome, SyncError> {
63        let local_reader = LocalTreeReader::new(manager);
64        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66        // Hello exchange
67        transport.send(&SyncMessage::Hello {
68            node_id: self.config.node_id,
69            root_page: local_root,
70            root_hash: local_hash,
71        })?;
72
73        let (remote_root, remote_hash, in_sync) = match transport.recv()? {
74            SyncMessage::HelloAck {
75                root_page,
76                root_hash,
77                in_sync,
78                ..
79            } => (root_page, root_hash, in_sync),
80            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
81            other => {
82                return Err(SyncError::UnexpectedMessage {
83                    expected: "HelloAck".into(),
84                    actual: msg_name(&other).into(),
85                })
86            }
87        };
88
89        if in_sync {
90            transport.send(&SyncMessage::Done)?;
91            return Ok(SyncOutcome {
92                pushed: None,
93                pulled: None,
94                already_in_sync: true,
95            });
96        }
97
98        let mut outcome = SyncOutcome {
99            pushed: None,
100            pulled: None,
101            already_in_sync: false,
102        };
103
104        // Push phase: diff(local → remote), send patch to remote
105        if self.config.direction == SyncDirection::Push
106            || self.config.direction == SyncDirection::Bidirectional
107        {
108            let result = self.initiator_push(manager, transport, remote_root, remote_hash)?;
109            outcome.pushed = Some(result);
110        }
111
112        // Pull phase: diff(remote → local), apply patch locally
113        if self.config.direction == SyncDirection::Pull
114            || self.config.direction == SyncDirection::Bidirectional
115        {
116            // For bidirectional after push, get updated remote state
117            let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
118                transport.send(&SyncMessage::PullRequest)?;
119                match transport.recv()? {
120                    SyncMessage::PullResponse {
121                        root_page,
122                        root_hash,
123                    } => (root_page, root_hash),
124                    SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
125                    other => {
126                        return Err(SyncError::UnexpectedMessage {
127                            expected: "PullResponse".into(),
128                            actual: msg_name(&other).into(),
129                        })
130                    }
131                }
132            } else {
133                (remote_root, remote_hash)
134            };
135
136            let result = self.initiator_pull(manager, transport, pull_root, pull_hash)?;
137            outcome.pulled = Some(result);
138        }
139
140        transport.send(&SyncMessage::Done)?;
141        Ok(outcome)
142    }
143
144    /// Run as the responder (server) side of a sync session.
145    pub fn sync_as_responder(
146        &self,
147        manager: &TxnManager,
148        transport: &dyn SyncTransport,
149    ) -> std::result::Result<SyncOutcome, SyncError> {
150        let local_reader = LocalTreeReader::new(manager);
151        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
152
153        // Receive Hello
154        let remote_hash = match transport.recv()? {
155            SyncMessage::Hello { root_hash, .. } => root_hash,
156            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
157            other => {
158                return Err(SyncError::UnexpectedMessage {
159                    expected: "Hello".into(),
160                    actual: msg_name(&other).into(),
161                })
162            }
163        };
164
165        let in_sync = local_hash == remote_hash;
166
167        transport.send(&SyncMessage::HelloAck {
168            node_id: self.config.node_id,
169            root_page: local_root,
170            root_hash: local_hash,
171            in_sync,
172        })?;
173
174        if in_sync {
175            let _ = transport.recv()?;
176            return Ok(SyncOutcome {
177                pushed: None,
178                pulled: None,
179                already_in_sync: true,
180            });
181        }
182
183        let mut outcome = SyncOutcome {
184            pushed: None,
185            pulled: None,
186            already_in_sync: false,
187        };
188
189        // Serve requests until Done
190        loop {
191            let msg = transport.recv()?;
192            match msg {
193                SyncMessage::DigestRequest { page_ids } => {
194                    let reader = LocalTreeReader::new(manager);
195                    let mut digests = Vec::with_capacity(page_ids.len());
196                    for pid in &page_ids {
197                        match reader.page_digest(*pid) {
198                            Ok(d) => digests.push(d),
199                            Err(e) => {
200                                transport.send(&SyncMessage::Error {
201                                    message: e.to_string(),
202                                })?;
203                                continue;
204                            }
205                        }
206                    }
207                    transport.send(&SyncMessage::DigestResponse { digests })?;
208                }
209                SyncMessage::EntriesRequest { page_ids } => {
210                    let reader = LocalTreeReader::new(manager);
211                    let mut entries = Vec::new();
212                    for pid in &page_ids {
213                        match reader.leaf_entries(*pid) {
214                            Ok(e) => entries.extend(e),
215                            Err(e) => {
216                                transport.send(&SyncMessage::Error {
217                                    message: e.to_string(),
218                                })?;
219                                continue;
220                            }
221                        }
222                    }
223                    transport.send(&SyncMessage::EntriesResponse { entries })?;
224                }
225                SyncMessage::PatchData { data } => {
226                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
227                    let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
228                    outcome.pushed = Some(result.clone());
229                    transport.send(&SyncMessage::PatchAck { result })?;
230                }
231                SyncMessage::PullRequest => {
232                    let reader = LocalTreeReader::new(manager);
233                    let (root_page, root_hash) = reader.root_info().map_err(SyncError::Database)?;
234                    transport.send(&SyncMessage::PullResponse {
235                        root_page,
236                        root_hash,
237                    })?;
238                }
239                SyncMessage::Done => {
240                    break;
241                }
242                SyncMessage::Error { message } => {
243                    return Err(SyncError::Remote(message));
244                }
245                _ => {
246                    transport.send(&SyncMessage::Error {
247                        message: "unexpected message".into(),
248                    })?;
249                }
250            }
251        }
252
253        Ok(outcome)
254    }
255
256    /// Push: diff(local → remote) via merkle_diff, send patch.
257    fn initiator_push(
258        &self,
259        manager: &TxnManager,
260        transport: &dyn SyncTransport,
261        remote_root: PageId,
262        remote_hash: MerkleHash,
263    ) -> std::result::Result<ApplyResult, SyncError> {
264        let local_reader = LocalTreeReader::new(manager);
265        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
266
267        // source = local, target = remote
268        let diff = merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
269
270        if diff.is_empty() {
271            return Ok(ApplyResult::empty());
272        }
273
274        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
275        let patch_data = patch.serialize();
276
277        transport.send(&SyncMessage::PatchData { data: patch_data })?;
278
279        match transport.recv()? {
280            SyncMessage::PatchAck { result } => Ok(result),
281            SyncMessage::Error { message } => Err(SyncError::Remote(message)),
282            other => Err(SyncError::UnexpectedMessage {
283                expected: "PatchAck".into(),
284                actual: msg_name(&other).into(),
285            }),
286        }
287    }
288
289    /// Run multi-table sync as the initiator.
290    pub fn sync_tables_as_initiator(
291        &self,
292        manager: &TxnManager,
293        transport: &dyn SyncTransport,
294    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
295        transport.send(&SyncMessage::TableListRequest)?;
296
297        let remote_tables = match transport.recv()? {
298            SyncMessage::TableListResponse { tables } => tables,
299            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
300            other => {
301                return Err(SyncError::UnexpectedMessage {
302                    expected: "TableListResponse".into(),
303                    actual: msg_name(&other).into(),
304                })
305            }
306        };
307
308        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
309
310        let mut all_names: Vec<Vec<u8>> = Vec::new();
311        for (name, _) in &local_tables {
312            if !name.starts_with(b"__idx_") && !all_names.contains(name) {
313                all_names.push(name.clone());
314            }
315        }
316        for info in &remote_tables {
317            if !info.name.starts_with(b"__idx_") && !all_names.contains(&info.name) {
318                all_names.push(info.name.clone());
319            }
320        }
321
322        let mut results = Vec::new();
323
324        for table_name in &all_names {
325            let local_info = local_tables.iter().find(|(n, _)| n == table_name);
326            let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
327
328            let local_root = local_info
329                .map(|(_, desc)| desc.root_page)
330                .unwrap_or(PageId::INVALID);
331            let local_hash = if local_root.is_valid() {
332                manager
333                    .read_page_from_disk(local_root)
334                    .map(|p| p.merkle_hash())
335                    .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
336            } else {
337                [0u8; citadel_core::MERKLE_HASH_SIZE]
338            };
339
340            let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
341            let remote_hash = remote_info
342                .map(|t| t.root_hash)
343                .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
344
345            if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
346                continue;
347            }
348
349            transport.send(&SyncMessage::TableSyncBegin {
350                table_name: table_name.clone(),
351                root_page: local_root,
352                root_hash: local_hash,
353            })?;
354
355            if local_root.is_valid() && remote_root.is_valid() {
356                let local_reader =
357                    LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
358                let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
359                let diff =
360                    merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
361
362                if !diff.is_empty() {
363                    let patch =
364                        SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
365                    transport.send(&SyncMessage::PatchData {
366                        data: patch.serialize(),
367                    })?;
368                    match transport.recv()? {
369                        SyncMessage::PatchAck { result } => {
370                            results.push((table_name.clone(), result));
371                        }
372                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
373                        other => {
374                            return Err(SyncError::UnexpectedMessage {
375                                expected: "PatchAck".into(),
376                                actual: msg_name(&other).into(),
377                            })
378                        }
379                    }
380                }
381            } else if local_root.is_valid() {
382                let local_reader =
383                    LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
384                let entries = local_reader
385                    .subtree_entries(local_root)
386                    .map_err(SyncError::Database)?;
387                if !entries.is_empty() {
388                    let diff = crate::diff::DiffResult {
389                        entries,
390                        pages_compared: 0,
391                        subtrees_skipped: 0,
392                    };
393                    let patch =
394                        SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
395                    transport.send(&SyncMessage::PatchData {
396                        data: patch.serialize(),
397                    })?;
398                    match transport.recv()? {
399                        SyncMessage::PatchAck { result } => {
400                            results.push((table_name.clone(), result));
401                        }
402                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
403                        other => {
404                            return Err(SyncError::UnexpectedMessage {
405                                expected: "PatchAck".into(),
406                                actual: msg_name(&other).into(),
407                            })
408                        }
409                    }
410                }
411            }
412
413            transport.send(&SyncMessage::TableSyncEnd {
414                table_name: table_name.clone(),
415            })?;
416        }
417
418        transport.send(&SyncMessage::Done)?;
419        Ok(results)
420    }
421
422    /// Handle multi-table sync as the responder.
423    pub fn handle_table_sync_as_responder(
424        &self,
425        manager: &TxnManager,
426        transport: &dyn SyncTransport,
427    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
428        match transport.recv()? {
429            SyncMessage::TableListRequest => {}
430            SyncMessage::Done => return Ok(Vec::new()),
431            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
432            other => {
433                return Err(SyncError::UnexpectedMessage {
434                    expected: "TableListRequest".into(),
435                    actual: msg_name(&other).into(),
436                })
437            }
438        }
439
440        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
441        let table_infos: Vec<TableInfo> = local_tables
442            .iter()
443            .filter(|(name, _)| !name.starts_with(b"__idx_"))
444            .filter_map(|(name, desc)| {
445                if desc.root_page.is_valid() {
446                    let hash = manager
447                        .read_page_from_disk(desc.root_page)
448                        .map(|p| p.merkle_hash())
449                        .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
450                    Some(TableInfo {
451                        name: name.clone(),
452                        root_page: desc.root_page,
453                        root_hash: hash,
454                    })
455                } else {
456                    None
457                }
458            })
459            .collect();
460        transport.send(&SyncMessage::TableListResponse {
461            tables: table_infos,
462        })?;
463
464        let mut results = Vec::new();
465        let mut current_table: Option<Vec<u8>> = None;
466
467        loop {
468            let msg = transport.recv()?;
469            match msg {
470                SyncMessage::TableSyncBegin { table_name, .. } => {
471                    current_table = Some(table_name);
472                }
473                SyncMessage::TableSyncEnd { .. } => {
474                    current_table = None;
475                }
476                SyncMessage::DigestRequest { page_ids } => {
477                    let reader = if let Some(ref tname) = current_table {
478                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
479                        if let Some(r) = root {
480                            LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
481                        } else {
482                            LocalTreeReader::new(manager)
483                        }
484                    } else {
485                        LocalTreeReader::new(manager)
486                    };
487
488                    let mut digests = Vec::with_capacity(page_ids.len());
489                    for pid in &page_ids {
490                        match reader.page_digest(*pid) {
491                            Ok(d) => digests.push(d),
492                            Err(e) => {
493                                transport.send(&SyncMessage::Error {
494                                    message: e.to_string(),
495                                })?;
496                                continue;
497                            }
498                        }
499                    }
500                    transport.send(&SyncMessage::DigestResponse { digests })?;
501                }
502                SyncMessage::EntriesRequest { page_ids } => {
503                    let reader = if let Some(ref tname) = current_table {
504                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
505                        if let Some(r) = root {
506                            LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
507                        } else {
508                            LocalTreeReader::new(manager)
509                        }
510                    } else {
511                        LocalTreeReader::new(manager)
512                    };
513
514                    let mut entries = Vec::new();
515                    for pid in &page_ids {
516                        match reader.leaf_entries(*pid) {
517                            Ok(e) => entries.extend(e),
518                            Err(e) => {
519                                transport.send(&SyncMessage::Error {
520                                    message: e.to_string(),
521                                })?;
522                                continue;
523                            }
524                        }
525                    }
526                    transport.send(&SyncMessage::EntriesResponse { entries })?;
527                }
528                SyncMessage::PatchData { data } => {
529                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
530                    let result = if let Some(ref tname) = current_table {
531                        apply_patch_to_table(manager, tname, &patch).map_err(SyncError::Database)?
532                    } else {
533                        apply_patch(manager, &patch).map_err(SyncError::Database)?
534                    };
535                    if let Some(ref tname) = current_table {
536                        results.push((tname.clone(), result.clone()));
537                    }
538                    transport.send(&SyncMessage::PatchAck { result })?;
539                }
540                SyncMessage::Done => break,
541                SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
542                _ => {
543                    transport.send(&SyncMessage::Error {
544                        message: "unexpected message in table sync".into(),
545                    })?;
546                }
547            }
548        }
549
550        Ok(results)
551    }
552
553    /// Pull: diff(remote → local) via merkle_diff, apply locally.
554    fn initiator_pull(
555        &self,
556        manager: &TxnManager,
557        transport: &dyn SyncTransport,
558        remote_root: PageId,
559        remote_hash: MerkleHash,
560    ) -> std::result::Result<ApplyResult, SyncError> {
561        let local_reader = LocalTreeReader::new(manager);
562        let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
563
564        if local_hash == remote_hash {
565            return Ok(ApplyResult::empty());
566        }
567
568        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
569
570        // source = remote, target = local
571        let diff = merkle_diff(&remote_reader, &local_reader).map_err(SyncError::Database)?;
572
573        if diff.is_empty() {
574            return Ok(ApplyResult::empty());
575        }
576
577        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
578        let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
579        Ok(result)
580    }
581}