Skip to main content

citadel_sync/
session.rs

1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13/// Sync direction for a session.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16    /// Push local changes to the remote peer.
17    Push,
18    /// Pull remote changes to the local database.
19    Pull,
20    /// Push then pull (full bidirectional sync).
21    Bidirectional,
22}
23
24/// Configuration for a sync session.
25#[derive(Debug, Clone)]
26pub struct SyncConfig {
27    pub node_id: NodeId,
28    pub direction: SyncDirection,
29    pub crdt_aware: bool,
30}
31
32/// Outcome of a sync session.
33#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35    /// Result of the push phase (if Push or Bidirectional).
36    pub pushed: Option<ApplyResult>,
37    /// Result of the pull phase (if Pull or Bidirectional).
38    pub pulled: Option<ApplyResult>,
39    /// True if both databases were already identical.
40    pub already_in_sync: bool,
41}
42
43/// Orchestrates a sync session between two databases.
44///
45/// The initiator drives the protocol: sends Hello, computes diffs,
46/// builds patches, and coordinates push/pull phases.
47/// The responder answers requests and applies patches.
48pub struct SyncSession {
49    config: SyncConfig,
50}
51
52impl SyncSession {
53    pub fn new(config: SyncConfig) -> Self {
54        Self { config }
55    }
56
57    /// Run as the initiator (client) side of a sync session.
58    pub fn sync_as_initiator(
59        &self,
60        manager: &TxnManager,
61        transport: &dyn SyncTransport,
62    ) -> std::result::Result<SyncOutcome, SyncError> {
63        let local_reader = LocalTreeReader::new(manager);
64        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66        transport.send(&SyncMessage::Hello {
67            node_id: self.config.node_id,
68            root_page: local_root,
69            root_hash: local_hash,
70        })?;
71
72        let (remote_root, remote_hash, in_sync) = match transport.recv()? {
73            SyncMessage::HelloAck {
74                root_page,
75                root_hash,
76                in_sync,
77                ..
78            } => (root_page, root_hash, in_sync),
79            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
80            other => {
81                return Err(SyncError::UnexpectedMessage {
82                    expected: "HelloAck".into(),
83                    actual: msg_name(&other).into(),
84                })
85            }
86        };
87
88        if in_sync {
89            transport.send(&SyncMessage::Done)?;
90            return Ok(SyncOutcome {
91                pushed: None,
92                pulled: None,
93                already_in_sync: true,
94            });
95        }
96
97        let mut outcome = SyncOutcome {
98            pushed: None,
99            pulled: None,
100            already_in_sync: false,
101        };
102
103        // Push phase: diff(local -> remote), send patch to remote
104        if self.config.direction == SyncDirection::Push
105            || self.config.direction == SyncDirection::Bidirectional
106        {
107            let result = self.initiator_push(manager, transport, remote_root, remote_hash)?;
108            outcome.pushed = Some(result);
109        }
110
111        // Pull phase: diff(remote -> local), apply patch locally
112        if self.config.direction == SyncDirection::Pull
113            || self.config.direction == SyncDirection::Bidirectional
114        {
115            // For bidirectional after push, get updated remote state
116            let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
117                transport.send(&SyncMessage::PullRequest)?;
118                match transport.recv()? {
119                    SyncMessage::PullResponse {
120                        root_page,
121                        root_hash,
122                    } => (root_page, root_hash),
123                    SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
124                    other => {
125                        return Err(SyncError::UnexpectedMessage {
126                            expected: "PullResponse".into(),
127                            actual: msg_name(&other).into(),
128                        })
129                    }
130                }
131            } else {
132                (remote_root, remote_hash)
133            };
134
135            let result = self.initiator_pull(manager, transport, pull_root, pull_hash)?;
136            outcome.pulled = Some(result);
137        }
138
139        transport.send(&SyncMessage::Done)?;
140        Ok(outcome)
141    }
142
143    /// Run as the responder (server) side of a sync session.
144    pub fn sync_as_responder(
145        &self,
146        manager: &TxnManager,
147        transport: &dyn SyncTransport,
148    ) -> std::result::Result<SyncOutcome, SyncError> {
149        let local_reader = LocalTreeReader::new(manager);
150        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
151
152        let remote_hash = match transport.recv()? {
153            SyncMessage::Hello { root_hash, .. } => root_hash,
154            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
155            other => {
156                return Err(SyncError::UnexpectedMessage {
157                    expected: "Hello".into(),
158                    actual: msg_name(&other).into(),
159                })
160            }
161        };
162
163        let in_sync = local_hash == remote_hash;
164
165        transport.send(&SyncMessage::HelloAck {
166            node_id: self.config.node_id,
167            root_page: local_root,
168            root_hash: local_hash,
169            in_sync,
170        })?;
171
172        if in_sync {
173            let _ = transport.recv()?;
174            return Ok(SyncOutcome {
175                pushed: None,
176                pulled: None,
177                already_in_sync: true,
178            });
179        }
180
181        let mut outcome = SyncOutcome {
182            pushed: None,
183            pulled: None,
184            already_in_sync: false,
185        };
186
187        loop {
188            let msg = transport.recv()?;
189            match msg {
190                SyncMessage::DigestRequest { page_ids } => {
191                    let reader = LocalTreeReader::new(manager);
192                    let mut digests = Vec::with_capacity(page_ids.len());
193                    for pid in &page_ids {
194                        match reader.page_digest(*pid) {
195                            Ok(d) => digests.push(d),
196                            Err(e) => {
197                                transport.send(&SyncMessage::Error {
198                                    message: e.to_string(),
199                                })?;
200                                continue;
201                            }
202                        }
203                    }
204                    transport.send(&SyncMessage::DigestResponse { digests })?;
205                }
206                SyncMessage::EntriesRequest { page_ids } => {
207                    let reader = LocalTreeReader::new(manager);
208                    let mut entries = Vec::new();
209                    for pid in &page_ids {
210                        match reader.leaf_entries(*pid) {
211                            Ok(e) => entries.extend(e),
212                            Err(e) => {
213                                transport.send(&SyncMessage::Error {
214                                    message: e.to_string(),
215                                })?;
216                                continue;
217                            }
218                        }
219                    }
220                    transport.send(&SyncMessage::EntriesResponse { entries })?;
221                }
222                SyncMessage::PatchData { data } => {
223                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
224                    let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
225                    outcome.pushed = Some(result.clone());
226                    transport.send(&SyncMessage::PatchAck { result })?;
227                }
228                SyncMessage::PullRequest => {
229                    let reader = LocalTreeReader::new(manager);
230                    let (root_page, root_hash) = reader.root_info().map_err(SyncError::Database)?;
231                    transport.send(&SyncMessage::PullResponse {
232                        root_page,
233                        root_hash,
234                    })?;
235                }
236                SyncMessage::Done => {
237                    break;
238                }
239                SyncMessage::Error { message } => {
240                    return Err(SyncError::Remote(message));
241                }
242                _ => {
243                    transport.send(&SyncMessage::Error {
244                        message: "unexpected message".into(),
245                    })?;
246                }
247            }
248        }
249
250        Ok(outcome)
251    }
252
253    /// Push: diff(local -> remote) via merkle_diff, send patch.
254    fn initiator_push(
255        &self,
256        manager: &TxnManager,
257        transport: &dyn SyncTransport,
258        remote_root: PageId,
259        remote_hash: MerkleHash,
260    ) -> std::result::Result<ApplyResult, SyncError> {
261        let local_reader = LocalTreeReader::new(manager);
262        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
263
264        // source = local, target = remote
265        let diff = merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
266
267        if diff.is_empty() {
268            return Ok(ApplyResult::empty());
269        }
270
271        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
272        let patch_data = patch.serialize();
273
274        transport.send(&SyncMessage::PatchData { data: patch_data })?;
275
276        match transport.recv()? {
277            SyncMessage::PatchAck { result } => Ok(result),
278            SyncMessage::Error { message } => Err(SyncError::Remote(message)),
279            other => Err(SyncError::UnexpectedMessage {
280                expected: "PatchAck".into(),
281                actual: msg_name(&other).into(),
282            }),
283        }
284    }
285
286    /// Run multi-table sync as the initiator.
287    pub fn sync_tables_as_initiator(
288        &self,
289        manager: &TxnManager,
290        transport: &dyn SyncTransport,
291    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
292        transport.send(&SyncMessage::TableListRequest)?;
293
294        let remote_tables = match transport.recv()? {
295            SyncMessage::TableListResponse { tables } => tables,
296            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
297            other => {
298                return Err(SyncError::UnexpectedMessage {
299                    expected: "TableListResponse".into(),
300                    actual: msg_name(&other).into(),
301                })
302            }
303        };
304
305        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
306
307        let mut all_names: Vec<Vec<u8>> = Vec::new();
308        for (name, _) in &local_tables {
309            if !name.starts_with(b"__idx_") && !all_names.contains(name) {
310                all_names.push(name.clone());
311            }
312        }
313        for info in &remote_tables {
314            if !info.name.starts_with(b"__idx_") && !all_names.contains(&info.name) {
315                all_names.push(info.name.clone());
316            }
317        }
318
319        let mut results = Vec::new();
320
321        for table_name in &all_names {
322            let local_info = local_tables.iter().find(|(n, _)| n == table_name);
323            let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
324
325            let local_root = local_info
326                .map(|(_, desc)| desc.root_page)
327                .unwrap_or(PageId::INVALID);
328            let local_hash = if local_root.is_valid() {
329                manager
330                    .read_page_from_disk(local_root)
331                    .map(|p| p.merkle_hash())
332                    .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
333            } else {
334                [0u8; citadel_core::MERKLE_HASH_SIZE]
335            };
336
337            let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
338            let remote_hash = remote_info
339                .map(|t| t.root_hash)
340                .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
341
342            if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
343                continue;
344            }
345
346            transport.send(&SyncMessage::TableSyncBegin {
347                table_name: table_name.clone(),
348                root_page: local_root,
349                root_hash: local_hash,
350            })?;
351
352            if local_root.is_valid() && remote_root.is_valid() {
353                let local_reader =
354                    LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
355                let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
356                let diff =
357                    merkle_diff(&local_reader, &remote_reader).map_err(SyncError::Database)?;
358
359                if !diff.is_empty() {
360                    let patch =
361                        SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
362                    transport.send(&SyncMessage::PatchData {
363                        data: patch.serialize(),
364                    })?;
365                    match transport.recv()? {
366                        SyncMessage::PatchAck { result } => {
367                            results.push((table_name.clone(), result));
368                        }
369                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
370                        other => {
371                            return Err(SyncError::UnexpectedMessage {
372                                expected: "PatchAck".into(),
373                                actual: msg_name(&other).into(),
374                            })
375                        }
376                    }
377                }
378            } else if local_root.is_valid() {
379                let local_reader =
380                    LocalTreeReader::for_table(manager, local_root).map_err(SyncError::Database)?;
381                let entries = local_reader
382                    .subtree_entries(local_root)
383                    .map_err(SyncError::Database)?;
384                if !entries.is_empty() {
385                    let diff = crate::diff::DiffResult {
386                        entries,
387                        pages_compared: 0,
388                        subtrees_skipped: 0,
389                    };
390                    let patch =
391                        SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
392                    transport.send(&SyncMessage::PatchData {
393                        data: patch.serialize(),
394                    })?;
395                    match transport.recv()? {
396                        SyncMessage::PatchAck { result } => {
397                            results.push((table_name.clone(), result));
398                        }
399                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
400                        other => {
401                            return Err(SyncError::UnexpectedMessage {
402                                expected: "PatchAck".into(),
403                                actual: msg_name(&other).into(),
404                            })
405                        }
406                    }
407                }
408            }
409
410            transport.send(&SyncMessage::TableSyncEnd {
411                table_name: table_name.clone(),
412            })?;
413        }
414
415        transport.send(&SyncMessage::Done)?;
416        Ok(results)
417    }
418
419    /// Handle multi-table sync as the responder.
420    pub fn handle_table_sync_as_responder(
421        &self,
422        manager: &TxnManager,
423        transport: &dyn SyncTransport,
424    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
425        match transport.recv()? {
426            SyncMessage::TableListRequest => {}
427            SyncMessage::Done => return Ok(Vec::new()),
428            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
429            other => {
430                return Err(SyncError::UnexpectedMessage {
431                    expected: "TableListRequest".into(),
432                    actual: msg_name(&other).into(),
433                })
434            }
435        }
436
437        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
438        let table_infos: Vec<TableInfo> = local_tables
439            .iter()
440            .filter(|(name, _)| !name.starts_with(b"__idx_"))
441            .filter_map(|(name, desc)| {
442                if desc.root_page.is_valid() {
443                    let hash = manager
444                        .read_page_from_disk(desc.root_page)
445                        .map(|p| p.merkle_hash())
446                        .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
447                    Some(TableInfo {
448                        name: name.clone(),
449                        root_page: desc.root_page,
450                        root_hash: hash,
451                    })
452                } else {
453                    None
454                }
455            })
456            .collect();
457        transport.send(&SyncMessage::TableListResponse {
458            tables: table_infos,
459        })?;
460
461        let mut results = Vec::new();
462        let mut current_table: Option<Vec<u8>> = None;
463
464        loop {
465            let msg = transport.recv()?;
466            match msg {
467                SyncMessage::TableSyncBegin { table_name, .. } => {
468                    current_table = Some(table_name);
469                }
470                SyncMessage::TableSyncEnd { .. } => {
471                    current_table = None;
472                }
473                SyncMessage::DigestRequest { page_ids } => {
474                    let reader = if let Some(ref tname) = current_table {
475                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
476                        if let Some(r) = root {
477                            LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
478                        } else {
479                            LocalTreeReader::new(manager)
480                        }
481                    } else {
482                        LocalTreeReader::new(manager)
483                    };
484
485                    let mut digests = Vec::with_capacity(page_ids.len());
486                    for pid in &page_ids {
487                        match reader.page_digest(*pid) {
488                            Ok(d) => digests.push(d),
489                            Err(e) => {
490                                transport.send(&SyncMessage::Error {
491                                    message: e.to_string(),
492                                })?;
493                                continue;
494                            }
495                        }
496                    }
497                    transport.send(&SyncMessage::DigestResponse { digests })?;
498                }
499                SyncMessage::EntriesRequest { page_ids } => {
500                    let reader = if let Some(ref tname) = current_table {
501                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
502                        if let Some(r) = root {
503                            LocalTreeReader::for_table(manager, r).map_err(SyncError::Database)?
504                        } else {
505                            LocalTreeReader::new(manager)
506                        }
507                    } else {
508                        LocalTreeReader::new(manager)
509                    };
510
511                    let mut entries = Vec::new();
512                    for pid in &page_ids {
513                        match reader.leaf_entries(*pid) {
514                            Ok(e) => entries.extend(e),
515                            Err(e) => {
516                                transport.send(&SyncMessage::Error {
517                                    message: e.to_string(),
518                                })?;
519                                continue;
520                            }
521                        }
522                    }
523                    transport.send(&SyncMessage::EntriesResponse { entries })?;
524                }
525                SyncMessage::PatchData { data } => {
526                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
527                    let result = if let Some(ref tname) = current_table {
528                        apply_patch_to_table(manager, tname, &patch).map_err(SyncError::Database)?
529                    } else {
530                        apply_patch(manager, &patch).map_err(SyncError::Database)?
531                    };
532                    if let Some(ref tname) = current_table {
533                        results.push((tname.clone(), result.clone()));
534                    }
535                    transport.send(&SyncMessage::PatchAck { result })?;
536                }
537                SyncMessage::Done => break,
538                SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
539                _ => {
540                    transport.send(&SyncMessage::Error {
541                        message: "unexpected message in table sync".into(),
542                    })?;
543                }
544            }
545        }
546
547        Ok(results)
548    }
549
550    /// Pull: diff(remote -> local) via merkle_diff, apply locally.
551    fn initiator_pull(
552        &self,
553        manager: &TxnManager,
554        transport: &dyn SyncTransport,
555        remote_root: PageId,
556        remote_hash: MerkleHash,
557    ) -> std::result::Result<ApplyResult, SyncError> {
558        let local_reader = LocalTreeReader::new(manager);
559        let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
560
561        if local_hash == remote_hash {
562            return Ok(ApplyResult::empty());
563        }
564
565        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
566
567        // source = remote, target = local
568        let diff = merkle_diff(&remote_reader, &local_reader).map_err(SyncError::Database)?;
569
570        if diff.is_empty() {
571            return Ok(ApplyResult::empty());
572        }
573
574        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
575        let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
576        Ok(result)
577    }
578}