1use std::error::Error;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use hydracache_core::CacheCodec;
9use hydracache_db::{
10 CollectedInvalidationReport, CommitPosition, DbCache, InvalidationCollector,
11 SqlxInvalidationOutbox,
12};
13use sqlx::{Sqlite, SqlitePool, Transaction};
14
15use crate::{SqlxTransactionError, TransactionResult};
16
17pub type SqlxTransactionFuture<'tx, E> =
19 Pin<Box<dyn Future<Output = std::result::Result<(), E>> + Send + 'tx>>;
20
21pub trait SqlxTransactionExt<C>
23where
24 C: CacheCodec,
25{
26 fn sqlx_transactions(&self) -> SqlxTransactionCompanion<C>;
28}
29
30impl<C> SqlxTransactionExt<C> for DbCache<C>
31where
32 C: CacheCodec,
33{
34 fn sqlx_transactions(&self) -> SqlxTransactionCompanion<C> {
35 SqlxTransactionCompanion::new(self.clone())
36 }
37}
38
39#[derive(Debug)]
45pub struct SqlxTransactionCompanion<C = hydracache::PostcardCodec>
46where
47 C: CacheCodec,
48{
49 queries: DbCache<C>,
50 outbox: Option<SqlxInvalidationOutbox>,
51 counters: Arc<SqlxTransactionCounters>,
52}
53
54impl<C> Clone for SqlxTransactionCompanion<C>
55where
56 C: CacheCodec,
57{
58 fn clone(&self) -> Self {
59 Self {
60 queries: self.queries.clone(),
61 outbox: self.outbox.clone(),
62 counters: self.counters.clone(),
63 }
64 }
65}
66
67impl<C> SqlxTransactionCompanion<C>
68where
69 C: CacheCodec,
70{
71 pub fn new(queries: DbCache<C>) -> Self {
73 Self {
74 queries,
75 outbox: None,
76 counters: Arc::new(SqlxTransactionCounters::default()),
77 }
78 }
79
80 pub fn with_outbox(mut self, outbox: SqlxInvalidationOutbox) -> Self {
82 self.outbox = Some(outbox);
83 self
84 }
85
86 pub fn diagnostics(&self) -> SqlxTransactionDiagnostics {
88 self.counters.snapshot()
89 }
90
91 pub async fn sqlite_durable<F, E>(
93 &self,
94 pool: &SqlitePool,
95 reason: impl Into<String>,
96 body: F,
97 ) -> TransactionResult<SqlxTransactionReport, E>
98 where
99 F: for<'tx> FnOnce(
100 &'tx mut Transaction<'_, Sqlite>,
101 &'tx mut InvalidationCollector,
102 ) -> SqlxTransactionFuture<'tx, E>
103 + Send,
104 E: Error + Send + Sync + 'static,
105 {
106 let Some(outbox) = self.outbox.as_ref() else {
107 self.counters
108 .enqueue_failures
109 .fetch_add(1, Ordering::Relaxed);
110 return Err(SqlxTransactionError::MissingOutbox);
111 };
112
113 let mut tx = pool.begin().await.map_err(SqlxTransactionError::Sqlx)?;
114 let mut collector = InvalidationCollector::new(self.queries.namespace(), reason);
115
116 if let Err(error) = body(&mut tx, &mut collector).await {
117 self.counters.body_errors.fetch_add(1, Ordering::Relaxed);
118 rollback_and_count(tx, &self.counters).await;
119 return Err(SqlxTransactionError::Body(error));
120 }
121
122 let collected = collector.into_collected();
123 let intent_count = collected.len();
124 let inserted = if collected.is_empty() {
125 0
126 } else {
127 match outbox
128 .enqueue_in_sqlite_tx(
129 &mut tx,
130 collected.namespace(),
131 &sqlite_commit_position(),
132 collected.batch(),
133 )
134 .await
135 {
136 Ok(inserted) => inserted,
137 Err(error) => {
138 self.counters
139 .enqueue_failures
140 .fetch_add(1, Ordering::Relaxed);
141 rollback_and_count(tx, &self.counters).await;
142 return Err(SqlxTransactionError::Outbox(error));
143 }
144 }
145 };
146
147 match tx.commit().await {
148 Ok(()) => {
149 self.counters.commits.fetch_add(1, Ordering::Relaxed);
150 Ok(SqlxTransactionReport {
151 intent_count,
152 durable_rows: inserted,
153 local_report: None,
154 })
155 }
156 Err(error) => {
157 self.counters
158 .commit_failures
159 .fetch_add(1, Ordering::Relaxed);
160 Err(SqlxTransactionError::Sqlx(error))
161 }
162 }
163 }
164
165 pub async fn sqlite_local<F, E>(
171 &self,
172 pool: &SqlitePool,
173 reason: impl Into<String>,
174 body: F,
175 ) -> TransactionResult<SqlxTransactionReport, E>
176 where
177 F: for<'tx> FnOnce(
178 &'tx mut Transaction<'_, Sqlite>,
179 &'tx mut InvalidationCollector,
180 ) -> SqlxTransactionFuture<'tx, E>
181 + Send,
182 E: Error + Send + Sync + 'static,
183 {
184 let mut tx = pool.begin().await.map_err(SqlxTransactionError::Sqlx)?;
185 let mut collector = InvalidationCollector::new(self.queries.namespace(), reason);
186
187 if let Err(error) = body(&mut tx, &mut collector).await {
188 self.counters.body_errors.fetch_add(1, Ordering::Relaxed);
189 rollback_and_count(tx, &self.counters).await;
190 return Err(SqlxTransactionError::Body(error));
191 }
192
193 let collected = collector.into_collected();
194 let intent_count = collected.len();
195
196 match tx.commit().await {
197 Ok(()) => {
198 self.counters.commits.fetch_add(1, Ordering::Relaxed);
199 }
200 Err(error) => {
201 self.counters
202 .commit_failures
203 .fetch_add(1, Ordering::Relaxed);
204 return Err(SqlxTransactionError::Sqlx(error));
205 }
206 }
207
208 let local_report = collected
209 .execute_local(self.queries.cache())
210 .await
211 .map_err(SqlxTransactionError::LocalInvalidation)?;
212 self.counters
213 .local_invalidations
214 .fetch_add(1, Ordering::Relaxed);
215
216 Ok(SqlxTransactionReport {
217 intent_count,
218 durable_rows: 0,
219 local_report: Some(local_report),
220 })
221 }
222}
223
224#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
226pub struct SqlxTransactionReport {
227 pub intent_count: usize,
229 pub durable_rows: usize,
231 pub local_report: Option<CollectedInvalidationReport>,
233}
234
235#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
237pub struct SqlxTransactionDiagnostics {
238 pub commits: u64,
240 pub rollbacks: u64,
242 pub body_errors: u64,
244 pub enqueue_failures: u64,
246 pub commit_failures: u64,
248 pub local_invalidations: u64,
250}
251
252#[derive(Debug, Default)]
253struct SqlxTransactionCounters {
254 commits: AtomicU64,
255 rollbacks: AtomicU64,
256 body_errors: AtomicU64,
257 enqueue_failures: AtomicU64,
258 commit_failures: AtomicU64,
259 local_invalidations: AtomicU64,
260}
261
262impl SqlxTransactionCounters {
263 fn snapshot(&self) -> SqlxTransactionDiagnostics {
264 SqlxTransactionDiagnostics {
265 commits: self.commits.load(Ordering::Relaxed),
266 rollbacks: self.rollbacks.load(Ordering::Relaxed),
267 body_errors: self.body_errors.load(Ordering::Relaxed),
268 enqueue_failures: self.enqueue_failures.load(Ordering::Relaxed),
269 commit_failures: self.commit_failures.load(Ordering::Relaxed),
270 local_invalidations: self.local_invalidations.load(Ordering::Relaxed),
271 }
272 }
273}
274
275async fn rollback_and_count(tx: Transaction<'_, Sqlite>, counters: &SqlxTransactionCounters) {
276 if tx.rollback().await.is_ok() {
277 counters.rollbacks.fetch_add(1, Ordering::Relaxed);
278 }
279}
280
281fn sqlite_commit_position() -> CommitPosition {
282 static NEXT: AtomicU64 = AtomicU64::new(1);
283
284 let now = SystemTime::now()
285 .duration_since(UNIX_EPOCH)
286 .unwrap_or_default()
287 .as_nanos();
288 let sequence = NEXT.fetch_add(1, Ordering::Relaxed);
289 CommitPosition::new(format!("sqlite:{now}:{sequence}"))
290}