Skip to main content

contextdb_server/
sync_plugin.rs

1use contextdb_core::{AtomicLsn, Lsn};
2use contextdb_engine::plugin::{CommitSource, DatabasePlugin};
3use contextdb_tx::WriteSet;
4use std::sync::atomic::{AtomicBool, Ordering};
5use tokio::sync::mpsc;
6
7/// Plugin that marks auto-sync as active.
8/// Sends change notifications to the background push task via an mpsc channel.
9pub struct SyncPlugin {
10    tx: std::sync::Mutex<Option<mpsc::UnboundedSender<()>>>,
11    auto_enabled: AtomicBool,
12    pending_lsn: AtomicLsn,
13}
14
15impl SyncPlugin {
16    pub fn new(tx: mpsc::UnboundedSender<()>) -> Self {
17        Self {
18            tx: std::sync::Mutex::new(Some(tx)),
19            auto_enabled: AtomicBool::new(false),
20            pending_lsn: AtomicLsn::new(Lsn(0)),
21        }
22    }
23
24    /// Enable or disable auto-sync.
25    pub fn set_auto(&self, enabled: bool) {
26        self.auto_enabled.store(enabled, Ordering::SeqCst);
27    }
28
29    /// Check if auto-sync is enabled.
30    pub fn is_auto(&self) -> bool {
31        self.auto_enabled.load(Ordering::SeqCst)
32    }
33
34    pub fn pending_lsn(&self) -> Lsn {
35        self.pending_lsn.load(Ordering::SeqCst)
36    }
37
38    /// Signal the background push task that a DML change occurred.
39    pub fn notify_change(&self) -> Result<(), &'static str> {
40        match self.tx.lock() {
41            Ok(guard) => {
42                if let Some(tx) = guard.as_ref() {
43                    if tx.send(()).is_err() {
44                        tracing::warn!("sync plugin receiver dropped; change notification lost");
45                        return Err("auto-sync worker unavailable");
46                    }
47                    Ok(())
48                } else {
49                    Err("auto-sync worker unavailable")
50                }
51            }
52            Err(_) => {
53                tracing::warn!("sync plugin mutex poisoned; skipping change notification");
54                Err("auto-sync worker unavailable")
55            }
56        }
57    }
58
59    /// Shutdown: drop the sender to close the channel and stop the background task.
60    pub fn shutdown(&self) {
61        match self.tx.lock() {
62            Ok(mut guard) => {
63                let _ = guard.take();
64            }
65            Err(_) => tracing::warn!("sync plugin mutex poisoned during shutdown"),
66        }
67    }
68}
69
70impl DatabasePlugin for SyncPlugin {
71    fn post_commit(&self, ws: &WriteSet, source: CommitSource) {
72        if !self.is_auto() || source == CommitSource::SyncPull || ws.is_empty() {
73            return;
74        }
75        if let Some(lsn) = ws.commit_lsn {
76            self.pending_lsn.fetch_max(lsn, Ordering::SeqCst);
77        }
78        let _ = self.notify_change();
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::sync::Arc;
86
87    #[test]
88    fn sync_03_plugin_survives_poisoned_mutex() {
89        let (tx, _rx) = mpsc::unbounded_channel();
90        let plugin = Arc::new(SyncPlugin::new(tx));
91        let poison_plugin = plugin.clone();
92        let _ = std::thread::spawn(move || {
93            let _guard = poison_plugin.tx.lock().unwrap();
94            panic!("poison sync_plugin mutex");
95        })
96        .join();
97
98        let panic = std::panic::catch_unwind(|| {
99            let _ = plugin.notify_change();
100        });
101        assert!(
102            panic.is_ok(),
103            "notify_change should not panic on a poisoned sync plugin mutex"
104        );
105    }
106
107    #[test]
108    fn sync_04_plugin_queues_multiple_notifications() {
109        let (tx, mut rx) = mpsc::unbounded_channel();
110        let plugin = SyncPlugin::new(tx);
111
112        plugin.notify_change().unwrap();
113        plugin.notify_change().unwrap();
114
115        assert_eq!(rx.try_recv(), Ok(()));
116        assert_eq!(rx.try_recv(), Ok(()));
117    }
118
119    #[test]
120    fn sync_05_plugin_reports_closed_receiver() {
121        let (tx, rx) = mpsc::unbounded_channel();
122        let plugin = SyncPlugin::new(tx);
123        drop(rx);
124
125        assert_eq!(plugin.notify_change(), Err("auto-sync worker unavailable"));
126    }
127
128    #[test]
129    fn sync_06_post_commit_notifies_only_for_local_writes_when_enabled() {
130        let (tx, mut rx) = mpsc::unbounded_channel();
131        let plugin = SyncPlugin::new(tx);
132        let mut ws = WriteSet::new();
133        ws.relational_inserts.push((
134            "t".to_string(),
135            contextdb_core::VersionedRow {
136                row_id: contextdb_core::RowId(1),
137                values: std::collections::HashMap::new(),
138                created_tx: contextdb_core::TxId(1),
139                deleted_tx: None,
140                lsn: Lsn(1),
141                created_at: None,
142            },
143        ));
144
145        plugin.post_commit(&ws, CommitSource::AutoCommit);
146        assert!(rx.try_recv().is_err(), "disabled auto-sync must stay quiet");
147
148        plugin.set_auto(true);
149        plugin.post_commit(&ws, CommitSource::SyncPull);
150        assert!(
151            rx.try_recv().is_err(),
152            "sync-pull commits must not trigger another auto-sync push"
153        );
154
155        plugin.post_commit(&ws, CommitSource::AutoCommit);
156        assert_eq!(rx.try_recv(), Ok(()));
157    }
158
159    #[test]
160    fn sync_07_post_commit_tracks_latest_pending_lsn() {
161        let (tx, _rx) = mpsc::unbounded_channel();
162        let plugin = SyncPlugin::new(tx);
163        plugin.set_auto(true);
164
165        let mut ws = WriteSet::new();
166        ws.commit_lsn = Some(Lsn(7));
167        ws.relational_deletes.push((
168            "t".to_string(),
169            contextdb_core::RowId(1),
170            contextdb_core::TxId(7),
171        ));
172        plugin.post_commit(&ws, CommitSource::AutoCommit);
173        assert_eq!(plugin.pending_lsn(), Lsn(7));
174
175        let mut newer = WriteSet::new();
176        newer.commit_lsn = Some(Lsn(11));
177        newer.relational_deletes.push((
178            "t".to_string(),
179            contextdb_core::RowId(2),
180            contextdb_core::TxId(11),
181        ));
182        plugin.post_commit(&newer, CommitSource::AutoCommit);
183        assert_eq!(plugin.pending_lsn(), Lsn(11));
184    }
185}