1use sqlx::{PgPool, Postgres, Transaction};
4use std::future::Future;
5
6use crate::error::Result;
7
8pub struct TransactionManager {
10 pool: PgPool,
11}
12
13impl TransactionManager {
14 pub fn new(pool: PgPool) -> Self {
16 Self { pool }
17 }
18
19 pub async fn begin(&self) -> Result<Transaction<'static, Postgres>> {
21 Ok(self.pool.begin().await?)
22 }
23
24 pub fn pool(&self) -> &PgPool {
26 &self.pool
27 }
28}
29
30pub struct TransactionBuilder {
32 pool: PgPool,
33 isolation_level: IsolationLevel,
34}
35
36#[derive(Debug, Clone, Copy, Default)]
38pub enum IsolationLevel {
39 #[default]
41 ReadCommitted,
42 RepeatableRead,
44 Serializable,
46}
47
48impl IsolationLevel {
49 fn as_sql(&self) -> &'static str {
50 match self {
51 Self::ReadCommitted => "READ COMMITTED",
52 Self::RepeatableRead => "REPEATABLE READ",
53 Self::Serializable => "SERIALIZABLE",
54 }
55 }
56}
57
58impl TransactionBuilder {
59 pub fn new(pool: PgPool) -> Self {
61 Self {
62 pool,
63 isolation_level: IsolationLevel::default(),
64 }
65 }
66
67 pub fn isolation_level(mut self, level: IsolationLevel) -> Self {
69 self.isolation_level = level;
70 self
71 }
72
73 pub async fn begin(self) -> Result<Transaction<'static, Postgres>> {
75 let mut tx = self.pool.begin().await?;
76
77 let query = format!(
79 "SET TRANSACTION ISOLATION LEVEL {}",
80 self.isolation_level.as_sql()
81 );
82 sqlx::query(&query).execute(&mut *tx).await?;
83
84 Ok(tx)
85 }
86}
87
88pub async fn with_savepoint<T, F, Fut>(
94 tx: &mut Transaction<'_, Postgres>,
95 name: &str,
96 f: F,
97) -> Result<T>
98where
99 F: FnOnce(&mut Transaction<'_, Postgres>) -> Fut,
100 Fut: Future<Output = Result<T>>,
101{
102 let create_query = format!("SAVEPOINT {}", name);
104 sqlx::query(&create_query).execute(&mut **tx).await?;
105
106 match f(tx).await {
107 Ok(result) => {
108 let release_query = format!("RELEASE SAVEPOINT {}", name);
110 sqlx::query(&release_query).execute(&mut **tx).await?;
111 Ok(result)
112 }
113 Err(e) => {
114 let rollback_query = format!("ROLLBACK TO SAVEPOINT {}", name);
116 sqlx::query(&rollback_query).execute(&mut **tx).await?;
117 Err(e)
118 }
119 }
120}
121
122pub struct SavepointGuard {
124 name: String,
125 committed: bool,
126}
127
128impl SavepointGuard {
129 pub async fn new(tx: &mut Transaction<'_, Postgres>, name: &str) -> Result<Self> {
131 let create_query = format!("SAVEPOINT {}", name);
132 sqlx::query(&create_query).execute(&mut **tx).await?;
133
134 Ok(Self {
135 name: name.to_string(),
136 committed: false,
137 })
138 }
139
140 pub async fn release(mut self, tx: &mut Transaction<'_, Postgres>) -> Result<()> {
142 let release_query = format!("RELEASE SAVEPOINT {}", self.name);
143 sqlx::query(&release_query).execute(&mut **tx).await?;
144 self.committed = true;
145 Ok(())
146 }
147
148 pub async fn rollback(mut self, tx: &mut Transaction<'_, Postgres>) -> Result<()> {
150 let rollback_query = format!("ROLLBACK TO SAVEPOINT {}", self.name);
151 sqlx::query(&rollback_query).execute(&mut **tx).await?;
152 self.committed = true;
153 Ok(())
154 }
155
156 pub fn name(&self) -> &str {
158 &self.name
159 }
160
161 pub fn is_handled(&self) -> bool {
163 self.committed
164 }
165}
166
167impl Drop for SavepointGuard {
168 fn drop(&mut self) {
169 if !self.committed {
170 tracing::warn!(
171 savepoint = %self.name,
172 "Savepoint guard dropped without release or rollback"
173 );
174 }
175 }
176}
177
178#[async_trait::async_trait]
180pub trait TransactionExt {
181 async fn create_savepoint(&mut self, name: &str) -> Result<()>;
183
184 async fn release_savepoint(&mut self, name: &str) -> Result<()>;
186
187 async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()>;
189}
190
191#[async_trait::async_trait]
192impl TransactionExt for Transaction<'_, Postgres> {
193 async fn create_savepoint(&mut self, name: &str) -> Result<()> {
194 let query = format!("SAVEPOINT {}", name);
195 sqlx::query(&query).execute(&mut **self).await?;
196 Ok(())
197 }
198
199 async fn release_savepoint(&mut self, name: &str) -> Result<()> {
200 let query = format!("RELEASE SAVEPOINT {}", name);
201 sqlx::query(&query).execute(&mut **self).await?;
202 Ok(())
203 }
204
205 async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
206 let query = format!("ROLLBACK TO SAVEPOINT {}", name);
207 sqlx::query(&query).execute(&mut **self).await?;
208 Ok(())
209 }
210}
211
212#[macro_export]
221macro_rules! nested_transaction {
222 ($tx:expr, $name:expr, $body:block) => {{
223 use $crate::transaction::TransactionExt;
224
225 $tx.create_savepoint($name).await?;
226 let result = (|| async $body)().await;
227
228 match result {
229 Ok(value) => {
230 $tx.release_savepoint($name).await?;
231 Ok(value)
232 }
233 Err(e) => {
234 $tx.rollback_to_savepoint($name).await?;
235 Err(e)
236 }
237 }
238 }};
239}
240
241#[derive(Debug, Clone)]
243pub struct TransactionRetryConfig {
244 pub max_retries: u32,
246 pub initial_backoff_ms: u64,
248 pub max_backoff_ms: u64,
250 pub backoff_multiplier: f64,
252}
253
254impl Default for TransactionRetryConfig {
255 fn default() -> Self {
256 Self {
257 max_retries: 3,
258 initial_backoff_ms: 10,
259 max_backoff_ms: 1000,
260 backoff_multiplier: 2.0,
261 }
262 }
263}
264
265pub async fn retry_transaction<T, F, Fut>(
300 pool: PgPool,
301 config: TransactionRetryConfig,
302 f: F,
303) -> Result<T>
304where
305 F: Fn(PgPool) -> Fut,
306 Fut: Future<Output = Result<T>>,
307{
308 let mut attempt = 0;
309 let mut backoff_ms = config.initial_backoff_ms;
310
311 loop {
312 attempt += 1;
313
314 match f(pool.clone()).await {
315 Ok(result) => {
316 return Ok(result);
317 }
318 Err(e) => {
319 let is_retriable = is_retriable_error(&e);
321
322 if !is_retriable || attempt >= config.max_retries {
323 tracing::warn!(
324 attempt = attempt,
325 max_retries = config.max_retries,
326 error = %e,
327 "Transaction failed after retries"
328 );
329 return Err(e);
330 }
331
332 let jitter = (rand::random::<f64>() * 0.3) + 0.85; let sleep_ms = (backoff_ms as f64 * jitter) as u64;
335
336 tracing::debug!(
337 attempt = attempt,
338 max_retries = config.max_retries,
339 backoff_ms = sleep_ms,
340 error = %e,
341 "Transaction failed, retrying"
342 );
343
344 tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
345
346 backoff_ms = ((backoff_ms as f64 * config.backoff_multiplier) as u64)
348 .min(config.max_backoff_ms);
349 }
350 }
351 }
352}
353
354fn is_retriable_error(error: &crate::error::DbError) -> bool {
356 match error {
357 crate::error::DbError::Sqlx(sqlx_error) => {
358 if let Some(db_error) = sqlx_error.as_database_error() {
359 let code = db_error.code();
360 code.as_deref() == Some("40001") || code.as_deref() == Some("40P01")
363 } else {
364 false
365 }
366 }
367 _ => false,
368 }
369}
370
371pub async fn retry_transaction_with_isolation<T, F, Fut>(
379 pool: PgPool,
380 config: TransactionRetryConfig,
381 isolation_level: IsolationLevel,
382 f: F,
383) -> Result<T>
384where
385 F: Fn(PgPool, IsolationLevel) -> Fut,
386 Fut: Future<Output = Result<T>>,
387{
388 let mut attempt = 0;
389 let mut backoff_ms = config.initial_backoff_ms;
390
391 loop {
392 attempt += 1;
393
394 match f(pool.clone(), isolation_level).await {
395 Ok(result) => {
396 return Ok(result);
397 }
398 Err(e) => {
399 let is_retriable = is_retriable_error(&e);
400
401 if !is_retriable || attempt >= config.max_retries {
402 tracing::warn!(
403 attempt = attempt,
404 max_retries = config.max_retries,
405 isolation_level = ?isolation_level,
406 error = %e,
407 "Transaction failed after retries"
408 );
409 return Err(e);
410 }
411
412 let jitter = (rand::random::<f64>() * 0.3) + 0.85;
413 let sleep_ms = (backoff_ms as f64 * jitter) as u64;
414
415 tracing::debug!(
416 attempt = attempt,
417 max_retries = config.max_retries,
418 backoff_ms = sleep_ms,
419 isolation_level = ?isolation_level,
420 error = %e,
421 "Transaction failed, retrying"
422 );
423
424 tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
425 backoff_ms = ((backoff_ms as f64 * config.backoff_multiplier) as u64)
426 .min(config.max_backoff_ms);
427 }
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn test_retry_config_default() {
438 let config = TransactionRetryConfig::default();
439 assert_eq!(config.max_retries, 3);
440 assert_eq!(config.initial_backoff_ms, 10);
441 assert_eq!(config.max_backoff_ms, 1000);
442 assert_eq!(config.backoff_multiplier, 2.0);
443 }
444
445 #[test]
446 fn test_retry_config_custom() {
447 let config = TransactionRetryConfig {
448 max_retries: 5,
449 initial_backoff_ms: 50,
450 max_backoff_ms: 5000,
451 backoff_multiplier: 1.5,
452 };
453 assert_eq!(config.max_retries, 5);
454 assert_eq!(config.initial_backoff_ms, 50);
455 assert_eq!(config.max_backoff_ms, 5000);
456 assert_eq!(config.backoff_multiplier, 1.5);
457 }
458
459 #[test]
460 fn test_isolation_level_sql() {
461 assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
462 assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
463 assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
464 }
465}