use serde_json::Value;
use crate::BoxFuture;
use crate::agents::error::AgentError;
use crate::tools::Tool;
#[derive(Debug, Clone)]
pub struct MiddlewareInput {
pub messages: Vec<Value>,
pub context: Value,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum MiddlewareResult {
Continue(MiddlewareInput),
Terminate(String),
}
pub trait Middleware: Send + Sync {
fn name(&self) -> &str;
fn process(
&self,
input: MiddlewareInput,
) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
Box::pin(async move { Ok(MiddlewareResult::Continue(input)) })
}
fn tools(&self) -> Vec<Box<dyn Tool>> {
Vec::new()
}
fn system_prompt_additions(&self) -> Vec<String> {
Vec::new()
}
}
pub struct MiddlewareStack {
components: Vec<Box<dyn Middleware>>,
}
impl std::fmt::Debug for MiddlewareStack {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MiddlewareStack")
.field(
"components",
&self.components.iter().map(|m| m.name()).collect::<Vec<_>>(),
)
.finish()
}
}
impl MiddlewareStack {
#[must_use]
pub fn new() -> Self {
Self {
components: Vec::new(),
}
}
pub fn push(&mut self, middleware: impl Middleware + 'static) {
self.components.push(Box::new(middleware));
}
pub async fn run(&self, mut input: MiddlewareInput) -> Result<MiddlewareResult, AgentError> {
for mw in &self.components {
match mw.process(input).await? {
MiddlewareResult::Continue(next_input) => input = next_input,
term @ MiddlewareResult::Terminate(_) => return Ok(term),
}
}
Ok(MiddlewareResult::Continue(input))
}
#[must_use]
pub fn system_prompt_additions(&self) -> Vec<String> {
self.components
.iter()
.flat_map(|m| m.system_prompt_additions())
.collect()
}
pub fn tools(&self) -> Vec<Box<dyn Tool>> {
self.components.iter().flat_map(|m| m.tools()).collect()
}
}
impl Default for MiddlewareStack {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::unnecessary_literal_bound
)]
mod tests {
use super::*;
struct OrderRecorder {
name: &'static str,
order: std::sync::Arc<std::sync::Mutex<Vec<&'static str>>>,
}
impl Middleware for OrderRecorder {
fn name(&self) -> &str {
self.name
}
fn process(
&self,
input: MiddlewareInput,
) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
let order = self.order.clone();
Box::pin(async move {
if let Ok(mut g) = order.lock() {
g.push(self.name);
}
Ok(MiddlewareResult::Continue(input))
})
}
fn system_prompt_additions(&self) -> Vec<String> {
vec![format!("[{}]", self.name)]
}
}
struct EarlyTerminator;
impl Middleware for EarlyTerminator {
fn name(&self) -> &str {
"terminator"
}
fn process(
&self,
_input: MiddlewareInput,
) -> BoxFuture<'_, Result<MiddlewareResult, AgentError>> {
Box::pin(async { Ok(MiddlewareResult::Terminate("stop".to_string())) })
}
}
fn base_input() -> MiddlewareInput {
MiddlewareInput {
messages: Vec::new(),
context: serde_json::json!({}),
}
}
#[tokio::test]
async fn test_stack_order() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecorder {
name: "a",
order: order.clone(),
});
stack.push(OrderRecorder {
name: "b",
order: order.clone(),
});
let _ = stack.run(base_input()).await.expect("run");
let seen = order.lock().expect("lock").clone();
assert_eq!(seen, vec!["a", "b"]);
}
#[tokio::test]
async fn test_early_termination() {
let mut stack = MiddlewareStack::new();
stack.push(EarlyTerminator);
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
stack.push(OrderRecorder {
name: "after",
order: order.clone(),
});
let result = stack.run(base_input()).await.expect("run");
assert!(matches!(result, MiddlewareResult::Terminate(_)));
assert!(order.lock().expect("lock").is_empty());
}
#[tokio::test]
async fn test_system_prompt_composition_order() {
let order = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let mut stack = MiddlewareStack::new();
stack.push(OrderRecorder {
name: "first",
order: order.clone(),
});
stack.push(OrderRecorder {
name: "second",
order: order.clone(),
});
let additions = stack.system_prompt_additions();
assert_eq!(additions, vec!["[first]", "[second]"]);
}
}