use async_trait::async_trait;
use rucora_core::agent::AgentError;
use rucora_core::agent::AgentInput;
use rucora_core::agent::AgentOutput;
use rucora_core::tool::types::ToolCall;
use rucora_core::tool::types::ToolResult;
use std::sync::Arc;
#[async_trait]
pub trait Middleware: Send + Sync {
fn name(&self) -> &str;
async fn on_request(&self, input: &mut AgentInput) -> Result<(), AgentError> {
let _ = input;
Ok(())
}
async fn on_response(&self, output: &mut AgentOutput) -> Result<(), AgentError> {
let _ = output;
Ok(())
}
async fn on_error(&self, error: &mut AgentError) -> Result<(), AgentError> {
let _ = error;
Ok(())
}
async fn on_tool_call_before(&self, call: &mut ToolCall) -> Result<(), AgentError> {
let _ = call;
Ok(())
}
async fn on_tool_call_after(&self, result: &mut ToolResult) -> Result<(), AgentError> {
let _ = result;
Ok(())
}
}
#[derive(Clone)]
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::new(),
}
}
pub fn with<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.middlewares.push(Arc::new(middleware));
self
}
pub fn with_arc(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middlewares.push(middleware);
self
}
pub async fn process_request(&self, input: &mut AgentInput) -> Result<(), AgentError> {
for middleware in &self.middlewares {
middleware.on_request(input).await?;
}
Ok(())
}
pub async fn process_response(&self, output: &mut AgentOutput) -> Result<(), AgentError> {
for middleware in self.middlewares.iter().rev() {
middleware.on_response(output).await?;
}
Ok(())
}
pub async fn process_error(&self, error: &mut AgentError) -> Result<(), AgentError> {
for middleware in self.middlewares.iter().rev() {
middleware.on_error(error).await?;
}
Ok(())
}
pub async fn process_tool_call_before(&self, call: &mut ToolCall) -> Result<(), AgentError> {
for middleware in &self.middlewares {
middleware.on_tool_call_before(call).await?;
}
Ok(())
}
pub async fn process_tool_call_after(&self, result: &mut ToolResult) -> Result<(), AgentError> {
for middleware in self.middlewares.iter().rev() {
middleware.on_tool_call_after(result).await?;
}
Ok(())
}
pub fn len(&self) -> usize {
self.middlewares.len()
}
pub fn is_empty(&self) -> bool {
self.middlewares.is_empty()
}
}
pub struct LoggingMiddleware {
log_request: bool,
log_response: bool,
}
impl LoggingMiddleware {
pub fn new() -> Self {
Self {
log_request: true,
log_response: true,
}
}
pub fn with_log_request(mut self, enable: bool) -> Self {
self.log_request = enable;
self
}
pub fn with_log_response(mut self, enable: bool) -> Self {
self.log_response = enable;
self
}
}
impl Default for LoggingMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for LoggingMiddleware {
fn name(&self) -> &str {
"logging"
}
async fn on_request(&self, input: &mut AgentInput) -> Result<(), AgentError> {
if self.log_request {
tracing::info!(input_len = input.text.len(), "middleware.logging.request");
}
Ok(())
}
async fn on_response(&self, output: &mut AgentOutput) -> Result<(), AgentError> {
if self.log_response {
tracing::info!(
output_value = %output.value,
messages_len = output.messages.len(),
tool_calls_len = output.tool_calls.len(),
"middleware.logging.response"
);
}
Ok(())
}
}
pub struct RateLimitMiddleware {
max_requests: usize,
window_secs: u64,
}
impl RateLimitMiddleware {
pub fn new(max_requests: usize) -> Self {
Self {
max_requests,
window_secs: 60,
}
}
pub fn with_window_secs(mut self, secs: u64) -> Self {
self.window_secs = secs;
self
}
}
#[async_trait]
impl Middleware for RateLimitMiddleware {
fn name(&self) -> &str {
"rate_limit"
}
async fn on_request(&self, _input: &mut AgentInput) -> Result<(), AgentError> {
tracing::debug!(
max_requests = self.max_requests,
window_secs = self.window_secs,
"middleware.rate_limit.check"
);
Ok(())
}
}
pub struct CacheMiddleware {
enabled: bool,
}
impl CacheMiddleware {
pub fn new() -> Self {
Self { enabled: true }
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
impl Default for CacheMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for CacheMiddleware {
fn name(&self) -> &str {
"cache"
}
async fn on_request(&self, input: &mut AgentInput) -> Result<(), AgentError> {
if self.enabled {
tracing::debug!(input_len = input.text.len(), "middleware.cache.request");
}
Ok(())
}
async fn on_response(&self, output: &mut AgentOutput) -> Result<(), AgentError> {
if self.enabled {
tracing::debug!(
output_value_len = %output.value,
"middleware.cache.response"
);
}
Ok(())
}
}
#[derive(Clone)]
pub struct MetricsMiddleware {
request_count: Arc<std::sync::atomic::AtomicU64>,
response_count: Arc<std::sync::atomic::AtomicU64>,
}
impl MetricsMiddleware {
pub fn new() -> Self {
Self {
request_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
response_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub fn get_request_count(&self) -> u64 {
self.request_count
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn get_response_count(&self) -> u64 {
self.response_count
.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Default for MetricsMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for MetricsMiddleware {
fn name(&self) -> &str {
"metrics"
}
async fn on_request(&self, _input: &mut AgentInput) -> Result<(), AgentError> {
self.request_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
async fn on_response(&self, _output: &mut AgentOutput) -> Result<(), AgentError> {
self.response_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_middleware_chain() {
let chain = MiddlewareChain::new()
.with(LoggingMiddleware::new())
.with(CacheMiddleware::new());
assert_eq!(chain.len(), 2);
let mut input = AgentInput::new("test");
assert!(chain.process_request(&mut input).await.is_ok());
let mut output = AgentOutput::new(serde_json::json!({"content": "response"}));
assert!(chain.process_response(&mut output).await.is_ok());
}
#[tokio::test]
async fn test_metrics_middleware() {
let metrics = MetricsMiddleware::new();
let chain = MiddlewareChain::new().with(metrics.clone());
assert_eq!(metrics.get_request_count(), 0);
assert_eq!(metrics.get_response_count(), 0);
let mut input = AgentInput::new("test");
chain.process_request(&mut input).await.unwrap();
assert_eq!(metrics.get_request_count(), 1);
let mut output = AgentOutput::new(serde_json::json!({"content": "test"}));
chain.process_response(&mut output).await.unwrap();
assert_eq!(metrics.get_response_count(), 1);
}
}