use super::http_client::HttpResponse;
use crate::error::Result;
use async_trait::async_trait;
#[async_trait]
pub trait Middleware: Send + Sync {
async fn before_request(&self, _url: &str, _method: &str) -> Result<()> {
Ok(())
}
async fn after_response(&self, _url: &str, _response: &HttpResponse) -> Result<()> {
Ok(())
}
async fn on_error(&self, _url: &str, _error: &crate::error::AcmeError) -> Result<()> {
Ok(())
}
}
pub struct MiddlewareChain {
middlewares: Vec<Box<dyn Middleware>>,
}
impl MiddlewareChain {
pub fn new() -> Self {
tracing::debug!("Creating new MiddlewareChain");
Self {
middlewares: Vec::new(),
}
}
pub fn push<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.middlewares.push(Box::new(middleware));
self
}
pub async fn before_request(&self, url: &str, method: &str) -> Result<()> {
for middleware in &self.middlewares {
middleware.before_request(url, method).await?;
}
Ok(())
}
pub async fn after_response(&self, url: &str, response: &HttpResponse) -> Result<()> {
for middleware in &self.middlewares {
middleware.after_response(url, response).await?;
}
Ok(())
}
pub async fn on_error(&self, url: &str, error: &crate::error::AcmeError) -> Result<()> {
for middleware in &self.middlewares {
middleware.on_error(url, error).await?;
}
Ok(())
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
pub struct LoggingMiddleware {
#[allow(dead_code)]
log_body: bool,
}
impl LoggingMiddleware {
pub fn new(log_body: bool) -> Self {
Self { log_body }
}
}
#[async_trait]
impl Middleware for LoggingMiddleware {
async fn before_request(&self, url: &str, method: &str) -> Result<()> {
tracing::info!("HTTP Request: {} {}", method, url);
Ok(())
}
async fn after_response(&self, url: &str, response: &HttpResponse) -> Result<()> {
tracing::info!("HTTP Response: {} (Status: {})", url, response.status);
Ok(())
}
async fn on_error(&self, url: &str, error: &crate::error::AcmeError) -> Result<()> {
tracing::error!("HTTP Request Failed: {} - Error: {:?}", url, error);
Ok(())
}
}
pub struct TimeoutMiddleware {
#[allow(dead_code)]
timeout_secs: u64,
}
impl TimeoutMiddleware {
pub fn new(timeout_secs: u64) -> Self {
Self { timeout_secs }
}
}
#[async_trait]
impl Middleware for TimeoutMiddleware {
async fn before_request(&self, url: &str, _method: &str) -> Result<()> {
tracing::debug!("Enforcing timeout for: {}", url);
Ok(())
}
}
pub struct RetryMiddleware {
#[allow(dead_code)]
max_retries: u32,
}
impl RetryMiddleware {
pub fn new(max_retries: u32) -> Self {
Self { max_retries }
}
}
#[async_trait]
impl Middleware for RetryMiddleware {
async fn on_error(&self, url: &str, error: &crate::error::AcmeError) -> Result<()> {
tracing::debug!(
"Retry middleware intercepted error for {}: {:?}",
url,
error
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestMiddleware {
called: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
#[async_trait]
impl Middleware for TestMiddleware {
async fn before_request(&self, _url: &str, _method: &str) -> Result<()> {
self.called
.store(true, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
}
#[tokio::test]
async fn test_middleware_chain() {
let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let middleware = TestMiddleware {
called: called.clone(),
};
let chain = MiddlewareChain::new().push(middleware);
chain.before_request("http://example.com", "GET").await.ok();
assert!(called.load(std::sync::atomic::Ordering::Relaxed));
}
}