use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use crate::error::{AcmeError, Error, Result};
pub const DEFAULT_RETRY_INTERVALS: &[Duration] = &[
Duration::from_secs(60), Duration::from_secs(120), Duration::from_secs(120), Duration::from_secs(300), Duration::from_secs(600), Duration::from_secs(600), Duration::from_secs(600), Duration::from_secs(1200), Duration::from_secs(1200), Duration::from_secs(1200), Duration::from_secs(1200), Duration::from_secs(1800), Duration::from_secs(1800), Duration::from_secs(1800), Duration::from_secs(1800), Duration::from_secs(1800), Duration::from_secs(1800), Duration::from_secs(3600), Duration::from_secs(3600), Duration::from_secs(3600), Duration::from_secs(7200), Duration::from_secs(7200), Duration::from_secs(10800), Duration::from_secs(10800), Duration::from_secs(21600), ];
pub const DEFAULT_MAX_DURATION: Duration = Duration::from_secs(30 * 24 * 60 * 60);
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub intervals: Vec<Duration>,
pub max_duration: Duration,
pub max_retries: usize,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
intervals: DEFAULT_RETRY_INTERVALS.to_vec(),
max_duration: DEFAULT_MAX_DURATION,
max_retries: 0,
}
}
}
pub fn should_retry(err: &Error) -> bool {
match err {
Error::NoRetry(_) => false,
Error::Acme(AcmeError::RateLimited { .. }) => true,
Error::Acme(AcmeError::Challenge { .. }) => false,
Error::Acme(AcmeError::Authorization(_)) => false,
Error::Acme(_) => true,
Error::Storage(_) => true,
Error::Crypto(_) => false,
Error::Cert(_) => false,
Error::Config(_) => false,
Error::Timeout(_) => true,
Error::Other(_) => true,
}
}
pub fn no_retry(err: Error) -> Error {
Error::NoRetry(err.to_string())
}
pub async fn do_with_retry<F, Fut>(config: &RetryConfig, f: F) -> Result<()>
where
F: Fn(usize) -> Fut,
Fut: Future<Output = Result<()>>,
{
let start = Instant::now();
let mut interval_index: isize = -1;
let mut attempts: usize = 0;
loop {
if interval_index >= 0 {
let idx = (interval_index as usize).min(config.intervals.len().saturating_sub(1));
let wait = config.intervals[idx];
tokio::time::sleep(wait).await;
}
if start.elapsed() >= config.max_duration {
tracing::error!(
attempts,
elapsed = ?start.elapsed(),
max_duration = ?config.max_duration,
"retry budget exhausted; giving up",
);
return Err(Error::Timeout(format!(
"retry budget of {:?} exhausted after {attempts} attempts",
config.max_duration,
)));
}
match f(attempts).await {
Ok(()) => return Ok(()),
Err(err) => {
attempts += 1;
if !should_retry(&err) {
tracing::warn!(
%err,
attempts,
"non-retriable error; will not retry",
);
return Err(err);
}
if config.max_retries > 0 && attempts >= config.max_retries {
tracing::error!(
%err,
attempts,
max_retries = config.max_retries,
"max retries reached; giving up",
);
return Err(err);
}
if interval_index < config.intervals.len() as isize - 1 {
interval_index += 1;
}
let next_wait = config.intervals[interval_index.max(0) as usize];
tracing::error!(
%err,
attempts,
retrying_in = ?next_wait,
elapsed = ?start.elapsed(),
max_duration = ?config.max_duration,
"will retry",
);
}
}
}
}
pub struct JobQueue {
jobs: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
name: String,
semaphore: Option<Arc<Semaphore>>,
}
impl JobQueue {
pub fn new(name: impl Into<String>) -> Self {
Self {
jobs: Arc::new(Mutex::new(HashMap::new())),
name: name.into(),
semaphore: None,
}
}
pub fn with_max_concurrent(name: impl Into<String>, max_concurrent: usize) -> Self {
Self {
jobs: Arc::new(Mutex::new(HashMap::new())),
name: name.into(),
semaphore: if max_concurrent > 0 {
Some(Arc::new(Semaphore::new(max_concurrent)))
} else {
None
},
}
}
pub async fn submit<F, Fut>(&self, name: String, f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let mut jobs = self.jobs.lock().await;
jobs.retain(|_, handle| !handle.is_finished());
if jobs.contains_key(&name) {
tracing::debug!(
queue = %self.name,
job = %name,
"job already running; skipping duplicate submission",
);
return;
}
tracing::debug!(queue = %self.name, job = %name, "submitting background job");
let job_name = name.clone();
let queue_name = self.name.clone();
let jobs_ref = Arc::clone(&self.jobs);
let semaphore = self.semaphore.clone();
let handle = tokio::spawn(async move {
let _permit = match &semaphore {
Some(sem) => Some(sem.acquire().await.expect("semaphore should not be closed")),
None => None,
};
f().await;
let mut jobs = jobs_ref.lock().await;
jobs.remove(&job_name);
tracing::debug!(
queue = %queue_name,
job = %job_name,
"background job completed",
);
});
jobs.insert(name, handle);
}
pub async fn wait(&self, name: &str) {
let handle = {
let mut jobs = self.jobs.lock().await;
jobs.remove(name)
};
if let Some(handle) = handle {
let _ = handle.await;
}
}
pub async fn is_running(&self, name: &str) -> bool {
let jobs = self.jobs.lock().await;
match jobs.get(name) {
Some(handle) => !handle.is_finished(),
None => false,
}
}
pub async fn cancel(&self, name: &str) {
let mut jobs = self.jobs.lock().await;
if let Some(handle) = jobs.remove(name) {
handle.abort();
tracing::debug!(
queue = %self.name,
job = %name,
"background job cancelled",
);
}
}
pub async fn len(&self) -> usize {
let mut jobs = self.jobs.lock().await;
jobs.retain(|_, handle| !handle.is_finished());
jobs.len()
}
pub async fn is_empty(&self) -> bool {
self.len().await == 0
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[test]
fn default_retry_config() {
let cfg = RetryConfig::default();
assert_eq!(cfg.intervals.len(), 25);
assert_eq!(cfg.intervals[0], Duration::from_secs(60));
assert_eq!(*cfg.intervals.last().unwrap(), Duration::from_secs(21600));
assert_eq!(cfg.max_duration, Duration::from_secs(30 * 24 * 3600));
assert_eq!(cfg.max_retries, 0);
}
#[test]
fn should_retry_rate_limited() {
let err = Error::Acme(AcmeError::RateLimited {
retry_after: None,
message: "slow down".into(),
});
assert!(should_retry(&err));
}
#[test]
fn should_not_retry_config() {
let err = Error::Config("bad".into());
assert!(!should_retry(&err));
}
#[test]
fn should_not_retry_challenge() {
let err = Error::Acme(AcmeError::Challenge {
challenge_type: "http-01".into(),
message: "failed".into(),
});
assert!(!should_retry(&err));
}
#[test]
fn should_retry_timeout() {
let err = Error::Timeout("timed out".into());
assert!(should_retry(&err));
}
#[test]
fn should_retry_storage() {
let err = Error::Storage(crate::error::StorageError::NotFound("x".into()));
assert!(should_retry(&err));
}
#[tokio::test]
async fn retry_succeeds_on_first_try() {
let cfg = RetryConfig {
intervals: vec![Duration::from_millis(10)],
max_duration: Duration::from_secs(5),
max_retries: 3,
};
let result: Result<()> = do_with_retry(&cfg, |_| async { Ok(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn retry_succeeds_after_transient_failures() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let cfg = RetryConfig {
intervals: vec![Duration::from_millis(10)],
max_duration: Duration::from_secs(5),
max_retries: 5,
};
let result = do_with_retry(&cfg, move |_| {
let c = Arc::clone(&counter_clone);
async move {
let attempt = c.fetch_add(1, Ordering::SeqCst);
if attempt < 2 {
Err(Error::Timeout("transient".into()))
} else {
Ok(())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn retry_stops_on_non_retriable() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let cfg = RetryConfig {
intervals: vec![Duration::from_millis(10)],
max_duration: Duration::from_secs(5),
max_retries: 10,
};
let result = do_with_retry(&cfg, move |_| {
let c = Arc::clone(&counter_clone);
async move {
c.fetch_add(1, Ordering::SeqCst);
Err(Error::Config("permanent".into()))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retry_respects_max_retries() {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let cfg = RetryConfig {
intervals: vec![Duration::from_millis(10)],
max_duration: Duration::from_secs(60),
max_retries: 3,
};
let result = do_with_retry(&cfg, move |_| {
let c = Arc::clone(&counter_clone);
async move {
c.fetch_add(1, Ordering::SeqCst);
Err(Error::Timeout("always fails".into()))
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn job_queue_submit_and_wait() {
let queue = JobQueue::new("test");
let flag = Arc::new(AtomicUsize::new(0));
let flag_clone = Arc::clone(&flag);
queue
.submit("job1".into(), move || {
let f = Arc::clone(&flag_clone);
async move {
f.store(42, Ordering::SeqCst);
}
})
.await;
queue.wait("job1").await;
assert_eq!(flag.load(Ordering::SeqCst), 42);
}
#[tokio::test]
async fn job_queue_deduplicates() {
let queue = JobQueue::new("test");
let counter = Arc::new(AtomicUsize::new(0));
let c1 = Arc::clone(&counter);
queue
.submit("dup".into(), move || {
let c = Arc::clone(&c1);
async move {
tokio::time::sleep(Duration::from_millis(200)).await;
c.fetch_add(1, Ordering::SeqCst);
}
})
.await;
let c2 = Arc::clone(&counter);
queue
.submit("dup".into(), move || {
let c = Arc::clone(&c2);
async move {
c.fetch_add(1, Ordering::SeqCst);
}
})
.await;
queue.wait("dup").await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn job_queue_cancel() {
let queue = JobQueue::new("test");
let counter = Arc::new(AtomicUsize::new(0));
let c = Arc::clone(&counter);
queue
.submit("slow".into(), move || {
let c = Arc::clone(&c);
async move {
tokio::time::sleep(Duration::from_secs(10)).await;
c.fetch_add(1, Ordering::SeqCst);
}
})
.await;
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(queue.is_running("slow").await);
queue.cancel("slow").await;
assert!(!queue.is_running("slow").await);
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn job_queue_is_running() {
let queue = JobQueue::new("test");
assert!(!queue.is_running("nope").await);
queue
.submit("task".into(), || async {
tokio::time::sleep(Duration::from_millis(200)).await;
})
.await;
assert!(queue.is_running("task").await);
queue.wait("task").await;
assert!(!queue.is_running("task").await);
}
#[tokio::test]
async fn job_queue_len() {
let queue = JobQueue::new("test");
assert!(queue.is_empty().await);
queue
.submit("a".into(), || async {
tokio::time::sleep(Duration::from_millis(200)).await;
})
.await;
queue
.submit("b".into(), || async {
tokio::time::sleep(Duration::from_millis(200)).await;
})
.await;
assert_eq!(queue.len().await, 2);
queue.wait("a").await;
queue.wait("b").await;
assert!(queue.is_empty().await);
}
}