sqlite_watcher/
connection.rs

1use crate::statement::{
2    BatchQuery, Sealed, SqlExecuteStatement, SqlTransactionStatement, Statement, StatementWithInput,
3};
4use crate::watcher::{ObservedTableOp, Watcher};
5use fixedbitset::FixedBitSet;
6use std::error::Error;
7use std::future::Future;
8use std::ops::{Deref, DerefMut};
9use std::sync::Arc;
10use tracing::{debug, trace, warn};
11
12#[cfg(feature = "rusqlite")]
13pub mod rusqlite;
14
15#[cfg(feature = "sqlx")]
16pub mod sqlx;
17
18/// Defines an implementation capable of executing SQL statement on a sqlite connection.
19///
20/// This is required so we can set up the temporary triggers and tables required to
21/// track changes.
22pub trait SqlExecutor {
23    type Error: Error;
24    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
25    ///
26    /// # Errors
27    ///
28    /// Should return error if the query failed.
29    fn sql_query_values(&self, query: &str) -> Result<Vec<u32>, Self::Error>;
30
31    /// Execute an sql statement which does not return any rows.
32    ///
33    /// # Errors
34    ///
35    /// Should return error if the query failed.
36    fn sql_execute(&self, query: &str) -> Result<(), Self::Error>;
37}
38
39/// Similar to [`SqlExecutor`], but for implementations that require mutable access to
40/// the connection to work.
41pub trait SqlExecutorMut {
42    type Error: Error;
43    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
44    ///
45    /// # Errors
46    ///
47    /// Should return error if the query failed.
48    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error>;
49
50    /// Execute an sql statement which does not return any rows.
51    ///
52    /// # Errors
53    ///
54    /// Should return error if the query failed.
55    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error>;
56}
57
58// Automatically derive SqlExecutorMut for any implementation of SqlExecutor.
59impl<T: SqlExecutor> SqlExecutorMut for T {
60    type Error = T::Error;
61
62    fn sql_query_values(&mut self, query: &str) -> Result<Vec<u32>, Self::Error> {
63        SqlExecutor::sql_query_values(self, query)
64    }
65
66    fn sql_execute(&mut self, query: &str) -> Result<(), Self::Error> {
67        SqlExecutor::sql_execute(self, query)
68    }
69}
70
71/// Defines an implementation capable of executing SQL statement on a sqlite connection.
72///
73/// This is required so we can set up the temporary triggers and tables required to
74/// track changes.
75pub trait SqlExecutorAsync: Send {
76    type Error: Error + Send;
77    /// This method will execute a query which returns 0 or N rows with one column of type `u32`.
78    ///
79    /// # Errors
80    ///
81    /// Should return error if the query failed.
82    fn sql_query_values(
83        &mut self,
84        query: &str,
85    ) -> impl Future<Output = Result<Vec<u32>, Self::Error>> + Send;
86
87    /// Execute an sql statement which does not return any rows.
88    ///
89    /// # Errors
90    ///
91    /// Should return error if the query failed.
92    fn sql_execute(&mut self, query: &str) -> impl Future<Output = Result<(), Self::Error>> + Send;
93}
94
95/// Building block to provide tracking capabilities to any type of sqlite connection which
96/// implements the [`SqlExecutor`] trait.
97///
98/// # Initialization
99///
100/// It's recommended to call [`State::set_pragmas()`] to enable in memory temporary tables and recursive
101/// triggers. If your connection already has this set up, this can be skipped.
102///
103/// Next you need to create the infrastructure to track changes. This can be accomplished with
104/// [`State::start_tracking()`].
105///
106/// # Tracking changes
107///
108/// To make sure we only track required tables always call [`State::sync_tables()`] before a query/statement
109/// or a transaction.
110///
111/// When the query/statement or transaction are completed, call [`State::publish_changes()`] to check
112/// which tables have been modified and send this information to the watcher.
113///
114/// # Disable Tracking
115///
116/// If you wish to remove all the tracking infrastructure from a connection on which
117/// [`State::start_tracking()`] was called, then call [`State::stop_tracking()`].
118///
119/// # See Also
120///
121/// The [`Connection`] type provided by this crate provides an example integration implementation.
122#[derive(Debug, Default)]
123pub struct State {
124    tracked_tables: FixedBitSet,
125    last_sync_version: u64,
126}
127
128impl State {
129    /// Enable required pragmas for execution.
130    #[must_use]
131    pub fn set_pragmas() -> impl Statement {
132        SqlExecuteStatement::new("PRAGMA temp_store = MEMORY")
133            .then(SqlExecuteStatement::new("PRAGMA recursive_triggers='ON'"))
134    }
135
136    /// Prepare the `connection` for tracking.
137    ///
138    /// This will create the temporary table used to track change.
139    #[must_use]
140    #[tracing::instrument(level = tracing::Level::DEBUG)]
141    pub fn start_tracking() -> impl Statement {
142        // create tracking table and cleanup previous data if re-used from a connection pool.
143        SqlTransactionStatement::temporary(
144            SqlExecuteStatement::new(create_tracking_table_query())
145                .then(SqlExecuteStatement::new(empty_tracking_table_query())),
146        )
147        .spanned_in_current()
148    }
149
150    /// Remove all triggers and the tracking table from `connection`.
151    //
152    /// # Errors
153    ///
154    /// Returns error if the initialization failed.
155    #[tracing::instrument(level = tracing::Level::DEBUG, skip_all)]
156    pub fn stop_tracking(&self, watcher: &Watcher) -> impl Statement {
157        let tables = watcher.observed_tables();
158        SqlTransactionStatement::temporary(
159            BatchQuery::new(
160                tables
161                    .into_iter()
162                    .enumerate()
163                    .flat_map(|(id, table_name)| drop_triggers(&table_name, id)),
164            )
165            .then(SqlExecuteStatement::new(drop_tracking_table_query())),
166        )
167        .spanned_in_current()
168    }
169
170    /// Create a new instance without initializing any connection.
171    #[must_use]
172    pub fn new() -> Self {
173        Self {
174            tracked_tables: FixedBitSet::new(),
175            last_sync_version: 0,
176        }
177    }
178
179    /// Synchronize the table list from the watcher.
180    ///
181    /// This method will create new triggers for tables that are not being watched over this
182    /// connection and remove triggers for tables that are no longer observed by the watcher.
183    ///
184    /// # Errors
185    ///
186    /// Returns error if creation or removal of triggers failed.
187    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
188    pub fn sync_tables(&mut self, watcher: &Watcher) -> Option<impl Statement + '_> {
189        let new_version = self.should_sync(watcher)?;
190
191        debug!("Syncing tables from observer");
192        let Some((new_tracker_state, tracker_changes)) = self.calculate_sync_changes(watcher)
193        else {
194            debug!("No changes");
195            return None;
196        };
197
198        let mut queries = BatchQuery::new([]);
199
200        if self.tracked_tables.is_empty() {
201            // It is possible on certain circumstances that if a connection can have leftover
202            // tracking data that is not cleared. To make sure this is reset, we force empty
203            // the table if we detect that we are not watching any tables at the moment.
204            queries.push(SqlExecuteStatement::new(empty_tracking_table_query()));
205        }
206        for change in tracker_changes {
207            match change {
208                ObservedTableOp::Add(table_name, id) => {
209                    debug!("Add watcher for table {table_name} id={id}");
210                    queries.extend(create_triggers(&table_name, id));
211                }
212                ObservedTableOp::Remove(table_name, id) => {
213                    debug!("Remove watcher for table {table_name}");
214                    queries.extend(drop_triggers(&table_name, id));
215                }
216            }
217        }
218
219        let tx = SqlTransactionStatement::temporary(queries);
220        Some(
221            tx.then(ConcludeStateChangeStatement {
222                state: self,
223                tracked_tables: new_tracker_state,
224                new_version,
225            })
226            .spanned_in_current(),
227        )
228    }
229
230    /// Check the tracking table and report finding to the [Watcher].
231    ///
232    /// The table where the changes are tracked is read and reset. Any
233    /// table that has been modified will be communicated to the [Watcher], which in turn
234    /// will notify the respective [TableObserver].
235    ///
236    /// # Errors
237    ///
238    /// Returns error if we failed to read from the temporary tables.
239    ///
240    /// [Watcher]: `crate::watcher::Watcher`
241    /// [TableObserver]: `crate::watcher::TableObserver`
242    #[tracing::instrument(level=tracing::Level::DEBUG, skip(self, watcher))]
243    pub fn publish_changes(&self, watcher: &Watcher) -> impl Statement {
244        SqlReadTableIdsStatement
245            .pipe(CalculateWatcherUpdatesStatement { state: self })
246            .pipe(MaybeResetResultsQuery)
247            .pipe(PublishWatcherChangesStatement(watcher))
248            .spanned_in_current()
249    }
250
251    fn prepare_watcher_changes(&self, modified_table_ids: Vec<u32>) -> FixedBitSet {
252        trace!("Preparing watcher changes");
253        let mut result = FixedBitSet::with_capacity(self.tracked_tables.len());
254        for id in modified_table_ids {
255            let id = id as usize;
256            debug!("Table {} has been modified", id);
257            if id >= result.len() {
258                warn!(
259                    "Received update for table {id}, but only tracking {} tables",
260                    self.tracked_tables.len(),
261                );
262                // We need to grow on the index + 1.
263                result.grow(id + 1);
264            }
265            result.set(id, true);
266        }
267
268        result
269    }
270
271    fn should_sync(&self, watcher: &Watcher) -> Option<u64> {
272        let service_version = watcher.tables_version();
273        if service_version == self.last_sync_version {
274            None
275        } else {
276            Some(service_version)
277        }
278    }
279
280    /// Determine which tables should start and/or stop being watched.
281    fn calculate_sync_changes(
282        &self,
283        watcher: &Watcher,
284    ) -> Option<(FixedBitSet, Vec<ObservedTableOp>)> {
285        trace!("Calculating sync changes");
286        let (new_tracker_state, tracker_changes) =
287            watcher.calculate_sync_changes(&self.tracked_tables);
288
289        if tracker_changes.is_empty() {
290            return None;
291        }
292
293        Some((new_tracker_state, tracker_changes))
294    }
295
296    /// Once we are satisfied with the changes, apply the new state.
297    fn apply_sync_changes(&mut self, new_tracker_state: FixedBitSet, new_version: u64) {
298        // Update local tracker bitset
299        trace!("Applying sync changes");
300        self.tracked_tables = new_tracker_state;
301        self.last_sync_version = new_version;
302    }
303}
304
305/// Connection abstraction that provides on possible implementation which uses the building
306/// blocks ([`State`]) provided by this crate.
307///
308/// For simplicity, it takes ownership of an existing type which implements [`SqlExecutor`] and
309/// initializes all the tracking infrastructure. The original type can still be accessed as
310/// [`Connection`] implements both [`Deref`] and [`DerefMut`].
311///
312/// # Remarks
313///
314/// To make sure all changes are capture, it's recommended to always call
315/// [`Connection::sync_watcher_tables()`]
316/// before any query/statement or transaction.
317///
318/// # Example
319///
320/// ## Single Query/Statement
321///
322/// ```rust
323/// use sqlite_watcher::connection::Connection;
324/// use sqlite_watcher::connection::SqlExecutor;
325/// use sqlite_watcher::watcher::Watcher;
326///
327/// pub fn track_changes<C:SqlExecutor>(connection: C) {
328///     let watcher = Watcher::new().unwrap();
329///     let mut connection = Connection::new(connection, watcher).unwrap();
330///
331///     // Sync tables so we are up to date.
332///     connection.sync_watcher_tables().unwrap();
333///
334///     connection.sql_execute("sql query here").unwrap();
335///
336///     // Publish changes to the watcher
337///     connection.publish_watcher_changes().unwrap();
338/// }
339/// ```
340///
341/// ## Transaction
342///
343/// ```rust
344/// use sqlite_watcher::connection::Connection;
345/// use sqlite_watcher::connection::{SqlExecutor};
346/// use sqlite_watcher::watcher::Watcher;
347///
348/// pub fn track_changes<C:SqlExecutor>(connection: C) {
349///     let watcher = Watcher::new().unwrap();
350///     let mut connection = Connection::new(connection, watcher).unwrap();
351///
352///     // Sync tables so we are up to date.
353///     connection.sync_watcher_tables().unwrap();
354///
355///     // Start a transaction
356///     connection.sql_execute("sql query here").unwrap();
357///     connection.sql_execute("sql query here").unwrap();
358///     // Commit transaction
359///
360///     // Publish changes to the watcher
361///     connection.publish_watcher_changes().unwrap();
362/// }
363/// ```
364pub struct Connection<C: SqlExecutor> {
365    state: State,
366    watcher: Arc<Watcher>,
367    connection: C,
368}
369impl<C: SqlExecutor> Connection<C> {
370    /// Create a new connection with `connection` and `watcher`.
371    ///
372    /// See [`State::start_tracking()`] for more information about initialization.
373    ///
374    /// # Errors
375    ///
376    /// Returns error if the initialization failed.
377    pub fn new(connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
378        let state = State::new();
379        State::set_pragmas().execute(&connection)?;
380        State::start_tracking().execute(&connection)?;
381        Ok(Self {
382            state,
383            watcher,
384            connection,
385        })
386    }
387
388    /// Sync tables from the [`Watcher`] and update tracking infrastructure.
389    ///
390    /// See [`State::sync_tables()`] for more information.
391    ///
392    /// # Errors
393    ///
394    /// Returns error if we failed to sync the changes to the database.
395    pub fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
396        self.state
397            .sync_tables(&self.watcher)
398            .execute(&self.connection)?;
399        Ok(())
400    }
401
402    /// Check if any tables have changed and notify the [`Watcher`]
403    ///
404    /// See [`State::publish_changes()`] for more information.
405    ///
406    /// It is recommended to call this method
407    ///
408    /// # Errors
409    ///
410    /// Returns error if we failed to check for changes.
411    pub fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
412        self.state
413            .publish_changes(&self.watcher)
414            .execute(&self.connection)?;
415        Ok(())
416    }
417
418    /// Disable all tracking on this connection.
419    ///
420    /// See [`State::stop_tracking`] for more details.
421    ///
422    /// # Errors
423    ///
424    /// Returns error if the queries failed.
425    pub fn stop_tracking(&mut self) -> Result<(), C::Error> {
426        self.state
427            .stop_tracking(&self.watcher)
428            .execute(&self.connection)?;
429        Ok(())
430    }
431
432    /// Consume the current connection and take ownership of the real sql connection.
433    ///
434    /// # Remarks
435    ///
436    /// This does not stop the tracking infrastructure enabled on the connection.
437    /// Use [`Self::stop_tracking()`] to disable it first.
438    pub fn take(self) -> C {
439        self.connection
440    }
441}
442
443/// Same as [`Connection`] but with an async executor.
444#[allow(clippy::module_name_repetitions)]
445pub struct ConnectionAsync<C: SqlExecutorAsync> {
446    state: State,
447    watcher: Arc<Watcher>,
448    connection: C,
449}
450impl<C: SqlExecutorAsync> ConnectionAsync<C> {
451    /// Create a new connection with `connection` and `watcher`.
452    ///
453    /// See [`State::start_tracking()`] for more information about initialization.
454    ///
455    /// # Errors
456    ///
457    /// Returns error if the initialization failed.
458    pub async fn new(mut connection: C, watcher: Arc<Watcher>) -> Result<Self, C::Error> {
459        let state = State::new();
460        State::set_pragmas().execute_async(&mut connection).await?;
461        State::start_tracking()
462            .execute_async(&mut connection)
463            .await?;
464        Ok(Self {
465            state,
466            watcher,
467            connection,
468        })
469    }
470
471    /// See [`Connection::sync_watcher_tables`] for more details.
472    ///
473    /// # Errors
474    ///
475    /// Returns error if we failed to sync the changes to the database.
476    pub async fn sync_watcher_tables(&mut self) -> Result<(), C::Error> {
477        self.state
478            .sync_tables(&self.watcher)
479            .execute_async(&mut self.connection)
480            .await?;
481        Ok(())
482    }
483
484    /// See [`Connection::publish_watcher_changes`] for more details.
485    ///
486    /// # Errors
487    ///
488    /// Returns error if we failed to check for changes.
489    pub async fn publish_watcher_changes(&mut self) -> Result<(), C::Error> {
490        self.state
491            .publish_changes(&self.watcher)
492            .execute_async(&mut self.connection)
493            .await?;
494        Ok(())
495    }
496
497    /// See [`Connection::stop_tracking`] for more details.
498    ///
499    /// # Errors
500    ///
501    /// Returns error if the queries failed.
502    pub async fn stop_tracking(&mut self) -> Result<(), C::Error> {
503        self.state
504            .stop_tracking(&self.watcher)
505            .execute_async(&mut self.connection)
506            .await?;
507        Ok(())
508    }
509
510    /// Consume the current connection and take ownership of the real sql connection.
511    ///
512    /// # Remarks
513    ///
514    /// This does not stop the tracking infrastructure enabled on the connection.
515    /// Use [`Self::stop_tracking()`] to disable it first.
516    pub fn take(self) -> C {
517        self.connection
518    }
519}
520
521impl<C: SqlExecutorAsync> Deref for ConnectionAsync<C> {
522    type Target = C;
523
524    fn deref(&self) -> &Self::Target {
525        &self.connection
526    }
527}
528
529impl<C: SqlExecutorAsync> DerefMut for ConnectionAsync<C> {
530    fn deref_mut(&mut self) -> &mut Self::Target {
531        &mut self.connection
532    }
533}
534
535impl<C: SqlExecutorAsync> AsRef<C> for ConnectionAsync<C> {
536    fn as_ref(&self) -> &C {
537        &self.connection
538    }
539}
540
541impl<C: SqlExecutorAsync> AsMut<C> for ConnectionAsync<C> {
542    fn as_mut(&mut self) -> &mut C {
543        &mut self.connection
544    }
545}
546
547impl<C: SqlExecutor> Deref for Connection<C> {
548    type Target = C;
549
550    fn deref(&self) -> &Self::Target {
551        &self.connection
552    }
553}
554
555impl<C: SqlExecutor> DerefMut for Connection<C> {
556    fn deref_mut(&mut self) -> &mut Self::Target {
557        &mut self.connection
558    }
559}
560
561impl<C: SqlExecutor> AsRef<C> for Connection<C> {
562    fn as_ref(&self) -> &C {
563        &self.connection
564    }
565}
566
567impl<C: SqlExecutor> AsMut<C> for Connection<C> {
568    fn as_mut(&mut self) -> &mut C {
569        &mut self.connection
570    }
571}
572
573const TRACKER_TABLE_NAME: &str = "rsqlite_watcher_version_tracker";
574
575const TRIGGER_LIST: [(&str, &str); 3] = [
576    ("INSERT", "insert"),
577    ("UPDATE", "update"),
578    ("DELETE", "delete"),
579];
580
581#[inline]
582fn create_tracking_table_query() -> String {
583    format!(
584        "CREATE TEMP TABLE IF NOT EXISTS `{TRACKER_TABLE_NAME}` (table_id INTEGER PRIMARY KEY, updated INTEGER)"
585    )
586}
587#[inline]
588fn empty_tracking_table_query() -> String {
589    format!("DELETE FROM `{TRACKER_TABLE_NAME}`")
590}
591#[inline]
592fn drop_tracking_table_query() -> String {
593    format!("DROP TABLE IF EXISTS `{TRACKER_TABLE_NAME}`")
594}
595
596#[inline]
597fn create_trigger_query(
598    table_name: &str,
599    trigger: &str,
600    trigger_name: &str,
601    table_id: usize,
602) -> String {
603    format!(
604        r"
605CREATE TEMP TRIGGER IF NOT EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}` AFTER {trigger} ON `{table_name}`
606BEGIN
607    UPDATE  `{TRACKER_TABLE_NAME}` SET updated=1 WHERE table_id={table_id};
608END
609            "
610    )
611}
612
613#[inline]
614fn insert_table_id_into_tracking_table_query(id: usize) -> String {
615    format!("INSERT INTO `{TRACKER_TABLE_NAME}` VALUES ({id},0)")
616}
617
618#[inline]
619fn drop_trigger_query(table_name: &str, trigger_name: &str) -> String {
620    format!("DROP TRIGGER IF EXISTS `{TRACKER_TABLE_NAME}_trigger_{table_name}_{trigger_name}`")
621}
622
623#[inline]
624fn remove_table_id_from_tracking_table_query(table_id: usize) -> String {
625    format!("DELETE FROM `{TRACKER_TABLE_NAME}` WHERE table_id={table_id}")
626}
627
628#[inline]
629fn select_updated_tables_query() -> String {
630    format!("SELECT table_id  FROM `{TRACKER_TABLE_NAME}` WHERE updated=1")
631}
632
633#[inline]
634fn reset_updated_tables_query() -> String {
635    format!("UPDATE `{TRACKER_TABLE_NAME}` SET updated=0 WHERE updated=1")
636}
637
638/// Create tracking triggers for `table` with `id`.
639fn create_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
640    TRIGGER_LIST
641        .iter()
642        .map(|(trigger, trigger_name)| {
643            let query = create_trigger_query(table, trigger, trigger_name, id);
644            SqlExecuteStatement::new(query)
645        })
646        .chain(std::iter::once_with(|| {
647            let query = insert_table_id_into_tracking_table_query(id);
648            SqlExecuteStatement::new(query)
649        }))
650        .collect()
651}
652
653/// Remove tracking triggers for `table` with `id`.
654fn drop_triggers(table: &str, id: usize) -> Vec<SqlExecuteStatement<String>> {
655    TRIGGER_LIST
656        .iter()
657        .map(|(_, trigger_name)| {
658            let query = drop_trigger_query(table, trigger_name);
659            SqlExecuteStatement::new(query)
660        })
661        .chain(std::iter::once_with(|| {
662            let query = remove_table_id_from_tracking_table_query(id);
663            SqlExecuteStatement::new(query)
664        }))
665        .collect()
666}
667
668/// Apply the new tracked table state to a [`State`].
669struct ConcludeStateChangeStatement<'s> {
670    state: &'s mut State,
671    tracked_tables: FixedBitSet,
672    new_version: u64,
673}
674
675impl Sealed for ConcludeStateChangeStatement<'_> {}
676impl Statement for ConcludeStateChangeStatement<'_> {
677    type Output = ();
678    fn execute<S: SqlExecutor>(self, _: &S) -> Result<Self::Output, S::Error> {
679        self.state
680            .apply_sync_changes(self.tracked_tables, self.new_version);
681        Ok(())
682    }
683
684    fn execute_mut<S: SqlExecutorMut>(self, _: &mut S) -> Result<Self::Output, S::Error> {
685        self.state
686            .apply_sync_changes(self.tracked_tables, self.new_version);
687        Ok(())
688    }
689
690    async fn execute_async<S: SqlExecutorAsync>(self, _: &mut S) -> Result<Self::Output, S::Error> {
691        self.state
692            .apply_sync_changes(self.tracked_tables, self.new_version);
693        Ok(())
694    }
695}
696
697/// Calculate what the changes to be sent to the watcher.
698struct CalculateWatcherUpdatesStatement<'s> {
699    state: &'s State,
700}
701
702impl StatementWithInput for CalculateWatcherUpdatesStatement<'_> {
703    type Input = Vec<u32>;
704    type Output = FixedBitSet;
705
706    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
707        Ok(self.state.prepare_watcher_changes(input))
708    }
709    fn execute_mut<S: SqlExecutorMut>(
710        self,
711        _: &mut S,
712        input: Self::Input,
713    ) -> Result<Self::Output, S::Error> {
714        Ok(self.state.prepare_watcher_changes(input))
715    }
716    async fn execute_async<S: SqlExecutorAsync>(
717        self,
718        _: &mut S,
719        input: Self::Input,
720    ) -> Result<Self::Output, S::Error> {
721        Ok(self.state.prepare_watcher_changes(input))
722    }
723}
724
725/// Publish the changes to the watcher.
726struct PublishWatcherChangesStatement<'w>(&'w Watcher);
727
728impl Sealed for PublishWatcherChangesStatement<'_> {}
729
730impl StatementWithInput for PublishWatcherChangesStatement<'_> {
731    type Input = FixedBitSet;
732    type Output = ();
733
734    fn execute<S: SqlExecutor>(self, _: &S, input: Self::Input) -> Result<Self::Output, S::Error> {
735        self.0.publish_changes(input);
736        Ok(())
737    }
738    fn execute_mut<S: SqlExecutorMut>(
739        self,
740        _: &mut S,
741        input: Self::Input,
742    ) -> Result<Self::Output, S::Error> {
743        self.0.publish_changes(input);
744        Ok(())
745    }
746    async fn execute_async<S: SqlExecutorAsync>(
747        self,
748        _: &mut S,
749        input: Self::Input,
750    ) -> Result<Self::Output, S::Error> {
751        self.0.publish_changes_async(input).await;
752        Ok(())
753    }
754}
755
756impl Sealed for SqlReadTableIdsStatement {}
757struct SqlReadTableIdsStatement;
758impl Statement for SqlReadTableIdsStatement {
759    type Output = Vec<u32>;
760    fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
761        connection.sql_query_values(&select_updated_tables_query())
762    }
763    fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
764        connection.sql_query_values(&select_updated_tables_query())
765    }
766    async fn execute_async<S: SqlExecutorAsync>(
767        self,
768        connection: &mut S,
769    ) -> Result<Self::Output, S::Error> {
770        connection
771            .sql_query_values(&select_updated_tables_query())
772            .await
773    }
774}
775
776/// It is possible on certain circumstances that if a connection can have leftover
777/// tracking data that is not cleared. To make sure this is reset, we force empty
778/// the table if we detect that we are not watching any tables at the moment.
779struct MaybeResetResultsQuery;
780impl StatementWithInput for MaybeResetResultsQuery {
781    type Input = FixedBitSet;
782    type Output = FixedBitSet;
783
784    fn execute<S: SqlExecutor>(
785        self,
786        connection: &S,
787        input: Self::Input,
788    ) -> Result<Self::Output, S::Error> {
789        if !input.is_clear() {
790            // Reset updated values.
791            connection.sql_execute(&reset_updated_tables_query())?;
792        }
793        Ok(input)
794    }
795    fn execute_mut<S: SqlExecutorMut>(
796        self,
797        connection: &mut S,
798        input: Self::Input,
799    ) -> Result<Self::Output, S::Error> {
800        if !input.is_clear() {
801            // Reset updated values.
802            connection.sql_execute(&reset_updated_tables_query())?;
803        }
804        Ok(input)
805    }
806    async fn execute_async<S: SqlExecutorAsync>(
807        self,
808        connection: &mut S,
809        input: Self::Input,
810    ) -> Result<Self::Output, S::Error> {
811        if !input.is_clear() {
812            // Reset updated values.
813            connection
814                .sql_execute(&reset_updated_tables_query())
815                .await?;
816        }
817        Ok(input)
818    }
819}
820
821#[cfg(test)]
822mod test {
823    use crate::connection::State;
824    use crate::watcher::tests::new_test_observer;
825    use crate::watcher::{ObservedTableOp, TableObserver, Watcher};
826    use std::collections::BTreeSet;
827    use std::sync::Mutex;
828    use std::sync::mpsc::{Receiver, SyncSender};
829
830    pub struct TestObserver {
831        expected: Mutex<Vec<BTreeSet<String>>>,
832        tables: Vec<String>,
833        // Channel is here to make sure we don't trigger a merge of multiple pending updates.
834        checked_channel: SyncSender<()>,
835    }
836
837    impl TestObserver {
838        pub fn new(
839            tables: Vec<String>,
840            expected: impl IntoIterator<Item = BTreeSet<String>>,
841        ) -> (Self, Receiver<()>) {
842            let (sender, receiver) = std::sync::mpsc::sync_channel::<()>(0);
843            let mut expected = expected.into_iter().collect::<Vec<_>>();
844            expected.reverse();
845            (
846                Self {
847                    expected: Mutex::new(expected),
848                    tables,
849                    checked_channel: sender,
850                },
851                receiver,
852            )
853        }
854    }
855
856    impl TableObserver for TestObserver {
857        fn tables(&self) -> Vec<String> {
858            self.tables.clone()
859        }
860
861        fn on_tables_changed(&self, tables: &BTreeSet<String>) {
862            let expected = self.expected.lock().unwrap().pop().unwrap();
863            assert_eq!(*tables, expected);
864            self.checked_channel.send(()).unwrap();
865        }
866    }
867
868    #[test]
869    fn connection_state() {
870        let service = Watcher::new().unwrap();
871
872        let observer_1 = new_test_observer(["foo", "bar"]);
873        let observer_2 = new_test_observer(["bar"]);
874        let observer_3 = new_test_observer(["bar", "omega"]);
875
876        let mut local_state = State::new();
877
878        assert!(local_state.should_sync(&service).is_none());
879        let observer_id_1 = service.add_observer(observer_1).unwrap();
880        let foo_table_id = service.get_table_id("foo").unwrap();
881        let bar_table_id = service.get_table_id("bar").unwrap();
882        {
883            let new_version = local_state
884                .should_sync(&service)
885                .expect("Should have new version");
886            let (tracker, ops) = local_state
887                .calculate_sync_changes(&service)
888                .expect("must have changes");
889            assert!(tracker[bar_table_id]);
890            assert!(tracker[foo_table_id]);
891            assert_eq!(ops.len(), 2);
892            assert_eq!(
893                ops[0],
894                ObservedTableOp::Add("bar".to_string(), bar_table_id)
895            );
896            assert_eq!(
897                ops[1],
898                ObservedTableOp::Add("foo".to_string(), foo_table_id)
899            );
900
901            local_state.apply_sync_changes(tracker, new_version);
902        }
903
904        let observer_id_2 = service.add_observer(observer_2).unwrap();
905        assert!(local_state.should_sync(&service).is_none());
906
907        let observer_id_3 = service.add_observer(observer_3).unwrap();
908        let omega_table_id = service.get_table_id("omega").unwrap();
909        {
910            let new_version = local_state
911                .should_sync(&service)
912                .expect("Should have new version");
913            let (tracker, ops) = local_state
914                .calculate_sync_changes(&service)
915                .expect("must have changes");
916            assert!(tracker[foo_table_id]);
917            assert!(tracker[bar_table_id]);
918            assert!(tracker[omega_table_id]);
919            assert_eq!(ops.len(), 1);
920            assert_eq!(
921                ops[0],
922                ObservedTableOp::Add("omega".to_string(), omega_table_id)
923            );
924
925            local_state.apply_sync_changes(tracker, new_version);
926        }
927
928        service.remove_observer(observer_id_2).unwrap();
929        assert!(local_state.should_sync(&service).is_none());
930
931        service.remove_observer(observer_id_3).unwrap();
932        {
933            let new_version = local_state
934                .should_sync(&service)
935                .expect("Should have new version");
936            let (tracker, ops) = local_state
937                .calculate_sync_changes(&service)
938                .expect("must have changes");
939            assert!(tracker[foo_table_id]);
940            assert!(tracker[bar_table_id]);
941            assert!(!tracker[omega_table_id]);
942            assert_eq!(ops.len(), 1);
943            assert_eq!(
944                ops[0],
945                ObservedTableOp::Remove("omega".to_string(), omega_table_id)
946            );
947
948            local_state.apply_sync_changes(tracker, new_version);
949        }
950
951        service.remove_observer(observer_id_1).unwrap();
952        {
953            let new_version = local_state
954                .should_sync(&service)
955                .expect("Should have new version");
956            let (tracker, ops) = local_state
957                .calculate_sync_changes(&service)
958                .expect("must have changes");
959            assert!(!tracker[foo_table_id]);
960            assert!(!tracker[bar_table_id]);
961            assert!(!tracker[omega_table_id]);
962            assert_eq!(ops.len(), 2);
963            assert_eq!(
964                ops[1],
965                ObservedTableOp::Remove("foo".to_string(), foo_table_id)
966            );
967            assert_eq!(
968                ops[0],
969                ObservedTableOp::Remove("bar".to_string(), bar_table_id)
970            );
971
972            local_state.apply_sync_changes(tracker, new_version);
973        }
974    }
975}