use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use thiserror::Error;
use crate::{
agent::context_engineering::{ModelRequest, ModelResponse},
agent::runtime::{Runtime, RuntimeRequest},
chain::ChainError,
language_models::GenerateResult,
prompt::PromptArgs,
schemas::agent::{AgentAction, AgentEvent, AgentFinish},
};
#[derive(Clone, Debug)]
pub struct MiddlewareContext {
pub iteration: usize,
pub start_time: std::time::Instant,
pub tool_call_count: usize,
pub custom_data: HashMap<String, Value>,
}
impl MiddlewareContext {
pub fn new() -> Self {
Self {
iteration: 0,
start_time: std::time::Instant::now(),
tool_call_count: 0,
custom_data: HashMap::new(),
}
}
pub fn with_iteration(mut self, iteration: usize) -> Self {
self.iteration = iteration;
self
}
pub fn increment_iteration(&mut self) {
self.iteration += 1;
}
pub fn increment_tool_call_count(&mut self) {
self.tool_call_count += 1;
}
pub fn get_custom_data(&self, key: &str) -> Option<&Value> {
self.custom_data.get(key)
}
pub fn set_custom_data(&mut self, key: String, value: Value) {
self.custom_data.insert(key, value);
}
}
impl Default for MiddlewareContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Error)]
pub enum MiddlewareError {
#[error("Middleware execution error: {0}")]
ExecutionError(String),
#[error("Middleware aborted execution: {0}")]
Aborted(String),
#[error("Middleware validation error: {0}")]
ValidationError(String),
#[error("Chain error: {0}")]
ChainError(#[from] ChainError),
#[error("Interrupt (human-in-the-loop)")]
Interrupt(serde_json::Value),
#[error("Tool call rejected by user")]
RejectTool,
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn before_agent_plan(
&self,
input: &PromptArgs,
steps: &[(AgentAction, String)],
context: &mut MiddlewareContext,
) -> Result<Option<PromptArgs>, MiddlewareError> {
let _ = (input, steps, context);
Ok(None)
}
async fn after_agent_plan(
&self,
input: &PromptArgs,
event: &AgentEvent,
context: &mut MiddlewareContext,
) -> Result<Option<AgentEvent>, MiddlewareError> {
let _ = (input, event, context);
Ok(None)
}
async fn before_tool_call(
&self,
action: &AgentAction,
context: &mut MiddlewareContext,
) -> Result<Option<AgentAction>, MiddlewareError> {
let _ = (action, context);
Ok(None)
}
async fn after_tool_call(
&self,
action: &AgentAction,
observation: &str,
context: &mut MiddlewareContext,
) -> Result<Option<String>, MiddlewareError> {
let _ = (action, observation, context);
Ok(None)
}
async fn before_finish(
&self,
finish: &AgentFinish,
context: &mut MiddlewareContext,
) -> Result<Option<AgentFinish>, MiddlewareError> {
let _ = (finish, context);
Ok(None)
}
async fn after_finish(
&self,
finish: &AgentFinish,
result: &GenerateResult,
context: &mut MiddlewareContext,
) -> Result<(), MiddlewareError> {
let _ = (finish, result, context);
Ok(())
}
async fn before_agent_plan_with_runtime(
&self,
request: &RuntimeRequest,
steps: &[(AgentAction, String)],
context: &mut MiddlewareContext,
) -> Result<Option<PromptArgs>, MiddlewareError> {
self.before_agent_plan(&request.input, steps, context).await
}
async fn after_agent_plan_with_runtime(
&self,
request: &RuntimeRequest,
event: &AgentEvent,
context: &mut MiddlewareContext,
) -> Result<Option<AgentEvent>, MiddlewareError> {
self.after_agent_plan(&request.input, event, context).await
}
async fn before_tool_call_with_runtime(
&self,
action: &AgentAction,
_runtime: Option<&Runtime>,
context: &mut MiddlewareContext,
) -> Result<Option<AgentAction>, MiddlewareError> {
self.before_tool_call(action, context).await
}
async fn after_tool_call_with_runtime(
&self,
action: &AgentAction,
observation: &str,
_runtime: Option<&Runtime>,
context: &mut MiddlewareContext,
) -> Result<Option<String>, MiddlewareError> {
self.after_tool_call(action, observation, context).await
}
async fn before_finish_with_runtime(
&self,
finish: &AgentFinish,
_runtime: Option<&Runtime>,
context: &mut MiddlewareContext,
) -> Result<Option<AgentFinish>, MiddlewareError> {
self.before_finish(finish, context).await
}
async fn after_finish_with_runtime(
&self,
finish: &AgentFinish,
result: &GenerateResult,
_runtime: Option<&Runtime>,
context: &mut MiddlewareContext,
) -> Result<(), MiddlewareError> {
self.after_finish(finish, result, context).await
}
async fn before_model_call(
&self,
request: &ModelRequest,
context: &mut MiddlewareContext,
) -> Result<Option<ModelRequest>, MiddlewareError> {
let _ = (request, context);
Ok(None)
}
async fn after_model_call(
&self,
request: &ModelRequest,
response: &ModelResponse,
context: &mut MiddlewareContext,
) -> Result<Option<ModelResponse>, MiddlewareError> {
let _ = (request, response, context);
Ok(None)
}
}
pub mod content_filter;
pub mod guardrail_utils;
pub mod human_in_loop;
pub mod logging;
pub mod pii;
pub mod pii_detector;
pub mod rate_limit;
pub mod retry;
pub mod safety_guardrail;
pub mod skill_injection;
pub mod summarization;
pub mod tool_result_eviction;
pub use content_filter::ContentFilterMiddleware;
pub use guardrail_utils::*;
pub use human_in_loop::HumanInTheLoopMiddleware;
pub use logging::{LogLevel, LoggingMiddleware};
pub use pii::{PIIMiddleware, PIIStrategy};
pub use pii_detector::{detect_all_pii, PIIDetector, PIIMatch, PIIType};
pub use rate_limit::RateLimitMiddleware;
pub use retry::RetryMiddleware;
pub use safety_guardrail::SafetyGuardrailMiddleware;
pub use skill_injection::{build_skills_middleware, SkillsMiddleware};
pub use summarization::SummarizationMiddleware;
pub use tool_result_eviction::ToolResultEvictionMiddleware;
pub mod chain;
pub use chain::{MiddlewareChainConfig, MiddlewareChainExecutor, MiddlewareResult};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_middleware_context() {
let mut ctx = MiddlewareContext::new();
assert_eq!(ctx.iteration, 0);
assert_eq!(ctx.tool_call_count, 0);
ctx.increment_iteration();
ctx.increment_tool_call_count();
assert_eq!(ctx.iteration, 1);
assert_eq!(ctx.tool_call_count, 1);
ctx.set_custom_data("key".to_string(), Value::String("value".to_string()));
assert_eq!(
ctx.get_custom_data("key"),
Some(&Value::String("value".to_string()))
);
}
}