#![allow(clippy::unwrap_used, clippy::expect_used)]
#![cfg(feature = "sqlite")]
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use modkit_db::{ConnectOpts, DEFAULT_TX_RETRY_ATTEMPTS, DbError, connect_db, secure::TxConfig};
use sea_orm::{DbErr, RuntimeErr};
#[derive(Debug)]
enum TestError {
Retryable(DbErr),
Permanent,
#[allow(dead_code)]
Db(DbError),
}
impl From<DbError> for TestError {
fn from(e: DbError) -> Self {
TestError::Db(e)
}
}
fn extract_db_err(e: &TestError) -> Option<&DbErr> {
match e {
TestError::Retryable(err) => Some(err),
_ => None,
}
}
fn sqlite_busy_err() -> DbErr {
DbErr::Exec(RuntimeErr::Internal(
"Execution Error: error returned from database: (code: 5) database is locked".to_owned(),
))
}
#[tokio::test]
async fn retry_default_succeeds_after_transient_failures() {
let db = connect_db("sqlite::memory:", ConnectOpts::default())
.await
.expect("connect sqlite memory");
let counter = Arc::new(AtomicU32::new(0));
let counter_for_body = Arc::clone(&counter);
let result: Result<u32, TestError> = db
.transaction_with_retry(TxConfig::default(), extract_db_err, move |_tx| {
let counter = Arc::clone(&counter_for_body);
Box::pin(async move {
let n = counter.fetch_add(1, Ordering::SeqCst) + 1;
if n < DEFAULT_TX_RETRY_ATTEMPTS {
Err(TestError::Retryable(sqlite_busy_err()))
} else {
Ok(n)
}
})
})
.await;
assert!(
matches!(result, Ok(n) if n == DEFAULT_TX_RETRY_ATTEMPTS),
"got {result:?}"
);
assert_eq!(counter.load(Ordering::SeqCst), DEFAULT_TX_RETRY_ATTEMPTS);
}
#[tokio::test]
async fn retry_returns_last_error_on_exhaustion() {
let db = connect_db("sqlite::memory:", ConnectOpts::default())
.await
.expect("connect sqlite memory");
let counter = Arc::new(AtomicU32::new(0));
let counter_for_body = Arc::clone(&counter);
let result: Result<(), TestError> = db
.transaction_with_retry_max(TxConfig::default(), 3, extract_db_err, move |_tx| {
let counter = Arc::clone(&counter_for_body);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(TestError::Retryable(sqlite_busy_err()))
})
})
.await;
assert!(
matches!(result, Err(TestError::Retryable(_))),
"got {result:?}"
);
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn non_retryable_error_returns_immediately() {
let db = connect_db("sqlite::memory:", ConnectOpts::default())
.await
.expect("connect sqlite memory");
let counter = Arc::new(AtomicU32::new(0));
let counter_for_body = Arc::clone(&counter);
let result: Result<(), TestError> = db
.transaction_with_retry_max(TxConfig::default(), 3, extract_db_err, move |_tx| {
let counter = Arc::clone(&counter_for_body);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(TestError::Permanent)
})
})
.await;
assert!(
matches!(result, Err(TestError::Permanent)),
"got {result:?}"
);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn extractor_returning_none_skips_retry() {
let db = connect_db("sqlite::memory:", ConnectOpts::default())
.await
.expect("connect sqlite memory");
let counter = Arc::new(AtomicU32::new(0));
let counter_for_body = Arc::clone(&counter);
let result: Result<(), TestError> = db
.transaction_with_retry_max(TxConfig::default(), 3, extract_db_err, move |_tx| {
let counter = Arc::clone(&counter_for_body);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(TestError::Db(DbError::InvalidConfig("boom".to_owned())))
})
})
.await;
assert!(matches!(result, Err(TestError::Db(_))), "got {result:?}");
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn zero_max_attempts_treated_as_one() {
let db = connect_db("sqlite::memory:", ConnectOpts::default())
.await
.expect("connect sqlite memory");
let counter = Arc::new(AtomicU32::new(0));
let counter_for_body = Arc::clone(&counter);
let result: Result<(), TestError> = db
.transaction_with_retry_max(TxConfig::default(), 0, extract_db_err, move |_tx| {
let counter = Arc::clone(&counter_for_body);
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Err(TestError::Retryable(sqlite_busy_err()))
})
})
.await;
assert!(
matches!(result, Err(TestError::Retryable(_))),
"got {result:?}"
);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}