use anyhow::Result;
use async_trait::async_trait;
use super::super::{CommandContext, CommandOutput};
#[async_trait]
pub trait Middleware: Send + Sync {
async fn before(&self, ctx: &mut CommandContext) -> Result<()>;
async fn after(&self, ctx: &mut CommandContext, result: &Result<CommandOutput>) -> Result<()>;
fn name(&self) -> &str {
"unnamed-middleware"
}
fn is_enabled(&self, _ctx: &CommandContext) -> bool {
true
}
}
pub struct MiddlewareStack {
middleware: Vec<Box<dyn Middleware>>,
}
impl MiddlewareStack {
pub fn new() -> Self {
Self {
middleware: Vec::new(),
}
}
pub fn push(mut self, middleware: Box<dyn Middleware>) -> Self {
self.middleware.push(middleware);
self
}
pub async fn run_before(&self, ctx: &mut CommandContext) -> Result<()> {
for mw in &self.middleware {
if mw.is_enabled(ctx) {
mw.before(ctx).await?;
}
}
Ok(())
}
pub async fn run_after(&self, ctx: &mut CommandContext, result: &Result<CommandOutput>) {
for mw in self.middleware.iter().rev() {
if mw.is_enabled(ctx) {
if let Err(e) = mw.after(ctx, result).await {
tracing::warn!(
middleware = mw.name(),
error = %e,
"Middleware after hook failed"
);
}
}
}
}
pub fn len(&self) -> usize {
self.middleware.len()
}
pub fn is_empty(&self) -> bool {
self.middleware.is_empty()
}
}
impl Default for MiddlewareStack {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestMiddleware {
name: String,
before_called: std::sync::Arc<std::sync::atomic::AtomicBool>,
after_called: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl TestMiddleware {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
before_called: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
after_called: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
}
#[async_trait]
impl Middleware for TestMiddleware {
async fn before(&self, _ctx: &mut CommandContext) -> Result<()> {
self.before_called
.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
async fn after(
&self,
_ctx: &mut CommandContext,
_result: &Result<CommandOutput>,
) -> Result<()> {
self.after_called
.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
fn name(&self) -> &str {
&self.name
}
}
#[tokio::test]
async fn test_middleware_execution() {
let mw = TestMiddleware::new("test");
let before_called = mw.before_called.clone();
let after_called = mw.after_called.clone();
let mut ctx = CommandContext::mock();
mw.before(&mut ctx).await.unwrap();
assert!(before_called.load(std::sync::atomic::Ordering::SeqCst));
let result = Ok(CommandOutput::success("test"));
mw.after(&mut ctx, &result).await.unwrap();
assert!(after_called.load(std::sync::atomic::Ordering::SeqCst));
}
#[tokio::test]
async fn test_middleware_stack() {
let mw1 = TestMiddleware::new("first");
let mw2 = TestMiddleware::new("second");
let before1 = mw1.before_called.clone();
let before2 = mw2.before_called.clone();
let after1 = mw1.after_called.clone();
let after2 = mw2.after_called.clone();
let stack = MiddlewareStack::new()
.push(Box::new(mw1))
.push(Box::new(mw2));
assert_eq!(stack.len(), 2);
let mut ctx = CommandContext::mock();
stack.run_before(&mut ctx).await.unwrap();
assert!(before1.load(std::sync::atomic::Ordering::SeqCst));
assert!(before2.load(std::sync::atomic::Ordering::SeqCst));
let result = Ok(CommandOutput::success("test"));
stack.run_after(&mut ctx, &result).await;
assert!(after1.load(std::sync::atomic::Ordering::SeqCst));
assert!(after2.load(std::sync::atomic::Ordering::SeqCst));
}
struct ConditionalMiddleware {
enabled: bool,
}
#[async_trait]
impl Middleware for ConditionalMiddleware {
async fn before(&self, ctx: &mut CommandContext) -> Result<()> {
ctx.set_state("conditional_ran", true);
Ok(())
}
async fn after(
&self,
_ctx: &mut CommandContext,
_result: &Result<CommandOutput>,
) -> Result<()> {
Ok(())
}
fn is_enabled(&self, _ctx: &CommandContext) -> bool {
self.enabled
}
}
#[tokio::test]
async fn test_conditional_middleware() {
let enabled_mw = ConditionalMiddleware { enabled: true };
let disabled_mw = ConditionalMiddleware { enabled: false };
let mut ctx1 = CommandContext::mock();
let mut ctx2 = CommandContext::mock();
assert!(enabled_mw.is_enabled(&ctx1));
enabled_mw.before(&mut ctx1).await.unwrap();
assert_eq!(ctx1.get_state::<bool>("conditional_ran"), Some(&true));
assert!(!disabled_mw.is_enabled(&ctx2));
assert_eq!(ctx2.get_state::<bool>("conditional_ran"), None);
}
}