kaccy_db/
transaction.rs

1//! Transaction management with savepoint support
2
3use sqlx::{PgPool, Postgres, Transaction};
4use std::future::Future;
5
6use crate::error::Result;
7
8/// Transaction manager with savepoint support for nested transactions
9pub struct TransactionManager {
10    pool: PgPool,
11}
12
13impl TransactionManager {
14    /// Create a new transaction manager
15    pub fn new(pool: PgPool) -> Self {
16        Self { pool }
17    }
18
19    /// Begin a new transaction
20    pub async fn begin(&self) -> Result<Transaction<'static, Postgres>> {
21        Ok(self.pool.begin().await?)
22    }
23
24    /// Get a reference to the pool
25    pub fn pool(&self) -> &PgPool {
26        &self.pool
27    }
28}
29
30/// Builder for complex transactional operations
31pub struct TransactionBuilder {
32    pool: PgPool,
33    isolation_level: IsolationLevel,
34}
35
36/// Transaction isolation levels
37#[derive(Debug, Clone, Copy, Default)]
38pub enum IsolationLevel {
39    /// Read committed (PostgreSQL default)
40    #[default]
41    ReadCommitted,
42    /// Repeatable read
43    RepeatableRead,
44    /// Serializable (strongest isolation)
45    Serializable,
46}
47
48impl IsolationLevel {
49    fn as_sql(&self) -> &'static str {
50        match self {
51            Self::ReadCommitted => "READ COMMITTED",
52            Self::RepeatableRead => "REPEATABLE READ",
53            Self::Serializable => "SERIALIZABLE",
54        }
55    }
56}
57
58impl TransactionBuilder {
59    /// Create a new transaction builder
60    pub fn new(pool: PgPool) -> Self {
61        Self {
62            pool,
63            isolation_level: IsolationLevel::default(),
64        }
65    }
66
67    /// Set the isolation level
68    pub fn isolation_level(mut self, level: IsolationLevel) -> Self {
69        self.isolation_level = level;
70        self
71    }
72
73    /// Begin the transaction with configured isolation level
74    pub async fn begin(self) -> Result<Transaction<'static, Postgres>> {
75        let mut tx = self.pool.begin().await?;
76
77        // Set isolation level
78        let query = format!(
79            "SET TRANSACTION ISOLATION LEVEL {}",
80            self.isolation_level.as_sql()
81        );
82        sqlx::query(&query).execute(&mut *tx).await?;
83
84        Ok(tx)
85    }
86}
87
88/// Helper function to run an operation within a savepoint
89///
90/// Creates a savepoint before executing the operation.
91/// If the operation succeeds, the savepoint is released.
92/// If the operation fails, the transaction is rolled back to the savepoint.
93pub async fn with_savepoint<T, F, Fut>(
94    tx: &mut Transaction<'_, Postgres>,
95    name: &str,
96    f: F,
97) -> Result<T>
98where
99    F: FnOnce(&mut Transaction<'_, Postgres>) -> Fut,
100    Fut: Future<Output = Result<T>>,
101{
102    // Create savepoint
103    let create_query = format!("SAVEPOINT {}", name);
104    sqlx::query(&create_query).execute(&mut **tx).await?;
105
106    match f(tx).await {
107        Ok(result) => {
108            // Release savepoint on success
109            let release_query = format!("RELEASE SAVEPOINT {}", name);
110            sqlx::query(&release_query).execute(&mut **tx).await?;
111            Ok(result)
112        }
113        Err(e) => {
114            // Rollback to savepoint on failure
115            let rollback_query = format!("ROLLBACK TO SAVEPOINT {}", name);
116            sqlx::query(&rollback_query).execute(&mut **tx).await?;
117            Err(e)
118        }
119    }
120}
121
122/// Savepoint guard that tracks savepoint state
123pub struct SavepointGuard {
124    name: String,
125    committed: bool,
126}
127
128impl SavepointGuard {
129    /// Create a new savepoint within a transaction
130    pub async fn new(tx: &mut Transaction<'_, Postgres>, name: &str) -> Result<Self> {
131        let create_query = format!("SAVEPOINT {}", name);
132        sqlx::query(&create_query).execute(&mut **tx).await?;
133
134        Ok(Self {
135            name: name.to_string(),
136            committed: false,
137        })
138    }
139
140    /// Release the savepoint (commit the nested transaction)
141    pub async fn release(mut self, tx: &mut Transaction<'_, Postgres>) -> Result<()> {
142        let release_query = format!("RELEASE SAVEPOINT {}", self.name);
143        sqlx::query(&release_query).execute(&mut **tx).await?;
144        self.committed = true;
145        Ok(())
146    }
147
148    /// Rollback to the savepoint
149    pub async fn rollback(mut self, tx: &mut Transaction<'_, Postgres>) -> Result<()> {
150        let rollback_query = format!("ROLLBACK TO SAVEPOINT {}", self.name);
151        sqlx::query(&rollback_query).execute(&mut **tx).await?;
152        self.committed = true;
153        Ok(())
154    }
155
156    /// Get the savepoint name
157    pub fn name(&self) -> &str {
158        &self.name
159    }
160
161    /// Check if the savepoint has been committed or rolled back
162    pub fn is_handled(&self) -> bool {
163        self.committed
164    }
165}
166
167impl Drop for SavepointGuard {
168    fn drop(&mut self) {
169        if !self.committed {
170            tracing::warn!(
171                savepoint = %self.name,
172                "Savepoint guard dropped without release or rollback"
173            );
174        }
175    }
176}
177
178/// Extension trait for transactions with savepoint support
179#[async_trait::async_trait]
180pub trait TransactionExt {
181    /// Create a savepoint within this transaction
182    async fn create_savepoint(&mut self, name: &str) -> Result<()>;
183
184    /// Release a savepoint
185    async fn release_savepoint(&mut self, name: &str) -> Result<()>;
186
187    /// Rollback to a savepoint
188    async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()>;
189}
190
191#[async_trait::async_trait]
192impl TransactionExt for Transaction<'_, Postgres> {
193    async fn create_savepoint(&mut self, name: &str) -> Result<()> {
194        let query = format!("SAVEPOINT {}", name);
195        sqlx::query(&query).execute(&mut **self).await?;
196        Ok(())
197    }
198
199    async fn release_savepoint(&mut self, name: &str) -> Result<()> {
200        let query = format!("RELEASE SAVEPOINT {}", name);
201        sqlx::query(&query).execute(&mut **self).await?;
202        Ok(())
203    }
204
205    async fn rollback_to_savepoint(&mut self, name: &str) -> Result<()> {
206        let query = format!("ROLLBACK TO SAVEPOINT {}", name);
207        sqlx::query(&query).execute(&mut **self).await?;
208        Ok(())
209    }
210}
211
212/// Macro for executing code within a savepoint
213///
214/// Usage:
215/// ```ignore
216/// nested_transaction!(tx, "savepoint_name", {
217///     // your code here
218/// })
219/// ```
220#[macro_export]
221macro_rules! nested_transaction {
222    ($tx:expr, $name:expr, $body:block) => {{
223        use $crate::transaction::TransactionExt;
224
225        $tx.create_savepoint($name).await?;
226        let result = (|| async $body)().await;
227
228        match result {
229            Ok(value) => {
230                $tx.release_savepoint($name).await?;
231                Ok(value)
232            }
233            Err(e) => {
234                $tx.rollback_to_savepoint($name).await?;
235                Err(e)
236            }
237        }
238    }};
239}
240
241/// Configuration for transaction retry logic
242#[derive(Debug, Clone)]
243pub struct TransactionRetryConfig {
244    /// Maximum number of retry attempts
245    pub max_retries: u32,
246    /// Initial backoff duration in milliseconds
247    pub initial_backoff_ms: u64,
248    /// Maximum backoff duration in milliseconds
249    pub max_backoff_ms: u64,
250    /// Backoff multiplier for exponential backoff
251    pub backoff_multiplier: f64,
252}
253
254impl Default for TransactionRetryConfig {
255    fn default() -> Self {
256        Self {
257            max_retries: 3,
258            initial_backoff_ms: 10,
259            max_backoff_ms: 1000,
260            backoff_multiplier: 2.0,
261        }
262    }
263}
264
265/// Execute a transaction with automatic retry on serialization failures
266///
267/// This function automatically retries the transaction if it fails due to
268/// serialization errors (SQLSTATE 40001) or deadlock errors (SQLSTATE 40P01).
269///
270/// # Arguments
271/// * `pool` - Database connection pool (cloneable)
272/// * `config` - Retry configuration
273/// * `f` - Async function that performs the transactional work
274///
275/// # Returns
276/// Result of the transaction operation
277///
278/// # Example
279/// ```ignore
280/// use kaccy_db::transaction::{retry_transaction, TransactionRetryConfig};
281///
282/// let pool_clone = pool.clone();
283/// let result = retry_transaction(
284///     pool_clone,
285///     TransactionRetryConfig::default(),
286///     |pool| async move {
287///         let mut tx = pool.begin().await?;
288///         // Your transactional work here
289///         sqlx::query("UPDATE accounts SET balance = balance + $1 WHERE id = $2")
290///             .bind(amount)
291///             .bind(account_id)
292///             .execute(&mut *tx)
293///             .await?;
294///         tx.commit().await?;
295///         Ok(())
296///     }
297/// ).await?;
298/// ```
299pub async fn retry_transaction<T, F, Fut>(
300    pool: PgPool,
301    config: TransactionRetryConfig,
302    f: F,
303) -> Result<T>
304where
305    F: Fn(PgPool) -> Fut,
306    Fut: Future<Output = Result<T>>,
307{
308    let mut attempt = 0;
309    let mut backoff_ms = config.initial_backoff_ms;
310
311    loop {
312        attempt += 1;
313
314        match f(pool.clone()).await {
315            Ok(result) => {
316                return Ok(result);
317            }
318            Err(e) => {
319                // Check if the error is retriable
320                let is_retriable = is_retriable_error(&e);
321
322                if !is_retriable || attempt >= config.max_retries {
323                    tracing::warn!(
324                        attempt = attempt,
325                        max_retries = config.max_retries,
326                        error = %e,
327                        "Transaction failed after retries"
328                    );
329                    return Err(e);
330                }
331
332                // Exponential backoff with jitter
333                let jitter = (rand::random::<f64>() * 0.3) + 0.85; // 0.85-1.15 range
334                let sleep_ms = (backoff_ms as f64 * jitter) as u64;
335
336                tracing::debug!(
337                    attempt = attempt,
338                    max_retries = config.max_retries,
339                    backoff_ms = sleep_ms,
340                    error = %e,
341                    "Transaction failed, retrying"
342                );
343
344                tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
345
346                // Increase backoff for next iteration
347                backoff_ms = ((backoff_ms as f64 * config.backoff_multiplier) as u64)
348                    .min(config.max_backoff_ms);
349            }
350        }
351    }
352}
353
354/// Check if an error is retriable (serialization or deadlock)
355fn is_retriable_error(error: &crate::error::DbError) -> bool {
356    match error {
357        crate::error::DbError::Sqlx(sqlx_error) => {
358            if let Some(db_error) = sqlx_error.as_database_error() {
359                let code = db_error.code();
360                // 40001 = serialization_failure
361                // 40P01 = deadlock_detected
362                code.as_deref() == Some("40001") || code.as_deref() == Some("40P01")
363            } else {
364                false
365            }
366        }
367        _ => false,
368    }
369}
370
371/// Execute a transaction with retry on serialization failures and custom isolation level
372///
373/// # Arguments
374/// * `pool` - Database connection pool (cloneable)
375/// * `config` - Retry configuration
376/// * `isolation_level` - Transaction isolation level
377/// * `f` - Async function that performs the transactional work
378pub async fn retry_transaction_with_isolation<T, F, Fut>(
379    pool: PgPool,
380    config: TransactionRetryConfig,
381    isolation_level: IsolationLevel,
382    f: F,
383) -> Result<T>
384where
385    F: Fn(PgPool, IsolationLevel) -> Fut,
386    Fut: Future<Output = Result<T>>,
387{
388    let mut attempt = 0;
389    let mut backoff_ms = config.initial_backoff_ms;
390
391    loop {
392        attempt += 1;
393
394        match f(pool.clone(), isolation_level).await {
395            Ok(result) => {
396                return Ok(result);
397            }
398            Err(e) => {
399                let is_retriable = is_retriable_error(&e);
400
401                if !is_retriable || attempt >= config.max_retries {
402                    tracing::warn!(
403                        attempt = attempt,
404                        max_retries = config.max_retries,
405                        isolation_level = ?isolation_level,
406                        error = %e,
407                        "Transaction failed after retries"
408                    );
409                    return Err(e);
410                }
411
412                let jitter = (rand::random::<f64>() * 0.3) + 0.85;
413                let sleep_ms = (backoff_ms as f64 * jitter) as u64;
414
415                tracing::debug!(
416                    attempt = attempt,
417                    max_retries = config.max_retries,
418                    backoff_ms = sleep_ms,
419                    isolation_level = ?isolation_level,
420                    error = %e,
421                    "Transaction failed, retrying"
422                );
423
424                tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
425                backoff_ms = ((backoff_ms as f64 * config.backoff_multiplier) as u64)
426                    .min(config.max_backoff_ms);
427            }
428        }
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_retry_config_default() {
438        let config = TransactionRetryConfig::default();
439        assert_eq!(config.max_retries, 3);
440        assert_eq!(config.initial_backoff_ms, 10);
441        assert_eq!(config.max_backoff_ms, 1000);
442        assert_eq!(config.backoff_multiplier, 2.0);
443    }
444
445    #[test]
446    fn test_retry_config_custom() {
447        let config = TransactionRetryConfig {
448            max_retries: 5,
449            initial_backoff_ms: 50,
450            max_backoff_ms: 5000,
451            backoff_multiplier: 1.5,
452        };
453        assert_eq!(config.max_retries, 5);
454        assert_eq!(config.initial_backoff_ms, 50);
455        assert_eq!(config.max_backoff_ms, 5000);
456        assert_eq!(config.backoff_multiplier, 1.5);
457    }
458
459    #[test]
460    fn test_isolation_level_sql() {
461        assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
462        assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
463        assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
464    }
465}