pub mod pipeline;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use utoipa::ToSchema;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Verdict {
Pass,
Warn,
Block,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct MiddlewareVerdict {
pub verdict: Verdict,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub hook_state: HashMap<String, serde_json::Value>,
}
impl MiddlewareVerdict {
pub fn pass() -> Self {
Self {
verdict: Verdict::Pass,
content: None,
reason: None,
category: None,
hook_state: HashMap::new(),
}
}
pub fn pass_with_content(content: serde_json::Value) -> Self {
Self {
verdict: Verdict::Pass,
content: Some(content),
reason: None,
category: None,
hook_state: HashMap::new(),
}
}
pub fn warn(category: impl Into<String>, reason: impl Into<String>) -> Self {
Self {
verdict: Verdict::Warn,
content: None,
reason: Some(reason.into()),
category: Some(category.into()),
hook_state: HashMap::new(),
}
}
pub fn block(category: impl Into<String>, reason: impl Into<String>) -> Self {
Self {
verdict: Verdict::Block,
content: None,
reason: Some(reason.into()),
category: Some(category.into()),
hook_state: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, ToSchema)]
#[serde(rename_all = "lowercase")]
pub enum MiddlewareStage {
Edit,
Release,
ProviderResponse,
BeforePrompt,
Completion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MiddlewareContext {
pub content: serde_json::Value,
pub action: String,
pub agent_id: String,
pub job_id: String,
pub round: u32,
pub stage: MiddlewareStage,
#[serde(default)]
pub metadata: serde_json::Value,
#[serde(default)]
pub hook_state: HashMap<String, serde_json::Value>,
}
#[async_trait]
pub trait AgentMiddleware: Send + Sync + Debug {
async fn execute(&self, ctx: &MiddlewareContext) -> MiddlewareVerdict;
fn stages(&self) -> Vec<MiddlewareStage> {
vec![
MiddlewareStage::Edit,
MiddlewareStage::Release,
MiddlewareStage::ProviderResponse,
]
}
fn name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Warning {
pub middleware: String,
pub category: Option<String>,
pub reason: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verdict_pass_default() {
let v = MiddlewareVerdict::pass();
assert_eq!(v.verdict, Verdict::Pass);
assert!(v.content.is_none());
assert!(v.reason.is_none());
}
#[test]
fn verdict_block_has_reason() {
let v = MiddlewareVerdict::block("pii", "Contains email addresses");
assert_eq!(v.verdict, Verdict::Block);
assert_eq!(v.category.as_deref(), Some("pii"));
assert_eq!(v.reason.as_deref(), Some("Contains email addresses"));
}
#[test]
fn verdict_warn_has_category() {
let v = MiddlewareVerdict::warn("format", "Response exceeds recommended length");
assert_eq!(v.verdict, Verdict::Warn);
assert_eq!(v.category.as_deref(), Some("format"));
}
#[test]
fn verdict_pass_with_content() {
let content = serde_json::json!({"cleaned": true});
let v = MiddlewareVerdict::pass_with_content(content.clone());
assert_eq!(v.verdict, Verdict::Pass);
assert_eq!(v.content.unwrap(), content);
}
#[test]
fn middleware_context_serde_roundtrip() {
let ctx = MiddlewareContext {
content: serde_json::json!({"text": "hello"}),
action: "propose".to_string(),
agent_id: "agent-1".to_string(),
job_id: "job-42".to_string(),
round: 2,
stage: MiddlewareStage::Release,
metadata: serde_json::json!({}),
hook_state: HashMap::new(),
};
let json = serde_json::to_string(&ctx).unwrap();
let deserialized: MiddlewareContext = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.agent_id, "agent-1");
assert_eq!(deserialized.stage, MiddlewareStage::Release);
assert_eq!(deserialized.round, 2);
}
#[test]
fn middleware_verdict_serde_roundtrip() {
let v = MiddlewareVerdict::block("harassment", "Violates guidelines");
let json = serde_json::to_string(&v).unwrap();
let deserialized: MiddlewareVerdict = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.verdict, Verdict::Block);
assert_eq!(deserialized.category.as_deref(), Some("harassment"));
}
#[test]
fn hook_state_propagation() {
let mut ctx = MiddlewareContext {
content: serde_json::json!(null),
action: "propose".to_string(),
agent_id: "a".to_string(),
job_id: "j".to_string(),
round: 0,
stage: MiddlewareStage::Edit,
metadata: serde_json::json!(null),
hook_state: HashMap::new(),
};
ctx.hook_state
.insert("pii_detected".to_string(), serde_json::json!(true));
let json = serde_json::to_string(&ctx).unwrap();
let deserialized: MiddlewareContext = serde_json::from_str(&json).unwrap();
assert_eq!(
deserialized.hook_state.get("pii_detected"),
Some(&serde_json::json!(true))
);
}
}
pub mod binary;
pub mod builtin;
pub mod config;
pub mod dylib;
pub use binary::BinaryMiddleware;
pub use config::{MiddlewareConfig, MiddlewareEntry};
pub use dylib::DylibMiddleware;