Skip to main content

citadeldb_sync/
session.rs

1use citadel_txn::manager::TxnManager;
2
3use crate::apply::{apply_patch, apply_patch_to_table, ApplyResult};
4use crate::diff::{merkle_diff, MerkleHash, TreeReader};
5use crate::local_reader::LocalTreeReader;
6use crate::node_id::NodeId;
7use crate::patch::SyncPatch;
8use crate::protocol::{SyncMessage, TableInfo};
9use crate::transport::{msg_name, RemoteTreeReader, SyncError, SyncTransport};
10
11use citadel_core::types::PageId;
12
13/// Sync direction for a session.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum SyncDirection {
16    /// Push local changes to the remote peer.
17    Push,
18    /// Pull remote changes to the local database.
19    Pull,
20    /// Push then pull (full bidirectional sync).
21    Bidirectional,
22}
23
24/// Configuration for a sync session.
25#[derive(Debug, Clone)]
26pub struct SyncConfig {
27    pub node_id: NodeId,
28    pub direction: SyncDirection,
29    pub crdt_aware: bool,
30}
31
32/// Outcome of a sync session.
33#[derive(Debug, Clone)]
34pub struct SyncOutcome {
35    /// Result of the push phase (if Push or Bidirectional).
36    pub pushed: Option<ApplyResult>,
37    /// Result of the pull phase (if Pull or Bidirectional).
38    pub pulled: Option<ApplyResult>,
39    /// True if both databases were already identical.
40    pub already_in_sync: bool,
41}
42
43/// Orchestrates a sync session between two databases.
44///
45/// The initiator drives the protocol: sends Hello, computes diffs,
46/// builds patches, and coordinates push/pull phases.
47/// The responder answers requests and applies patches.
48pub struct SyncSession {
49    config: SyncConfig,
50}
51
52impl SyncSession {
53    pub fn new(config: SyncConfig) -> Self {
54        Self { config }
55    }
56
57    /// Run as the initiator (client) side of a sync session.
58    pub fn sync_as_initiator(
59        &self,
60        manager: &TxnManager,
61        transport: &dyn SyncTransport,
62    ) -> std::result::Result<SyncOutcome, SyncError> {
63        let local_reader = LocalTreeReader::new(manager);
64        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
65
66        // Hello exchange
67        transport.send(&SyncMessage::Hello {
68            node_id: self.config.node_id,
69            root_page: local_root,
70            root_hash: local_hash,
71        })?;
72
73        let (remote_root, remote_hash, in_sync) = match transport.recv()? {
74            SyncMessage::HelloAck { root_page, root_hash, in_sync, .. } => {
75                (root_page, root_hash, in_sync)
76            }
77            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
78            other => return Err(SyncError::UnexpectedMessage {
79                expected: "HelloAck".into(),
80                actual: msg_name(&other).into(),
81            }),
82        };
83
84        if in_sync {
85            transport.send(&SyncMessage::Done)?;
86            return Ok(SyncOutcome {
87                pushed: None,
88                pulled: None,
89                already_in_sync: true,
90            });
91        }
92
93        let mut outcome = SyncOutcome {
94            pushed: None,
95            pulled: None,
96            already_in_sync: false,
97        };
98
99        // Push phase: diff(local → remote), send patch to remote
100        if self.config.direction == SyncDirection::Push
101            || self.config.direction == SyncDirection::Bidirectional
102        {
103            let result = self.initiator_push(
104                manager, transport, remote_root, remote_hash,
105            )?;
106            outcome.pushed = Some(result);
107        }
108
109        // Pull phase: diff(remote → local), apply patch locally
110        if self.config.direction == SyncDirection::Pull
111            || self.config.direction == SyncDirection::Bidirectional
112        {
113            // For bidirectional after push, get updated remote state
114            let (pull_root, pull_hash) = if self.config.direction == SyncDirection::Bidirectional {
115                transport.send(&SyncMessage::PullRequest)?;
116                match transport.recv()? {
117                    SyncMessage::PullResponse { root_page, root_hash } => {
118                        (root_page, root_hash)
119                    }
120                    SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
121                    other => return Err(SyncError::UnexpectedMessage {
122                        expected: "PullResponse".into(),
123                        actual: msg_name(&other).into(),
124                    }),
125                }
126            } else {
127                (remote_root, remote_hash)
128            };
129
130            let result = self.initiator_pull(
131                manager, transport, pull_root, pull_hash,
132            )?;
133            outcome.pulled = Some(result);
134        }
135
136        transport.send(&SyncMessage::Done)?;
137        Ok(outcome)
138    }
139
140    /// Run as the responder (server) side of a sync session.
141    pub fn sync_as_responder(
142        &self,
143        manager: &TxnManager,
144        transport: &dyn SyncTransport,
145    ) -> std::result::Result<SyncOutcome, SyncError> {
146        let local_reader = LocalTreeReader::new(manager);
147        let (local_root, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
148
149        // Receive Hello
150        let remote_hash = match transport.recv()? {
151            SyncMessage::Hello { root_hash, .. } => root_hash,
152            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
153            other => return Err(SyncError::UnexpectedMessage {
154                expected: "Hello".into(),
155                actual: msg_name(&other).into(),
156            }),
157        };
158
159        let in_sync = local_hash == remote_hash;
160
161        transport.send(&SyncMessage::HelloAck {
162            node_id: self.config.node_id,
163            root_page: local_root,
164            root_hash: local_hash,
165            in_sync,
166        })?;
167
168        if in_sync {
169            match transport.recv()? {
170                SyncMessage::Done => {}
171                _ => {}
172            }
173            return Ok(SyncOutcome {
174                pushed: None,
175                pulled: None,
176                already_in_sync: true,
177            });
178        }
179
180        let mut outcome = SyncOutcome {
181            pushed: None,
182            pulled: None,
183            already_in_sync: false,
184        };
185
186        // Serve requests until Done
187        loop {
188            let msg = transport.recv()?;
189            match msg {
190                SyncMessage::DigestRequest { page_ids } => {
191                    let reader = LocalTreeReader::new(manager);
192                    let mut digests = Vec::with_capacity(page_ids.len());
193                    for pid in &page_ids {
194                        match reader.page_digest(*pid) {
195                            Ok(d) => digests.push(d),
196                            Err(e) => {
197                                transport.send(&SyncMessage::Error {
198                                    message: e.to_string(),
199                                })?;
200                                continue;
201                            }
202                        }
203                    }
204                    transport.send(&SyncMessage::DigestResponse { digests })?;
205                }
206                SyncMessage::EntriesRequest { page_ids } => {
207                    let reader = LocalTreeReader::new(manager);
208                    let mut entries = Vec::new();
209                    for pid in &page_ids {
210                        match reader.leaf_entries(*pid) {
211                            Ok(e) => entries.extend(e),
212                            Err(e) => {
213                                transport.send(&SyncMessage::Error {
214                                    message: e.to_string(),
215                                })?;
216                                continue;
217                            }
218                        }
219                    }
220                    transport.send(&SyncMessage::EntriesResponse { entries })?;
221                }
222                SyncMessage::PatchData { data } => {
223                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
224                    let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
225                    outcome.pushed = Some(result.clone());
226                    transport.send(&SyncMessage::PatchAck { result })?;
227                }
228                SyncMessage::PullRequest => {
229                    let reader = LocalTreeReader::new(manager);
230                    let (root_page, root_hash) =
231                        reader.root_info().map_err(SyncError::Database)?;
232                    transport.send(&SyncMessage::PullResponse { root_page, root_hash })?;
233                }
234                SyncMessage::Done => {
235                    break;
236                }
237                SyncMessage::Error { message } => {
238                    return Err(SyncError::Remote(message));
239                }
240                _ => {
241                    transport.send(&SyncMessage::Error {
242                        message: "unexpected message".into(),
243                    })?;
244                }
245            }
246        }
247
248        Ok(outcome)
249    }
250
251    /// Push: diff(local → remote) via merkle_diff, send patch.
252    fn initiator_push(
253        &self,
254        manager: &TxnManager,
255        transport: &dyn SyncTransport,
256        remote_root: PageId,
257        remote_hash: MerkleHash,
258    ) -> std::result::Result<ApplyResult, SyncError> {
259        let local_reader = LocalTreeReader::new(manager);
260        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
261
262        // source = local, target = remote
263        let diff = merkle_diff(&local_reader, &remote_reader)
264            .map_err(SyncError::Database)?;
265
266        if diff.is_empty() {
267            return Ok(ApplyResult::empty());
268        }
269
270        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
271        let patch_data = patch.serialize();
272
273        transport.send(&SyncMessage::PatchData { data: patch_data })?;
274
275        match transport.recv()? {
276            SyncMessage::PatchAck { result } => Ok(result),
277            SyncMessage::Error { message } => Err(SyncError::Remote(message)),
278            other => Err(SyncError::UnexpectedMessage {
279                expected: "PatchAck".into(),
280                actual: msg_name(&other).into(),
281            }),
282        }
283    }
284
285    /// Run multi-table sync as the initiator.
286    pub fn sync_tables_as_initiator(
287        &self,
288        manager: &TxnManager,
289        transport: &dyn SyncTransport,
290    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
291        transport.send(&SyncMessage::TableListRequest)?;
292
293        let remote_tables = match transport.recv()? {
294            SyncMessage::TableListResponse { tables } => tables,
295            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
296            other => return Err(SyncError::UnexpectedMessage {
297                expected: "TableListResponse".into(),
298                actual: msg_name(&other).into(),
299            }),
300        };
301
302        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
303
304        let mut all_names: Vec<Vec<u8>> = Vec::new();
305        for (name, _) in &local_tables {
306            if !name.starts_with(b"__idx_") {
307                if !all_names.iter().any(|n| n == name) {
308                    all_names.push(name.clone());
309                }
310            }
311        }
312        for info in &remote_tables {
313            if !info.name.starts_with(b"__idx_") {
314                if !all_names.iter().any(|n| *n == info.name) {
315                    all_names.push(info.name.clone());
316                }
317            }
318        }
319
320        let mut results = Vec::new();
321
322        for table_name in &all_names {
323            let local_info = local_tables.iter().find(|(n, _)| n == table_name);
324            let remote_info = remote_tables.iter().find(|t| t.name == *table_name);
325
326            let local_root = local_info
327                .map(|(_, desc)| desc.root_page)
328                .unwrap_or(PageId::INVALID);
329            let local_hash = if local_root.is_valid() {
330                manager.read_page_from_disk(local_root)
331                    .map(|p| p.merkle_hash())
332                    .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE])
333            } else {
334                [0u8; citadel_core::MERKLE_HASH_SIZE]
335            };
336
337            let remote_root = remote_info.map(|t| t.root_page).unwrap_or(PageId::INVALID);
338            let remote_hash = remote_info
339                .map(|t| t.root_hash)
340                .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
341
342            if local_hash == remote_hash && local_root.is_valid() && remote_root.is_valid() {
343                continue;
344            }
345
346            transport.send(&SyncMessage::TableSyncBegin {
347                table_name: table_name.clone(),
348                root_page: local_root,
349                root_hash: local_hash,
350            })?;
351
352            if local_root.is_valid() && remote_root.is_valid() {
353                let local_reader = LocalTreeReader::for_table(manager, local_root)
354                    .map_err(SyncError::Database)?;
355                let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
356                let diff = merkle_diff(&local_reader, &remote_reader)
357                    .map_err(SyncError::Database)?;
358
359                if !diff.is_empty() {
360                    let patch = SyncPatch::from_diff(
361                        self.config.node_id, &diff, self.config.crdt_aware,
362                    );
363                    transport.send(&SyncMessage::PatchData { data: patch.serialize() })?;
364                    match transport.recv()? {
365                        SyncMessage::PatchAck { result } => {
366                            results.push((table_name.clone(), result));
367                        }
368                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
369                        other => return Err(SyncError::UnexpectedMessage {
370                            expected: "PatchAck".into(),
371                            actual: msg_name(&other).into(),
372                        }),
373                    }
374                }
375            } else if local_root.is_valid() {
376                let local_reader = LocalTreeReader::for_table(manager, local_root)
377                    .map_err(SyncError::Database)?;
378                let entries = local_reader.subtree_entries(local_root)
379                    .map_err(SyncError::Database)?;
380                if !entries.is_empty() {
381                    let diff = crate::diff::DiffResult {
382                        entries,
383                        pages_compared: 0,
384                        subtrees_skipped: 0,
385                    };
386                    let patch = SyncPatch::from_diff(
387                        self.config.node_id, &diff, self.config.crdt_aware,
388                    );
389                    transport.send(&SyncMessage::PatchData { data: patch.serialize() })?;
390                    match transport.recv()? {
391                        SyncMessage::PatchAck { result } => {
392                            results.push((table_name.clone(), result));
393                        }
394                        SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
395                        other => return Err(SyncError::UnexpectedMessage {
396                            expected: "PatchAck".into(),
397                            actual: msg_name(&other).into(),
398                        }),
399                    }
400                }
401            }
402
403            transport.send(&SyncMessage::TableSyncEnd {
404                table_name: table_name.clone(),
405            })?;
406        }
407
408        transport.send(&SyncMessage::Done)?;
409        Ok(results)
410    }
411
412    /// Handle multi-table sync as the responder.
413    pub fn handle_table_sync_as_responder(
414        &self,
415        manager: &TxnManager,
416        transport: &dyn SyncTransport,
417    ) -> std::result::Result<Vec<(Vec<u8>, ApplyResult)>, SyncError> {
418        match transport.recv()? {
419            SyncMessage::TableListRequest => {}
420            SyncMessage::Done => return Ok(Vec::new()),
421            SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
422            other => return Err(SyncError::UnexpectedMessage {
423                expected: "TableListRequest".into(),
424                actual: msg_name(&other).into(),
425            }),
426        }
427
428        let local_tables = manager.list_tables().map_err(SyncError::Database)?;
429        let table_infos: Vec<TableInfo> = local_tables
430            .iter()
431            .filter(|(name, _)| !name.starts_with(b"__idx_"))
432            .filter_map(|(name, desc)| {
433                if desc.root_page.is_valid() {
434                    let hash = manager.read_page_from_disk(desc.root_page)
435                        .map(|p| p.merkle_hash())
436                        .unwrap_or([0u8; citadel_core::MERKLE_HASH_SIZE]);
437                    Some(TableInfo {
438                        name: name.clone(),
439                        root_page: desc.root_page,
440                        root_hash: hash,
441                    })
442                } else {
443                    None
444                }
445            })
446            .collect();
447        transport.send(&SyncMessage::TableListResponse { tables: table_infos })?;
448
449        let mut results = Vec::new();
450        let mut current_table: Option<Vec<u8>> = None;
451
452        loop {
453            let msg = transport.recv()?;
454            match msg {
455                SyncMessage::TableSyncBegin { table_name, .. } => {
456                    current_table = Some(table_name);
457                }
458                SyncMessage::TableSyncEnd { .. } => {
459                    current_table = None;
460                }
461                SyncMessage::DigestRequest { page_ids } => {
462                    let reader = if let Some(ref tname) = current_table {
463                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
464                        if let Some(r) = root {
465                            LocalTreeReader::for_table(manager, r)
466                                .map_err(SyncError::Database)?
467                        } else {
468                            LocalTreeReader::new(manager)
469                        }
470                    } else {
471                        LocalTreeReader::new(manager)
472                    };
473
474                    let mut digests = Vec::with_capacity(page_ids.len());
475                    for pid in &page_ids {
476                        match reader.page_digest(*pid) {
477                            Ok(d) => digests.push(d),
478                            Err(e) => {
479                                transport.send(&SyncMessage::Error {
480                                    message: e.to_string(),
481                                })?;
482                                continue;
483                            }
484                        }
485                    }
486                    transport.send(&SyncMessage::DigestResponse { digests })?;
487                }
488                SyncMessage::EntriesRequest { page_ids } => {
489                    let reader = if let Some(ref tname) = current_table {
490                        let root = manager.table_root(tname).map_err(SyncError::Database)?;
491                        if let Some(r) = root {
492                            LocalTreeReader::for_table(manager, r)
493                                .map_err(SyncError::Database)?
494                        } else {
495                            LocalTreeReader::new(manager)
496                        }
497                    } else {
498                        LocalTreeReader::new(manager)
499                    };
500
501                    let mut entries = Vec::new();
502                    for pid in &page_ids {
503                        match reader.leaf_entries(*pid) {
504                            Ok(e) => entries.extend(e),
505                            Err(e) => {
506                                transport.send(&SyncMessage::Error {
507                                    message: e.to_string(),
508                                })?;
509                                continue;
510                            }
511                        }
512                    }
513                    transport.send(&SyncMessage::EntriesResponse { entries })?;
514                }
515                SyncMessage::PatchData { data } => {
516                    let patch = SyncPatch::deserialize(&data).map_err(SyncError::Patch)?;
517                    let result = if let Some(ref tname) = current_table {
518                        apply_patch_to_table(manager, tname, &patch)
519                            .map_err(SyncError::Database)?
520                    } else {
521                        apply_patch(manager, &patch).map_err(SyncError::Database)?
522                    };
523                    if let Some(ref tname) = current_table {
524                        results.push((tname.clone(), result.clone()));
525                    }
526                    transport.send(&SyncMessage::PatchAck { result })?;
527                }
528                SyncMessage::Done => break,
529                SyncMessage::Error { message } => return Err(SyncError::Remote(message)),
530                _ => {
531                    transport.send(&SyncMessage::Error {
532                        message: "unexpected message in table sync".into(),
533                    })?;
534                }
535            }
536        }
537
538        Ok(results)
539    }
540
541    /// Pull: diff(remote → local) via merkle_diff, apply locally.
542    fn initiator_pull(
543        &self,
544        manager: &TxnManager,
545        transport: &dyn SyncTransport,
546        remote_root: PageId,
547        remote_hash: MerkleHash,
548    ) -> std::result::Result<ApplyResult, SyncError> {
549        let local_reader = LocalTreeReader::new(manager);
550        let (_, local_hash) = local_reader.root_info().map_err(SyncError::Database)?;
551
552        if local_hash == remote_hash {
553            return Ok(ApplyResult::empty());
554        }
555
556        let remote_reader = RemoteTreeReader::new(transport, remote_root, remote_hash);
557
558        // source = remote, target = local
559        let diff = merkle_diff(&remote_reader, &local_reader)
560            .map_err(SyncError::Database)?;
561
562        if diff.is_empty() {
563            return Ok(ApplyResult::empty());
564        }
565
566        let patch = SyncPatch::from_diff(self.config.node_id, &diff, self.config.crdt_aware);
567        let result = apply_patch(manager, &patch).map_err(SyncError::Database)?;
568        Ok(result)
569    }
570}