Skip to main content

sqlite_watcher/
watcher.rs

1use fixedbitset::FixedBitSet;
2use flume::{Receiver, Sender, TryRecvError};
3use parking_lot::RwLock;
4use slotmap::{SlotMap, new_key_type};
5use std::collections::btree_map::Entry;
6use std::collections::{BTreeMap, BTreeSet};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Weak};
9use tracing::{debug, error};
10
11new_key_type! {
12    /// Handle for a [`TableObserver`].
13    pub struct TableObserverHandle;
14}
15
16/// Utility type that removes an observer from a [`Watcher`] when the type is dropped.
17///
18/// The [`TableObserver`] will be removed with [`Watcher::remove_observer_deferred()`].
19#[derive(Debug, Clone)]
20pub struct DropRemoveTableObserverHandle {
21    watcher: Weak<Watcher>,
22    handle: TableObserverHandle,
23}
24
25impl DropRemoveTableObserverHandle {
26    fn new(handle: TableObserverHandle, watcher: &Arc<Watcher>) -> Self {
27        Self {
28            watcher: Arc::downgrade(watcher),
29            handle,
30        }
31    }
32
33    /// Returns the handle of the table observer.
34    #[must_use]
35    pub fn handle(&self) -> TableObserverHandle {
36        self.handle
37    }
38
39    /// Unsubscribe the observer immediately.
40    ///
41    /// This can be safely called multiple times, the observer
42    /// is only unsubscribed once.
43    ///
44    /// # Errors
45    ///
46    /// Returns error if we can't communicate with Watcher or the
47    /// removal of the observer failed.
48    pub fn unsubscribe(&self) -> Result<(), Error> {
49        if let Some(watcher) = self.watcher.upgrade() {
50            watcher.remove_observer(self.handle)
51        } else {
52            Err(Error::Command)
53        }
54    }
55}
56
57impl Drop for DropRemoveTableObserverHandle {
58    fn drop(&mut self) {
59        if let Some(watcher) = self.watcher.upgrade() {
60            if watcher.remove_observer_deferred(self.handle).is_err() {
61                error!("Failed to remove watcher from observer on drop");
62            }
63        }
64    }
65}
66
67/// Defines an observer for a set of tables.
68pub trait TableObserver: Send + Sync {
69    /// Return the set of tables this observer is interested in.
70    fn tables(&self) -> Vec<String>;
71
72    /// When one or more of the tables return by [`Self::tables()`] is modified, this method
73    /// will be invoked by the [`Watcher`].
74    ///
75    /// `tables` contains the set of tables that we modified.
76    ///
77    /// It is recommended that the implementation be as short as possible to not delay/block
78    /// the execution of other observers.
79    fn on_tables_changed(&self, tables: &BTreeSet<String>);
80}
81
82/// The [`Watcher`] is the hub where updates are published regarding tables that updated when
83/// observing a connection.
84///
85/// All changes are published to a background thread which then notifies the respective
86/// [`TableObserver`]s.
87///
88/// # Observing Tables
89///
90/// To be notified of changes, register an observer with [`Watcher::add_observer`].
91///
92/// The [`Watcher`] by itself does not automatically watch all tables. The observed tables
93/// are driven by the tables defined by each [`TableObserver`].
94///
95/// A table can be observed by many [`TableObserver`]. When the last [`TableObserver`] is removed
96/// for a given table, that table stops being tracked.
97///
98/// # Update Propagation
99///
100/// Every time a [`TableObserver`] is added or removed, the list of tracked tables is updated and
101/// a counter is bumped. These changes are propagated to [State] instances when they sync their
102/// state [`State::sync_tables()`](crate::connection::State::sync_tables).
103///
104/// Due to the nature of concurrent operations, it is possible that a connection on different
105/// thread will miss the changes applied from adding/removing an observer on the current thread. On
106/// the next sync this will be rectified.
107///
108/// If both operation happen on the same thread, everything will work as expected.
109///
110/// # Notifications
111///
112/// To notify the [`Watcher`] of changed tables, an instance of either [Connection] or
113/// [State] needs to be used. Check each type for more information on how to use
114/// it correctly.
115///
116/// # Remarks
117///
118/// The [`Watcher`] currently maintains a list of observed tables that is never pruned. It will
119/// keep growing with every new table that is observed. If you have af fixed set of tables that
120/// you watch on a regular basis this is not an issue. If you have a dynamic list of tables
121/// deleted tables are currently not removed. To be addressed in the future.
122///
123///
124/// [Connection]: `crate::connection::Connection`
125/// [State]: `crate::connection::State`
126pub struct Watcher {
127    tables: RwLock<ObservedTables>,
128    tables_version: AtomicU64,
129    sender: Sender<Command>,
130}
131
132const WATCHER_CHANNEL_CAPACITY: usize = 24;
133
134impl Watcher {
135    /// Create a new instance of an in process tracker service.
136    ///
137    /// # Errors
138    /// Returns error if the worker thread fails to spawn.
139    pub fn new() -> Result<Arc<Self>, Error> {
140        let (sender, receiver) = flume::bounded(WATCHER_CHANNEL_CAPACITY);
141        let watcher = Arc::new(Self {
142            tables: RwLock::new(ObservedTables::new()),
143            tables_version: AtomicU64::new(0),
144            sender,
145        });
146
147        let watcher_cloned = Arc::clone(&watcher);
148        std::thread::Builder::new()
149            .name("sqlite_watcher".into())
150            .spawn(move || {
151                Watcher::background_loop(receiver, &watcher_cloned);
152            })
153            .map_err(Error::Thread)?;
154
155        Ok(watcher)
156    }
157
158    /// Register a new observer with a list of interested tables.
159    ///
160    /// This function returns a [`TableObserverHandle`] which can later be used to
161    /// remove the current observer.
162    ///
163    /// # Errors
164    ///
165    /// Returns error if the command which adds the observer to the background thread
166    /// could not be sent or the handle could not be retrieved.
167    pub fn add_observer(
168        &self,
169        observer: Box<dyn TableObserver>,
170    ) -> Result<TableObserverHandle, Error> {
171        let (sender, receiver) = oneshot::channel();
172        if self
173            .sender
174            .send(Command::AddObserver(observer, sender))
175            .is_err()
176        {
177            error!("Failed to send add observer command");
178            return Err(Error::Command);
179        }
180
181        let Ok(handle) = receiver.recv() else {
182            error!("Failed to receive handle for new observer");
183            return Err(Error::Command);
184        };
185
186        Ok(handle)
187    }
188
189    /// Same as [`Self::add_observer`], but returns a handle that removes the observer
190    /// from this [`Watcher`] on drop.
191    ///
192    ///
193    /// # Errors
194    ///
195    /// See [`Self::add_observer`] for more details.
196    pub fn add_observer_with_drop_remove(
197        self: &Arc<Self>,
198        observer: Box<dyn TableObserver>,
199    ) -> Result<DropRemoveTableObserverHandle, Error> {
200        let handle = self.add_observer(observer)?;
201
202        Ok(DropRemoveTableObserverHandle::new(handle, self))
203    }
204
205    /// Remove an observer via its `handle` without waiting for the operation to complete.
206    ///
207    /// The removal of observers is deferred to the background thread and will
208    /// be executed as soon as possible.
209    ///
210    /// If you wish to wait for an observer to finish being removed from the list,
211    /// you should use [`Self::remove_observer()`]
212    ///
213    /// # Errors
214    ///
215    /// Returns error if the command to remove the observer could not be sent.
216    pub fn remove_observer_deferred(&self, handle: TableObserverHandle) -> Result<(), Error> {
217        self.sender
218            .send(Command::RemoveObserverDeferred(handle))
219            .map_err(|_| Error::Command)
220    }
221
222    /// Remove an observer via its `handle` and wait for it to be removed.
223    ///
224    /// If you wish do not wish to wait for an observer to finish being removed from the list,
225    /// you should use [`Self::remove_observer_deferred()`]
226    ///
227    /// # Errors
228    ///
229    /// Returns error if the command to remove the observer could not be sent or the reply
230    /// could not be received.
231    pub fn remove_observer(&self, handle: TableObserverHandle) -> Result<(), Error> {
232        let (sender, receiver) = oneshot::channel();
233        self.sender
234            .send(Command::RemoveObserver(handle, sender))
235            .map_err(|_| Error::Command)?;
236
237        receiver.recv().map_err(|_| {
238            error!("Failed to receive reply for remove observer command");
239            Error::Command
240        })
241    }
242
243    pub(crate) fn publish_changes(&self, table_ids: FixedBitSet) {
244        if self
245            .sender
246            .send(Command::PublishChanges(table_ids))
247            .is_err()
248        {
249            error!("Watcher could not communicate with background thread");
250        }
251    }
252
253    pub(crate) async fn publish_changes_async(&self, table_ids: FixedBitSet) {
254        if self
255            .sender
256            .send_async(Command::PublishChanges(table_ids))
257            .await
258            .is_err()
259        {
260            error!("Watcher could not communicate with background thread");
261        }
262    }
263
264    #[cfg(test)]
265    pub(crate) fn get_table_id(&self, table: &str) -> Option<usize> {
266        self.with_tables(|tables| tables.table_ids.get(table).copied())
267    }
268
269    fn with_tables_mut(&self, f: impl FnOnce(&mut ObservedTables)) {
270        let mut accessor = self.tables.write();
271        // Save counter to check for significant changes
272        let prev_counter = accessor.counter;
273
274        (f)(&mut accessor);
275
276        // Significant changes were made.
277        let cur_counter = accessor.counter;
278        if prev_counter != cur_counter {
279            self.tables_version.fetch_add(1, Ordering::Release);
280        }
281    }
282
283    fn with_tables<R>(&self, f: impl (FnOnce(&ObservedTables) -> R)) -> R {
284        let accessor = self.tables.read();
285        (f)(&accessor)
286    }
287
288    /// The current version of the tracked tables state.
289    pub(crate) fn tables_version(&self) -> u64 {
290        self.tables_version.load(Ordering::Acquire)
291    }
292
293    /// Return the list of observed tables at this point in time.
294    pub fn observed_tables(&self) -> Vec<String> {
295        self.with_tables(|t| t.tables.clone())
296    }
297
298    pub(crate) fn calculate_sync_changes(
299        &self,
300        connection_state: &FixedBitSet,
301    ) -> (FixedBitSet, Vec<ObservedTableOp>) {
302        self.with_tables(|t| t.calculate_changes(connection_state))
303    }
304
305    #[allow(clippy::needless_pass_by_value)]
306    #[tracing::instrument(level= tracing::Level::TRACE, skip(receiver, watcher))]
307    fn background_loop(receiver: Receiver<Command>, watcher: &Watcher) {
308        let mut worker = WatcherWorker::new();
309
310        loop {
311            debug_assert!(worker.add_observers.is_empty());
312            debug_assert!(worker.remove_observers.is_empty());
313            debug_assert!(worker.publish_changes.is_empty());
314
315            let Ok(command) = receiver.recv() else {
316                return;
317            };
318
319            worker.unpack_command(command);
320
321            // try to read more commands if any are queued.
322            loop {
323                match receiver.try_recv() {
324                    Ok(command) => {
325                        worker.unpack_command(command);
326                    }
327                    Err(e) => match e {
328                        TryRecvError::Empty => {
329                            break;
330                        }
331                        TryRecvError::Disconnected => {
332                            return;
333                        }
334                    },
335                }
336            }
337
338            worker.tick(watcher);
339        }
340    }
341}
342
343/// Background watcher worker which responds to [`Command`]s;
344struct WatcherWorker {
345    observers: SlotMap<TableObserverHandle, ActiveObserver>,
346    updated_tables: BTreeSet<String>,
347    remove_observers: Vec<(TableObserverHandle, Option<oneshot::Sender<()>>)>,
348    add_observers: Vec<(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>)>,
349    publish_changes: Vec<FixedBitSet>,
350}
351
352impl WatcherWorker {
353    fn new() -> Self {
354        Self {
355            observers: SlotMap::with_capacity_and_key(4),
356            updated_tables: BTreeSet::default(),
357            remove_observers: vec![],
358            add_observers: vec![],
359            publish_changes: vec![],
360        }
361    }
362    fn unpack_command(&mut self, command: Command) {
363        match command {
364            Command::AddObserver(o, r) => self.add_observers.push((o, r)),
365            Command::RemoveObserver(h, r) => self.remove_observers.push((h, Some(r))),
366            Command::RemoveObserverDeferred(h) => {
367                self.remove_observers.push((h, None));
368            }
369            Command::PublishChanges(fixedbitset) => {
370                self.publish_changes.push(fixedbitset);
371            }
372        }
373    }
374
375    fn tick(&mut self, watcher: &Watcher) {
376        // Remove old observers,
377        for (handle, reply) in self.remove_observers.drain(..) {
378            if let Some(observer) = self.observers.remove(handle) {
379                watcher.with_tables_mut(|tables| {
380                    tables.untrack_tables(observer.tables.iter());
381                });
382            }
383
384            if let Some(reply) = reply {
385                if reply.send(()).is_err() {
386                    error!("Failed to send reply for observer removal");
387                }
388            }
389        }
390
391        // Add new observers
392        for (observer, reply) in self.add_observers.drain(..) {
393            let active_observer = ActiveObserver::new(observer);
394            watcher.with_tables_mut(|tables| {
395                tables.track_tables(active_observer.tables.iter().cloned());
396            });
397            let handle = self.observers.insert(active_observer);
398            if reply.send(handle).is_err() {
399                error!("Failed to send reply back to caller, new observer will not be added");
400                self.observers.remove(handle);
401            }
402        }
403
404        // Combine and publish changes.
405        self.updated_tables.clear();
406
407        for table_ids in self.publish_changes.drain(..) {
408            if table_ids.is_clear() {
409                continue;
410            }
411
412            // resolve table names.
413            watcher.with_tables(|observer_tables| {
414                for idx in table_ids.ones() {
415                    // Safeguard against some invalid index, just in case.
416                    if let Some(name) = observer_tables.tables.get(idx).cloned() {
417                        self.updated_tables.insert(name);
418                    }
419                }
420            });
421        }
422
423        if !self.updated_tables.is_empty() {
424            debug!("Changes detected on tables: {:?}", self.updated_tables);
425            // publish changes;
426            {
427                for (_, active_observer) in &self.observers {
428                    if self
429                        .updated_tables
430                        .intersection(&active_observer.tables)
431                        .next()
432                        .is_some()
433                    {
434                        active_observer
435                            .observer
436                            .on_tables_changed(&self.updated_tables);
437                    }
438                }
439            }
440        }
441    }
442}
443
444struct ActiveObserver {
445    observer: Box<dyn TableObserver>,
446    tables: BTreeSet<String>,
447}
448
449impl ActiveObserver {
450    fn new(observer: Box<dyn TableObserver>) -> ActiveObserver {
451        let tables = BTreeSet::from_iter(observer.tables());
452        Self { observer, tables }
453    }
454}
455
456/// Commands send to the background thread.
457enum Command {
458    /// Add a new observer
459    AddObserver(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>),
460    /// Remove an observer
461    RemoveObserverDeferred(TableObserverHandle),
462    /// Remove an observer and wait for the operation to finish.
463    RemoveObserver(TableObserverHandle, oneshot::Sender<()>),
464    /// Publish new changes
465    PublishChanges(FixedBitSet),
466}
467
468#[derive(Debug, thiserror::Error)]
469pub enum Error {
470    #[error("Failed to send or receive command to/from background thread")]
471    Command,
472    #[error("Failed to create background thread: {0}")]
473    Thread(std::io::Error),
474}
475
476#[derive(Debug, Clone, Eq, PartialEq)]
477pub(crate) enum ObservedTableOp {
478    Add(String, usize),
479    Remove(String, usize),
480}
481
482/// Keeps track of all the observed tables.
483///
484/// Each table is assigned an unique value (index) which is then propagated to all the trackers
485/// when they sync their state.
486struct ObservedTables {
487    /// Table names to index/id
488    table_ids: BTreeMap<String, usize>,
489    /// Table names by index/id
490    tables: Vec<String>,
491    /// Number of active observers for each table.
492    num_observers: Vec<usize>,
493    /// Version counter.
494    counter: u64,
495}
496
497impl ObservedTables {
498    fn new() -> Self {
499        Self {
500            table_ids: BTreeMap::new(),
501            tables: Vec::with_capacity(8),
502            num_observers: Vec::with_capacity(8),
503            counter: 0,
504        }
505    }
506
507    /// Add the `tables` to the list of tables that need to be observed.
508    fn track_tables(&mut self, tables: impl Iterator<Item = String>) {
509        let mut requires_version_bump = false;
510        for table in tables {
511            match self.table_ids.entry(table.clone()) {
512                Entry::Vacant(v) => {
513                    let id = self.num_observers.len();
514                    self.tables.push(table.clone());
515                    self.num_observers.push(1);
516                    v.insert(id);
517                    requires_version_bump = true;
518                }
519                Entry::Occupied(o) => {
520                    let id = *o.get();
521                    let current = self.num_observers[id];
522                    if current == 0 {
523                        // We should start following this table again. If it is not
524                        // 0, we are already observing it.
525                        requires_version_bump = true;
526                    }
527                    self.num_observers[*o.get()] = current + 1;
528                }
529            }
530        }
531
532        if requires_version_bump {
533            self.counter = self.counter.saturating_add(1);
534        }
535    }
536
537    /// Remove the `tables` from the list of tables that need to be observed.
538    fn untrack_tables<'i>(&mut self, tables: impl Iterator<Item = &'i String>) {
539        let mut requires_version_bump = false;
540        for table in tables {
541            if let Some(id) = self.table_ids.get(table) {
542                // We never remove the table entirely, but we need to stop tracking
543                // once all observers have been removed.
544                self.num_observers[*id] -= 1;
545                if self.num_observers[*id] == 0 {
546                    requires_version_bump = true;
547                }
548            }
549        }
550
551        if requires_version_bump {
552            self.counter = self.counter.saturating_add(1);
553        }
554    }
555
556    /// Calculate the which tables should be added or removed from a `connection_state` to
557    /// make sure it is synced up with the current list.
558    ///
559    /// This will return the new updated state as well as the list of triggers that should be
560    /// created or removed.
561    fn calculate_changes(
562        &self,
563        connection_state: &FixedBitSet,
564    ) -> (FixedBitSet, Vec<ObservedTableOp>) {
565        let mut result = connection_state.clone();
566        result.grow(self.tables.len());
567        let mut changes = Vec::with_capacity(self.tables.len());
568        let min_index = connection_state.len().min(self.tables.len());
569        for i in 0..min_index {
570            let is_tracking = connection_state[i];
571            let num_observers = self.num_observers[i];
572
573            if is_tracking && num_observers == 0 {
574                changes.push(ObservedTableOp::Remove(self.tables[i].clone(), i));
575                result.set(i, false);
576            } else if !is_tracking && num_observers != 0 {
577                changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
578                result.set(i, true);
579            }
580        }
581
582        // Process any new tables that might be missing.
583        for i in min_index..self.num_observers.len() {
584            if self.num_observers[i] != 0 {
585                changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
586                result.set(i, true);
587            }
588        }
589
590        (result, changes)
591    }
592}
593
594#[cfg(test)]
595pub(crate) mod tests {
596    use crate::watcher::{ObservedTables, TableObserver, Watcher};
597    use std::collections::BTreeSet;
598    use std::sync::atomic::Ordering;
599
600    pub struct TestObserver {
601        tables: Vec<String>,
602    }
603
604    impl TableObserver for TestObserver {
605        fn tables(&self) -> Vec<String> {
606            self.tables.clone()
607        }
608        fn on_tables_changed(&self, _: &BTreeSet<String>) {}
609    }
610
611    pub(crate) fn new_test_observer(
612        tables: impl IntoIterator<Item = &'static str>,
613    ) -> Box<dyn TableObserver + Send + 'static> {
614        Box::new(TestObserver {
615            tables: tables.into_iter().map(ToString::to_string).collect(),
616        })
617    }
618
619    fn check_table_counter(tables: &ObservedTables, name: &str, expected: usize) {
620        let idx = *tables
621            .table_ids
622            .get(name)
623            .expect("could not find table by name");
624        assert_eq!(tables.num_observers[idx], expected);
625    }
626
627    #[test]
628    fn test_observer_tables_version_counter() {
629        let service = Watcher::new().unwrap();
630
631        let mut version = service.tables_version.load(Ordering::Relaxed);
632        let observer_1 = new_test_observer(["foo", "bar"]);
633        let observer_2 = new_test_observer(["bar"]);
634        let observer_3 = new_test_observer(["bar", "omega"]);
635
636        // Adding new observer triggers change.
637        let observer_1_id = service.add_observer(observer_1).unwrap();
638        service.with_tables(|tables| {
639            assert_eq!(tables.num_observers.len(), 2);
640            check_table_counter(tables, "foo", 1);
641            check_table_counter(tables, "bar", 1);
642        });
643        version += 1;
644        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
645
646        // Adding an observer for only bar does not change version counter.
647        let observer_2_id = service.add_observer(observer_2).unwrap();
648        service.with_tables(|tables| {
649            assert_eq!(tables.num_observers.len(), 2);
650            check_table_counter(tables, "foo", 1);
651            check_table_counter(tables, "bar", 2);
652        });
653        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
654
655        // Adding this observer causes another change
656        let observer_3_id = service.add_observer(observer_3).unwrap();
657        service.with_tables(|tables| {
658            assert_eq!(tables.num_observers.len(), 3);
659            check_table_counter(tables, "foo", 1);
660            check_table_counter(tables, "omega", 1);
661            check_table_counter(tables, "bar", 3);
662        });
663        version += 1;
664        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
665
666        // Remove observer 2 causes no version change.
667        service.remove_observer(observer_2_id).unwrap();
668        service.with_tables(|tables| {
669            assert_eq!(tables.num_observers.len(), 3);
670            check_table_counter(tables, "foo", 1);
671            check_table_counter(tables, "bar", 2);
672            check_table_counter(tables, "omega", 1);
673        });
674        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
675
676        // Remove observer 3 causes version change.
677        service.remove_observer(observer_3_id).unwrap();
678        service.with_tables(|tables| {
679            assert_eq!(tables.num_observers.len(), 3);
680            check_table_counter(tables, "foo", 1);
681            check_table_counter(tables, "bar", 1);
682            check_table_counter(tables, "omega", 0);
683        });
684        version += 1;
685        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
686
687        // Remove observer 1 causes version change.
688        service.remove_observer(observer_1_id).unwrap();
689        service.with_tables(|tables| {
690            assert_eq!(tables.num_observers.len(), 3);
691            check_table_counter(tables, "foo", 0);
692            check_table_counter(tables, "bar", 0);
693            check_table_counter(tables, "omega", 0);
694        });
695        version += 1;
696        assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
697    }
698}