use std::time::Duration;
pub const MAX_ATTEMPTS: u32 = 3;
pub trait IsRetryableBusy {
fn is_retryable_busy(&self) -> bool;
}
impl IsRetryableBusy for sqlx::Error {
fn is_retryable_busy(&self) -> bool {
crate::errors::is_retryable_sqlite_busy(self)
}
}
pub async fn retry_serializable<F, Fut, T, E>(mut f: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: IsRetryableBusy,
{
let mut last: Option<E> = None;
for attempt in 0..MAX_ATTEMPTS {
match f().await {
Ok(v) => return Ok(v),
Err(err) => {
if err.is_retryable_busy() {
if attempt + 1 < MAX_ATTEMPTS {
let ms = 5u64 * (1u64 << attempt);
tokio::time::sleep(Duration::from_millis(ms)).await;
}
last = Some(err);
continue;
}
return Err(err);
}
}
}
Err(last.expect("retry loop exited without populating last error"))
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::error::{DatabaseError, ErrorKind};
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[derive(Debug)]
struct MockSqliteError {
code: String,
}
impl fmt::Display for MockSqliteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(mock code: {})", self.code)
}
}
impl StdError for MockSqliteError {}
impl DatabaseError for MockSqliteError {
fn message(&self) -> &str {
"mock"
}
fn code(&self) -> Option<Cow<'_, str>> {
Some(Cow::Borrowed(self.code.as_str()))
}
fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) {
self
}
fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) {
self
}
fn into_error(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static> {
self
}
fn kind(&self) -> ErrorKind {
ErrorKind::Other
}
}
fn db_err(code: &str) -> sqlx::Error {
sqlx::Error::Database(Box::new(MockSqliteError {
code: code.to_string(),
}))
}
#[test]
fn classifier_matches_busy() {
assert!(db_err("5").is_retryable_busy(), "SQLITE_BUSY (5)");
assert!(db_err("6").is_retryable_busy(), "SQLITE_LOCKED (6)");
assert!(
db_err("261").is_retryable_busy(),
"SQLITE_BUSY_RECOVERY (261 = 5 | 1<<8)"
);
assert!(
db_err("517").is_retryable_busy(),
"SQLITE_BUSY_SNAPSHOT (517 = 5 | 2<<8)"
);
assert!(
db_err("773").is_retryable_busy(),
"SQLITE_BUSY_TIMEOUT (773 = 5 | 3<<8)"
);
assert!(
db_err("262").is_retryable_busy(),
"SQLITE_LOCKED_SHAREDCACHE (262 = 6 | 1<<8)"
);
assert!(
db_err("518").is_retryable_busy(),
"SQLITE_LOCKED_VTAB (518 = 6 | 2<<8)"
);
}
#[test]
fn classifier_rejects_corrupt() {
assert!(!db_err("11").is_retryable_busy());
}
#[test]
fn classifier_rejects_full() {
assert!(!db_err("13").is_retryable_busy());
}
#[test]
fn classifier_rejects_misuse() {
assert!(!db_err("21").is_retryable_busy());
}
#[test]
fn classifier_rejects_non_db_error() {
assert!(!sqlx::Error::RowNotFound.is_retryable_busy());
}
#[test]
fn classifier_rejects_extended_non_busy_family() {
assert!(!db_err("266").is_retryable_busy());
assert!(!db_err("2067").is_retryable_busy());
}
#[tokio::test(start_paused = true)]
async fn retry_succeeds_on_first_try() {
let calls = Arc::new(AtomicU32::new(0));
let calls_c = calls.clone();
let result: Result<u32, sqlx::Error> = retry_serializable(|| {
let calls = calls_c.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(42)
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test(start_paused = true)]
async fn retry_exhausts_after_max_attempts() {
let calls = Arc::new(AtomicU32::new(0));
let calls_c = calls.clone();
let result: Result<(), sqlx::Error> = retry_serializable(|| {
let calls = calls_c.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Err(db_err("5")) }
})
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_retryable_busy());
assert_eq!(calls.load(Ordering::SeqCst), MAX_ATTEMPTS);
}
#[tokio::test(start_paused = true)]
async fn retry_returns_non_retryable_immediately() {
let calls = Arc::new(AtomicU32::new(0));
let calls_c = calls.clone();
let result: Result<(), sqlx::Error> = retry_serializable(|| {
let calls = calls_c.clone();
async move {
calls.fetch_add(1, Ordering::SeqCst);
Err(db_err("11")) }
})
.await;
assert!(result.is_err());
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test(start_paused = true)]
async fn retry_succeeds_on_second_attempt() {
let calls = Arc::new(AtomicU32::new(0));
let calls_c = calls.clone();
let result: Result<u32, sqlx::Error> = retry_serializable(|| {
let calls = calls_c.clone();
async move {
let n = calls.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Err(db_err("5"))
} else {
Ok(7)
}
}
})
.await;
assert_eq!(result.unwrap(), 7);
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test(start_paused = true)]
async fn retry_backoff_matches_pg_shape() {
let start = tokio::time::Instant::now();
let _: Result<(), sqlx::Error> = retry_serializable(|| async {
Err::<(), _>(db_err("5"))
})
.await;
let elapsed = start.elapsed();
assert_eq!(
elapsed,
Duration::from_millis(15),
"expected 5ms + 10ms between three attempts"
);
}
}