use super::{AgentMiddleware, MiddlewareContext, Verdict, Warning};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PipelineResult {
Passed {
warnings: Vec<Warning>,
},
Blocked {
category: String,
reason: String,
blocked_by: String,
},
}
impl PipelineResult {
pub fn is_passed(&self) -> bool {
matches!(self, Self::Passed { .. })
}
pub fn warnings(&self) -> &[Warning] {
match self {
Self::Passed { warnings } => warnings,
Self::Blocked { .. } => &[],
}
}
}
pub struct MiddlewarePipeline {
middleware: Vec<Box<dyn AgentMiddleware>>,
}
impl MiddlewarePipeline {
pub fn new(middleware: Vec<Box<dyn AgentMiddleware>>) -> Self {
Self { middleware }
}
pub fn empty() -> Self {
Self {
middleware: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.middleware.is_empty()
}
pub fn len(&self) -> usize {
self.middleware.len()
}
pub async fn run(&self, ctx: &mut MiddlewareContext) -> PipelineResult {
let mut warnings = Vec::new();
for mw in &self.middleware {
if !mw.stages().contains(&ctx.stage) {
continue;
}
let verdict = mw.execute(ctx).await;
for (k, v) in &verdict.hook_state {
ctx.hook_state.insert(k.clone(), v.clone());
}
match verdict.verdict {
Verdict::Block => {
tracing::warn!(
middleware = mw.name(),
category = ?verdict.category,
reason = ?verdict.reason,
agent_id = %ctx.agent_id,
stage = ?ctx.stage,
"Middleware blocked content"
);
return PipelineResult::Blocked {
category: verdict.category.unwrap_or_default(),
reason: verdict.reason.unwrap_or_default(),
blocked_by: mw.name().to_string(),
};
}
Verdict::Warn => {
tracing::info!(
middleware = mw.name(),
category = ?verdict.category,
reason = ?verdict.reason,
agent_id = %ctx.agent_id,
"Middleware warning"
);
warnings.push(Warning {
middleware: mw.name().to_string(),
category: verdict.category.clone(),
reason: verdict.reason.clone(),
});
if let Some(ref new_content) = verdict.content {
ctx.content = new_content.clone();
}
}
Verdict::Pass => {
if let Some(ref new_content) = verdict.content {
ctx.content = new_content.clone();
}
}
}
}
PipelineResult::Passed { warnings }
}
}
impl std::fmt::Debug for MiddlewarePipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names: Vec<&str> = self.middleware.iter().map(|m| m.name()).collect();
f.debug_struct("MiddlewarePipeline")
.field("middleware", &names)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::{MiddlewareStage, MiddlewareVerdict};
use async_trait::async_trait;
use std::collections::HashMap;
#[derive(Debug)]
struct PassMiddleware;
#[async_trait]
impl AgentMiddleware for PassMiddleware {
async fn execute(&self, _ctx: &MiddlewareContext) -> MiddlewareVerdict {
MiddlewareVerdict::pass()
}
fn name(&self) -> &str {
"pass"
}
}
#[derive(Debug)]
struct WarnMiddleware {
category: String,
}
#[async_trait]
impl AgentMiddleware for WarnMiddleware {
async fn execute(&self, _ctx: &MiddlewareContext) -> MiddlewareVerdict {
MiddlewareVerdict::warn(&self.category, "Just a warning")
}
fn name(&self) -> &str {
"warn"
}
}
#[derive(Debug)]
struct BlockMiddleware {
category: String,
}
#[async_trait]
impl AgentMiddleware for BlockMiddleware {
async fn execute(&self, _ctx: &MiddlewareContext) -> MiddlewareVerdict {
MiddlewareVerdict::block(&self.category, "Content rejected")
}
fn name(&self) -> &str {
"block"
}
}
#[derive(Debug)]
struct TransformMiddleware {
new_content: serde_json::Value,
}
#[async_trait]
impl AgentMiddleware for TransformMiddleware {
async fn execute(&self, _ctx: &MiddlewareContext) -> MiddlewareVerdict {
MiddlewareVerdict::pass_with_content(self.new_content.clone())
}
fn name(&self) -> &str {
"transform"
}
}
#[derive(Debug)]
struct ReleaseOnlyMiddleware;
#[async_trait]
impl AgentMiddleware for ReleaseOnlyMiddleware {
async fn execute(&self, _ctx: &MiddlewareContext) -> MiddlewareVerdict {
MiddlewareVerdict::block("release_only", "Should only run at release")
}
fn stages(&self) -> Vec<MiddlewareStage> {
vec![MiddlewareStage::Release]
}
fn name(&self) -> &str {
"release_only"
}
}
#[derive(Debug)]
struct StateWriterMiddleware;
#[async_trait]
impl AgentMiddleware for StateWriterMiddleware {
async fn execute(&self, ctx: &MiddlewareContext) -> MiddlewareVerdict {
let mut v = MiddlewareVerdict::pass();
if ctx.hook_state.contains_key("seen") {
v = MiddlewareVerdict::warn("state", "Already seen");
}
v
}
fn name(&self) -> &str {
"state_writer"
}
}
fn make_ctx(stage: MiddlewareStage) -> MiddlewareContext {
MiddlewareContext {
content: serde_json::json!({"text": "original"}),
action: "propose".to_string(),
agent_id: "test-agent".to_string(),
job_id: "test-job".to_string(),
round: 1,
stage,
metadata: serde_json::json!({}),
hook_state: HashMap::new(),
}
}
#[tokio::test]
async fn pipeline_all_pass() {
let pipeline =
MiddlewarePipeline::new(vec![Box::new(PassMiddleware), Box::new(PassMiddleware)]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert!(result.warnings().is_empty());
}
#[tokio::test]
async fn pipeline_pass_warn_pass() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(PassMiddleware),
Box::new(WarnMiddleware {
category: "format".to_string(),
}),
Box::new(PassMiddleware),
]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert_eq!(result.warnings().len(), 1);
assert_eq!(result.warnings()[0].middleware, "warn");
assert_eq!(result.warnings()[0].category.as_deref(), Some("format"));
}
#[tokio::test]
async fn pipeline_short_circuits_on_block() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(PassMiddleware),
Box::new(BlockMiddleware {
category: "harassment".to_string(),
}),
Box::new(PassMiddleware), ]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(!result.is_passed());
match result {
PipelineResult::Blocked {
category,
reason,
blocked_by,
} => {
assert_eq!(category, "harassment");
assert_eq!(reason, "Content rejected");
assert_eq!(blocked_by, "block");
}
_ => panic!("Expected Blocked"),
}
}
#[tokio::test]
async fn pipeline_content_transformation_flows_through() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(TransformMiddleware {
new_content: serde_json::json!({"text": "transformed"}),
}),
Box::new(PassMiddleware), ]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert_eq!(ctx.content, serde_json::json!({"text": "transformed"}));
}
#[tokio::test]
async fn pipeline_stage_filtering() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(PassMiddleware),
Box::new(ReleaseOnlyMiddleware), ]);
let mut ctx = make_ctx(MiddlewareStage::Edit);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed()); }
#[tokio::test]
async fn pipeline_stage_filtering_blocks_at_correct_stage() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(PassMiddleware),
Box::new(ReleaseOnlyMiddleware), ]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(!result.is_passed()); }
#[tokio::test]
async fn pipeline_empty_is_noop() {
let pipeline = MiddlewarePipeline::empty();
assert!(pipeline.is_empty());
assert_eq!(pipeline.len(), 0);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert!(result.warnings().is_empty());
}
#[tokio::test]
async fn pipeline_hook_state_available() {
let pipeline = MiddlewarePipeline::new(vec![Box::new(StateWriterMiddleware)]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert!(result.warnings().is_empty());
ctx.hook_state
.insert("seen".to_string(), serde_json::json!(true));
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert_eq!(result.warnings().len(), 1);
}
#[tokio::test]
async fn pipeline_multiple_warnings_accumulated() {
let pipeline = MiddlewarePipeline::new(vec![
Box::new(WarnMiddleware {
category: "pii".to_string(),
}),
Box::new(WarnMiddleware {
category: "length".to_string(),
}),
Box::new(PassMiddleware),
]);
let mut ctx = make_ctx(MiddlewareStage::Release);
let result = pipeline.run(&mut ctx).await;
assert!(result.is_passed());
assert_eq!(result.warnings().len(), 2);
assert_eq!(result.warnings()[0].category.as_deref(), Some("pii"));
assert_eq!(result.warnings()[1].category.as_deref(), Some("length"));
}
}