use async_trait::async_trait;
use super::Middleware;
use crate::error::AgentError;
pub struct RetryMiddleware {
max_retries: u32,
pub(crate) error_count: std::sync::atomic::AtomicU32,
pub(crate) attempt: std::sync::atomic::AtomicU32,
}
impl RetryMiddleware {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
error_count: std::sync::atomic::AtomicU32::new(0),
attempt: std::sync::atomic::AtomicU32::new(0),
}
}
pub fn should_retry(&self) -> bool {
let attempts = self.attempt.load(std::sync::atomic::Ordering::SeqCst);
let errors = self.error_count.load(std::sync::atomic::Ordering::SeqCst);
errors > 0 && attempts < self.max_retries
}
pub fn record_attempt(&self) {
self.attempt
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.error_count
.store(0, std::sync::atomic::Ordering::SeqCst);
}
pub fn attempts(&self) -> u32 {
self.attempt.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn reset(&self) {
self.error_count
.store(0, std::sync::atomic::Ordering::SeqCst);
self.attempt.store(0, std::sync::atomic::Ordering::SeqCst);
}
}
#[async_trait]
impl Middleware for RetryMiddleware {
fn name(&self) -> &str {
"retry"
}
async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
self.error_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
}