use fixedbitset::FixedBitSet;
use flume::{Receiver, Sender, TryRecvError};
use parking_lot::RwLock;
use slotmap::{SlotMap, new_key_type};
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use tracing::{debug, error};
new_key_type! {
pub struct TableObserverHandle;
}
#[derive(Debug, Clone)]
pub struct DropRemoveTableObserverHandle {
watcher: Weak<Watcher>,
handle: TableObserverHandle,
}
impl DropRemoveTableObserverHandle {
fn new(handle: TableObserverHandle, watcher: &Arc<Watcher>) -> Self {
Self {
watcher: Arc::downgrade(watcher),
handle,
}
}
#[must_use]
pub fn handle(&self) -> TableObserverHandle {
self.handle
}
pub fn unsubscribe(&self) -> Result<(), Error> {
if let Some(watcher) = self.watcher.upgrade() {
watcher.remove_observer(self.handle)
} else {
Err(Error::Command)
}
}
}
impl Drop for DropRemoveTableObserverHandle {
fn drop(&mut self) {
if let Some(watcher) = self.watcher.upgrade() {
if watcher.remove_observer_deferred(self.handle).is_err() {
error!("Failed to remove watcher from observer on drop");
}
}
}
}
pub trait TableObserver: Send + Sync {
fn tables(&self) -> Vec<String>;
fn on_tables_changed(&self, tables: &BTreeSet<String>);
}
pub struct Watcher {
tables: RwLock<ObservedTables>,
tables_version: AtomicU64,
sender: Sender<Command>,
}
const WATCHER_CHANNEL_CAPACITY: usize = 24;
impl Watcher {
pub fn new() -> Result<Arc<Self>, Error> {
let (sender, receiver) = flume::bounded(WATCHER_CHANNEL_CAPACITY);
let watcher = Arc::new(Self {
tables: RwLock::new(ObservedTables::new()),
tables_version: AtomicU64::new(0),
sender,
});
let watcher_cloned = Arc::clone(&watcher);
std::thread::Builder::new()
.name("sqlite_watcher".into())
.spawn(move || {
Watcher::background_loop(receiver, &watcher_cloned);
})
.map_err(Error::Thread)?;
Ok(watcher)
}
pub fn add_observer(
&self,
observer: Box<dyn TableObserver>,
) -> Result<TableObserverHandle, Error> {
let (sender, receiver) = oneshot::channel();
if self
.sender
.send(Command::AddObserver(observer, sender))
.is_err()
{
error!("Failed to send add observer command");
return Err(Error::Command);
}
let Ok(handle) = receiver.recv() else {
error!("Failed to receive handle for new observer");
return Err(Error::Command);
};
Ok(handle)
}
pub fn add_observer_with_drop_remove(
self: &Arc<Self>,
observer: Box<dyn TableObserver>,
) -> Result<DropRemoveTableObserverHandle, Error> {
let handle = self.add_observer(observer)?;
Ok(DropRemoveTableObserverHandle::new(handle, self))
}
pub fn remove_observer_deferred(&self, handle: TableObserverHandle) -> Result<(), Error> {
self.sender
.send(Command::RemoveObserverDeferred(handle))
.map_err(|_| Error::Command)
}
pub fn remove_observer(&self, handle: TableObserverHandle) -> Result<(), Error> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(Command::RemoveObserver(handle, sender))
.map_err(|_| Error::Command)?;
receiver.recv().map_err(|_| {
error!("Failed to receive reply for remove observer command");
Error::Command
})
}
pub(crate) fn publish_changes(&self, table_ids: FixedBitSet) {
if self
.sender
.send(Command::PublishChanges(table_ids))
.is_err()
{
error!("Watcher could not communicate with background thread");
}
}
pub(crate) async fn publish_changes_async(&self, table_ids: FixedBitSet) {
if self
.sender
.send_async(Command::PublishChanges(table_ids))
.await
.is_err()
{
error!("Watcher could not communicate with background thread");
}
}
#[cfg(test)]
pub(crate) fn get_table_id(&self, table: &str) -> Option<usize> {
self.with_tables(|tables| tables.table_ids.get(table).copied())
}
fn with_tables_mut(&self, f: impl FnOnce(&mut ObservedTables)) {
let mut accessor = self.tables.write();
let prev_counter = accessor.counter;
(f)(&mut accessor);
let cur_counter = accessor.counter;
if prev_counter != cur_counter {
self.tables_version.fetch_add(1, Ordering::Release);
}
}
fn with_tables<R>(&self, f: impl (FnOnce(&ObservedTables) -> R)) -> R {
let accessor = self.tables.read();
(f)(&accessor)
}
pub(crate) fn tables_version(&self) -> u64 {
self.tables_version.load(Ordering::Acquire)
}
pub fn observed_tables(&self) -> Vec<String> {
self.with_tables(|t| t.tables.clone())
}
pub(crate) fn calculate_sync_changes(
&self,
connection_state: &FixedBitSet,
) -> (FixedBitSet, Vec<ObservedTableOp>) {
self.with_tables(|t| t.calculate_changes(connection_state))
}
#[allow(clippy::needless_pass_by_value)]
#[tracing::instrument(level= tracing::Level::TRACE, skip(receiver, watcher))]
fn background_loop(receiver: Receiver<Command>, watcher: &Watcher) {
let mut worker = WatcherWorker::new();
loop {
debug_assert!(worker.add_observers.is_empty());
debug_assert!(worker.remove_observers.is_empty());
debug_assert!(worker.publish_changes.is_empty());
let Ok(command) = receiver.recv() else {
return;
};
worker.unpack_command(command);
loop {
match receiver.try_recv() {
Ok(command) => {
worker.unpack_command(command);
}
Err(e) => match e {
TryRecvError::Empty => {
break;
}
TryRecvError::Disconnected => {
return;
}
},
}
}
worker.tick(watcher);
}
}
}
struct WatcherWorker {
observers: SlotMap<TableObserverHandle, ActiveObserver>,
updated_tables: BTreeSet<String>,
remove_observers: Vec<(TableObserverHandle, Option<oneshot::Sender<()>>)>,
add_observers: Vec<(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>)>,
publish_changes: Vec<FixedBitSet>,
}
impl WatcherWorker {
fn new() -> Self {
Self {
observers: SlotMap::with_capacity_and_key(4),
updated_tables: BTreeSet::default(),
remove_observers: vec![],
add_observers: vec![],
publish_changes: vec![],
}
}
fn unpack_command(&mut self, command: Command) {
match command {
Command::AddObserver(o, r) => self.add_observers.push((o, r)),
Command::RemoveObserver(h, r) => self.remove_observers.push((h, Some(r))),
Command::RemoveObserverDeferred(h) => {
self.remove_observers.push((h, None));
}
Command::PublishChanges(fixedbitset) => {
self.publish_changes.push(fixedbitset);
}
}
}
fn tick(&mut self, watcher: &Watcher) {
for (handle, reply) in self.remove_observers.drain(..) {
if let Some(observer) = self.observers.remove(handle) {
watcher.with_tables_mut(|tables| {
tables.untrack_tables(observer.tables.iter());
});
}
if let Some(reply) = reply {
if reply.send(()).is_err() {
error!("Failed to send reply for observer removal");
}
}
}
for (observer, reply) in self.add_observers.drain(..) {
let active_observer = ActiveObserver::new(observer);
watcher.with_tables_mut(|tables| {
tables.track_tables(active_observer.tables.iter().cloned());
});
let handle = self.observers.insert(active_observer);
if reply.send(handle).is_err() {
error!("Failed to send reply back to caller, new observer will not be added");
self.observers.remove(handle);
}
}
self.updated_tables.clear();
for table_ids in self.publish_changes.drain(..) {
if table_ids.is_clear() {
continue;
}
watcher.with_tables(|observer_tables| {
for idx in table_ids.ones() {
if let Some(name) = observer_tables.tables.get(idx).cloned() {
self.updated_tables.insert(name);
}
}
});
}
if !self.updated_tables.is_empty() {
debug!("Changes detected on tables: {:?}", self.updated_tables);
{
for (_, active_observer) in &self.observers {
if self
.updated_tables
.intersection(&active_observer.tables)
.next()
.is_some()
{
active_observer
.observer
.on_tables_changed(&self.updated_tables);
}
}
}
}
}
}
struct ActiveObserver {
observer: Box<dyn TableObserver>,
tables: BTreeSet<String>,
}
impl ActiveObserver {
fn new(observer: Box<dyn TableObserver>) -> ActiveObserver {
let tables = BTreeSet::from_iter(observer.tables());
Self { observer, tables }
}
}
enum Command {
AddObserver(Box<dyn TableObserver>, oneshot::Sender<TableObserverHandle>),
RemoveObserverDeferred(TableObserverHandle),
RemoveObserver(TableObserverHandle, oneshot::Sender<()>),
PublishChanges(FixedBitSet),
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Failed to send or receive command to/from background thread")]
Command,
#[error("Failed to create background thread: {0}")]
Thread(std::io::Error),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) enum ObservedTableOp {
Add(String, usize),
Remove(String, usize),
}
struct ObservedTables {
table_ids: BTreeMap<String, usize>,
tables: Vec<String>,
num_observers: Vec<usize>,
counter: u64,
}
impl ObservedTables {
fn new() -> Self {
Self {
table_ids: BTreeMap::new(),
tables: Vec::with_capacity(8),
num_observers: Vec::with_capacity(8),
counter: 0,
}
}
fn track_tables(&mut self, tables: impl Iterator<Item = String>) {
let mut requires_version_bump = false;
for table in tables {
match self.table_ids.entry(table.clone()) {
Entry::Vacant(v) => {
let id = self.num_observers.len();
self.tables.push(table.clone());
self.num_observers.push(1);
v.insert(id);
requires_version_bump = true;
}
Entry::Occupied(o) => {
let id = *o.get();
let current = self.num_observers[id];
if current == 0 {
requires_version_bump = true;
}
self.num_observers[*o.get()] = current + 1;
}
}
}
if requires_version_bump {
self.counter = self.counter.saturating_add(1);
}
}
fn untrack_tables<'i>(&mut self, tables: impl Iterator<Item = &'i String>) {
let mut requires_version_bump = false;
for table in tables {
if let Some(id) = self.table_ids.get(table) {
self.num_observers[*id] -= 1;
if self.num_observers[*id] == 0 {
requires_version_bump = true;
}
}
}
if requires_version_bump {
self.counter = self.counter.saturating_add(1);
}
}
fn calculate_changes(
&self,
connection_state: &FixedBitSet,
) -> (FixedBitSet, Vec<ObservedTableOp>) {
let mut result = connection_state.clone();
result.grow(self.tables.len());
let mut changes = Vec::with_capacity(self.tables.len());
let min_index = connection_state.len().min(self.tables.len());
for i in 0..min_index {
let is_tracking = connection_state[i];
let num_observers = self.num_observers[i];
if is_tracking && num_observers == 0 {
changes.push(ObservedTableOp::Remove(self.tables[i].clone(), i));
result.set(i, false);
} else if !is_tracking && num_observers != 0 {
changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
result.set(i, true);
}
}
for i in min_index..self.num_observers.len() {
if self.num_observers[i] != 0 {
changes.push(ObservedTableOp::Add(self.tables[i].clone(), i));
result.set(i, true);
}
}
(result, changes)
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::watcher::{ObservedTables, TableObserver, Watcher};
use std::collections::BTreeSet;
use std::sync::atomic::Ordering;
pub struct TestObserver {
tables: Vec<String>,
}
impl TableObserver for TestObserver {
fn tables(&self) -> Vec<String> {
self.tables.clone()
}
fn on_tables_changed(&self, _: &BTreeSet<String>) {}
}
pub(crate) fn new_test_observer(
tables: impl IntoIterator<Item = &'static str>,
) -> Box<dyn TableObserver + Send + 'static> {
Box::new(TestObserver {
tables: tables.into_iter().map(ToString::to_string).collect(),
})
}
fn check_table_counter(tables: &ObservedTables, name: &str, expected: usize) {
let idx = *tables
.table_ids
.get(name)
.expect("could not find table by name");
assert_eq!(tables.num_observers[idx], expected);
}
#[test]
fn test_observer_tables_version_counter() {
let service = Watcher::new().unwrap();
let mut version = service.tables_version.load(Ordering::Relaxed);
let observer_1 = new_test_observer(["foo", "bar"]);
let observer_2 = new_test_observer(["bar"]);
let observer_3 = new_test_observer(["bar", "omega"]);
let observer_1_id = service.add_observer(observer_1).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 2);
check_table_counter(tables, "foo", 1);
check_table_counter(tables, "bar", 1);
});
version += 1;
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
let observer_2_id = service.add_observer(observer_2).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 2);
check_table_counter(tables, "foo", 1);
check_table_counter(tables, "bar", 2);
});
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
let observer_3_id = service.add_observer(observer_3).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 3);
check_table_counter(tables, "foo", 1);
check_table_counter(tables, "omega", 1);
check_table_counter(tables, "bar", 3);
});
version += 1;
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
service.remove_observer(observer_2_id).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 3);
check_table_counter(tables, "foo", 1);
check_table_counter(tables, "bar", 2);
check_table_counter(tables, "omega", 1);
});
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
service.remove_observer(observer_3_id).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 3);
check_table_counter(tables, "foo", 1);
check_table_counter(tables, "bar", 1);
check_table_counter(tables, "omega", 0);
});
version += 1;
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
service.remove_observer(observer_1_id).unwrap();
service.with_tables(|tables| {
assert_eq!(tables.num_observers.len(), 3);
check_table_counter(tables, "foo", 0);
check_table_counter(tables, "bar", 0);
check_table_counter(tables, "omega", 0);
});
version += 1;
assert_eq!(version, service.tables_version.load(Ordering::Relaxed));
}
}