Skip to main content

hydracache_sqlx/
transaction.rs

1use std::error::Error;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use hydracache_core::CacheCodec;
9use hydracache_db::{
10    CollectedInvalidationReport, CommitPosition, DbCache, InvalidationCollector,
11    SqlxInvalidationOutbox,
12};
13use sqlx::{Sqlite, SqlitePool, Transaction};
14
15use crate::{SqlxTransactionError, TransactionResult};
16
17/// Boxed future returned by SQLx transaction companion closures.
18pub type SqlxTransactionFuture<'tx, E> =
19    Pin<Box<dyn Future<Output = std::result::Result<(), E>> + Send + 'tx>>;
20
21/// Extension trait that creates SQLx transaction companions from [`DbCache`].
22pub trait SqlxTransactionExt<C>
23where
24    C: CacheCodec,
25{
26    /// Create a transaction companion using this database cache namespace.
27    fn sqlx_transactions(&self) -> SqlxTransactionCompanion<C>;
28}
29
30impl<C> SqlxTransactionExt<C> for DbCache<C>
31where
32    C: CacheCodec,
33{
34    fn sqlx_transactions(&self) -> SqlxTransactionCompanion<C> {
35        SqlxTransactionCompanion::new(self.clone())
36    }
37}
38
39/// SQLx transaction companion.
40///
41/// The companion begins a transaction, gives user code explicit access to the
42/// SQLx transaction and an invalidation collector, then either enqueues durable
43/// outbox rows before commit or applies local-only invalidation after commit.
44#[derive(Debug)]
45pub struct SqlxTransactionCompanion<C = hydracache::PostcardCodec>
46where
47    C: CacheCodec,
48{
49    queries: DbCache<C>,
50    outbox: Option<SqlxInvalidationOutbox>,
51    counters: Arc<SqlxTransactionCounters>,
52}
53
54impl<C> Clone for SqlxTransactionCompanion<C>
55where
56    C: CacheCodec,
57{
58    fn clone(&self) -> Self {
59        Self {
60            queries: self.queries.clone(),
61            outbox: self.outbox.clone(),
62            counters: self.counters.clone(),
63        }
64    }
65}
66
67impl<C> SqlxTransactionCompanion<C>
68where
69    C: CacheCodec,
70{
71    /// Create a companion for the given query cache namespace.
72    pub fn new(queries: DbCache<C>) -> Self {
73        Self {
74            queries,
75            outbox: None,
76            counters: Arc::new(SqlxTransactionCounters::default()),
77        }
78    }
79
80    /// Attach a durable SQLx invalidation outbox.
81    pub fn with_outbox(mut self, outbox: SqlxInvalidationOutbox) -> Self {
82        self.outbox = Some(outbox);
83        self
84    }
85
86    /// Return transaction companion diagnostics.
87    pub fn diagnostics(&self) -> SqlxTransactionDiagnostics {
88        self.counters.snapshot()
89    }
90
91    /// Execute a SQLite transaction and enqueue invalidation intent before commit.
92    pub async fn sqlite_durable<F, E>(
93        &self,
94        pool: &SqlitePool,
95        reason: impl Into<String>,
96        body: F,
97    ) -> TransactionResult<SqlxTransactionReport, E>
98    where
99        F: for<'tx> FnOnce(
100                &'tx mut Transaction<'_, Sqlite>,
101                &'tx mut InvalidationCollector,
102            ) -> SqlxTransactionFuture<'tx, E>
103            + Send,
104        E: Error + Send + Sync + 'static,
105    {
106        let Some(outbox) = self.outbox.as_ref() else {
107            self.counters
108                .enqueue_failures
109                .fetch_add(1, Ordering::Relaxed);
110            return Err(SqlxTransactionError::MissingOutbox);
111        };
112
113        let mut tx = pool.begin().await.map_err(SqlxTransactionError::Sqlx)?;
114        let mut collector = InvalidationCollector::new(self.queries.namespace(), reason);
115
116        if let Err(error) = body(&mut tx, &mut collector).await {
117            self.counters.body_errors.fetch_add(1, Ordering::Relaxed);
118            rollback_and_count(tx, &self.counters).await;
119            return Err(SqlxTransactionError::Body(error));
120        }
121
122        let collected = collector.into_collected();
123        let intent_count = collected.len();
124        let inserted = if collected.is_empty() {
125            0
126        } else {
127            match outbox
128                .enqueue_in_sqlite_tx(
129                    &mut tx,
130                    collected.namespace(),
131                    &sqlite_commit_position(),
132                    collected.batch(),
133                )
134                .await
135            {
136                Ok(inserted) => inserted,
137                Err(error) => {
138                    self.counters
139                        .enqueue_failures
140                        .fetch_add(1, Ordering::Relaxed);
141                    rollback_and_count(tx, &self.counters).await;
142                    return Err(SqlxTransactionError::Outbox(error));
143                }
144            }
145        };
146
147        match tx.commit().await {
148            Ok(()) => {
149                self.counters.commits.fetch_add(1, Ordering::Relaxed);
150                Ok(SqlxTransactionReport {
151                    intent_count,
152                    durable_rows: inserted,
153                    local_report: None,
154                })
155            }
156            Err(error) => {
157                self.counters
158                    .commit_failures
159                    .fetch_add(1, Ordering::Relaxed);
160                Err(SqlxTransactionError::Sqlx(error))
161            }
162        }
163    }
164
165    /// Execute a SQLite transaction and apply invalidation directly after commit.
166    ///
167    /// This mode is intentionally non-durable. It is useful for local demos and
168    /// single-process tests, while production writes should use
169    /// [`SqlxTransactionCompanion::sqlite_durable`].
170    pub async fn sqlite_local<F, E>(
171        &self,
172        pool: &SqlitePool,
173        reason: impl Into<String>,
174        body: F,
175    ) -> TransactionResult<SqlxTransactionReport, E>
176    where
177        F: for<'tx> FnOnce(
178                &'tx mut Transaction<'_, Sqlite>,
179                &'tx mut InvalidationCollector,
180            ) -> SqlxTransactionFuture<'tx, E>
181            + Send,
182        E: Error + Send + Sync + 'static,
183    {
184        let mut tx = pool.begin().await.map_err(SqlxTransactionError::Sqlx)?;
185        let mut collector = InvalidationCollector::new(self.queries.namespace(), reason);
186
187        if let Err(error) = body(&mut tx, &mut collector).await {
188            self.counters.body_errors.fetch_add(1, Ordering::Relaxed);
189            rollback_and_count(tx, &self.counters).await;
190            return Err(SqlxTransactionError::Body(error));
191        }
192
193        let collected = collector.into_collected();
194        let intent_count = collected.len();
195
196        match tx.commit().await {
197            Ok(()) => {
198                self.counters.commits.fetch_add(1, Ordering::Relaxed);
199            }
200            Err(error) => {
201                self.counters
202                    .commit_failures
203                    .fetch_add(1, Ordering::Relaxed);
204                return Err(SqlxTransactionError::Sqlx(error));
205            }
206        }
207
208        let local_report = collected
209            .execute_local(self.queries.cache())
210            .await
211            .map_err(SqlxTransactionError::LocalInvalidation)?;
212        self.counters
213            .local_invalidations
214            .fetch_add(1, Ordering::Relaxed);
215
216        Ok(SqlxTransactionReport {
217            intent_count,
218            durable_rows: 0,
219            local_report: Some(local_report),
220        })
221    }
222}
223
224/// Result of a SQLx transaction companion run.
225#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
226pub struct SqlxTransactionReport {
227    /// Number of invalidation intents collected by user code.
228    pub intent_count: usize,
229    /// Number of durable outbox rows inserted.
230    pub durable_rows: usize,
231    /// Local invalidation report for non-durable mode.
232    pub local_report: Option<CollectedInvalidationReport>,
233}
234
235/// Diagnostics for SQLx transaction companion runs.
236#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
237pub struct SqlxTransactionDiagnostics {
238    /// Successful transaction commits.
239    pub commits: u64,
240    /// Rollbacks attempted after body/enqueue failures.
241    pub rollbacks: u64,
242    /// User closure failures.
243    pub body_errors: u64,
244    /// Outbox enqueue failures.
245    pub enqueue_failures: u64,
246    /// Commit failures after a successful user body.
247    pub commit_failures: u64,
248    /// Local non-durable invalidation applications.
249    pub local_invalidations: u64,
250}
251
252#[derive(Debug, Default)]
253struct SqlxTransactionCounters {
254    commits: AtomicU64,
255    rollbacks: AtomicU64,
256    body_errors: AtomicU64,
257    enqueue_failures: AtomicU64,
258    commit_failures: AtomicU64,
259    local_invalidations: AtomicU64,
260}
261
262impl SqlxTransactionCounters {
263    fn snapshot(&self) -> SqlxTransactionDiagnostics {
264        SqlxTransactionDiagnostics {
265            commits: self.commits.load(Ordering::Relaxed),
266            rollbacks: self.rollbacks.load(Ordering::Relaxed),
267            body_errors: self.body_errors.load(Ordering::Relaxed),
268            enqueue_failures: self.enqueue_failures.load(Ordering::Relaxed),
269            commit_failures: self.commit_failures.load(Ordering::Relaxed),
270            local_invalidations: self.local_invalidations.load(Ordering::Relaxed),
271        }
272    }
273}
274
275async fn rollback_and_count(tx: Transaction<'_, Sqlite>, counters: &SqlxTransactionCounters) {
276    if tx.rollback().await.is_ok() {
277        counters.rollbacks.fetch_add(1, Ordering::Relaxed);
278    }
279}
280
281fn sqlite_commit_position() -> CommitPosition {
282    static NEXT: AtomicU64 = AtomicU64::new(1);
283
284    let now = SystemTime::now()
285        .duration_since(UNIX_EPOCH)
286        .unwrap_or_default()
287        .as_nanos();
288    let sequence = NEXT.fetch_add(1, Ordering::Relaxed);
289    CommitPosition::new(format!("sqlite:{now}:{sequence}"))
290}