use crate::tools::improvements_errors::{EventType, ObservabilityContext};
use std::sync::Arc;
const LAYER_LOGGING: &str = "logging";
const LAYER_CACHING: &str = "caching";
const LAYER_RETRY: &str = "retry";
const LAYER_VALIDATION: &str = "validation";
const LAYER_METRICS: &str = "metrics";
const LAYER_CIRCUIT_BREAKER: &str = "circuit_breaker";
#[deprecated(since = "0.1.0", note = "Use async_middleware::ToolResult instead")]
#[derive(Debug, Clone)]
pub struct MiddlewareResult {
pub success: bool,
pub result: Option<String>,
pub error: Option<MiddlewareError>,
pub metadata: ExecutionMetadata,
}
impl MiddlewareResult {
pub fn clone_with_metadata_update(&self, new_metadata: ExecutionMetadata) -> Self {
Self {
success: self.success,
result: self.result.clone(),
error: self.error.clone(),
metadata: new_metadata,
}
}
}
#[deprecated(since = "0.1.0", note = "Use async_middleware error handling instead")]
#[derive(Debug, Clone, thiserror::Error)]
pub enum MiddlewareError {
#[error("execution failed: {0}")]
ExecutionFailed(&'static str),
#[error("validation failed: {0}")]
ValidationFailed(&'static str),
#[error("cache failed: {0}")]
CacheFailed(&'static str),
#[error("execution timeout exceeded")]
TimeoutExceeded,
#[error("execution cancelled")]
Cancelled,
}
#[deprecated(since = "0.1.0", note = "Use async_middleware types instead")]
#[derive(Debug, Clone, Default)]
pub struct ExecutionMetadata {
pub duration_ms: u64,
pub from_cache: bool,
pub retry_count: u32,
pub layers_executed: Vec<String>,
pub warnings: Vec<String>,
}
#[deprecated(
since = "0.1.0",
note = "Use async_middleware::AsyncMiddleware instead"
)]
pub trait Middleware: Send + Sync {
fn name(&self) -> &str;
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult;
}
#[deprecated(since = "0.1.0", note = "Use async_middleware::ToolRequest instead")]
#[derive(Debug, Clone)]
pub struct ToolRequest {
pub tool_name: String,
pub arguments: String,
pub context: String,
pub metadata: RequestMetadata,
}
#[deprecated(since = "0.1.0", note = "Use async_middleware types instead")]
#[derive(Debug, Clone)]
pub struct RequestMetadata {
pub request_id: String,
pub parent_request_id: Option<String>,
pub priority: u32,
pub timeout_ms: u64,
pub tags: Vec<String>,
}
impl Default for RequestMetadata {
fn default() -> Self {
use std::time::SystemTime;
let timestamp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
Self {
request_id: format!("req-{}", timestamp),
parent_request_id: None,
priority: 50,
timeout_ms: 30000,
tags: Vec::with_capacity(3),
}
}
}
#[deprecated(
since = "0.1.0",
note = "Use async_middleware::AsyncLoggingMiddleware instead"
)]
pub struct LoggingMiddleware {
#[allow(dead_code)]
level: tracing::Level,
}
impl LoggingMiddleware {
pub fn new(level: tracing::Level) -> Self {
Self { level }
}
}
impl Middleware for LoggingMiddleware {
fn name(&self) -> &str {
"logging"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
let tool_name = request.tool_name.clone();
let request_id = request.metadata.request_id.clone();
tracing::debug!(
tool = %tool_name,
request_id = %request_id,
arguments = %request.arguments,
"tool_execution_started"
);
let start = std::time::Instant::now();
let mut result = next(request);
let duration = start.elapsed().as_millis() as u64;
if result.success {
tracing::debug!(
tool = %tool_name,
duration_ms = duration,
"tool_execution_completed"
);
} else {
tracing::error!(
tool = %tool_name,
error = ?result.error,
"tool_execution_failed"
);
}
result.metadata.duration_ms = duration;
result.metadata.layers_executed.push(LAYER_LOGGING.into());
result
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
value: Arc<String>,
timestamp: std::time::Instant,
}
#[deprecated(
since = "0.1.0",
note = "Use async_middleware::AsyncCachingMiddleware instead"
)]
pub struct CachingMiddleware {
cache: Arc<std::sync::Mutex<hashbrown::HashMap<String, CacheEntry>>>,
max_age_secs: u64,
max_entries: usize,
max_value_bytes: usize,
}
impl CachingMiddleware {
pub fn new() -> Self {
Self {
cache: Arc::new(std::sync::Mutex::new(hashbrown::HashMap::new())),
max_age_secs: 300, max_entries: 256,
max_value_bytes: 512 * 1024, }
}
pub fn with_max_age(mut self, max_age_secs: u64) -> Self {
self.max_age_secs = max_age_secs;
self
}
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries.max(1);
self
}
pub fn with_max_value_bytes(mut self, max_value_bytes: usize) -> Self {
self.max_value_bytes = max_value_bytes.max(1024); self
}
fn cache_key(tool: &str, args: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
let mut hasher = DefaultHasher::new();
hasher.write(args.as_bytes());
format!("{}:{}", tool, hasher.finish())
}
fn is_stale(&self, entry: &CacheEntry) -> bool {
entry.timestamp.elapsed().as_secs() > self.max_age_secs
}
}
impl Default for CachingMiddleware {
fn default() -> Self {
Self::new()
}
}
impl Middleware for CachingMiddleware {
fn name(&self) -> &str {
"caching"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
let key = Self::cache_key(&request.tool_name, &request.arguments);
if let Ok(mut cache) = self.cache.lock() {
cache.retain(|_, entry| !self.is_stale(entry));
if let Some(entry) = cache.get(&key)
&& !self.is_stale(entry)
{
return MiddlewareResult {
success: true,
result: Some((*entry.value).clone()),
error: None,
metadata: ExecutionMetadata {
from_cache: true,
layers_executed: vec![LAYER_CACHING.into()],
..Default::default()
},
};
}
}
let mut result = next(request);
if result.success
&& let Some(ref output) = result.result
&& output.len() <= self.max_value_bytes
&& let Ok(mut cache) = self.cache.lock()
{
cache.insert(
key,
CacheEntry {
value: Arc::new(output.clone()),
timestamp: std::time::Instant::now(),
},
);
while cache.len() > self.max_entries {
if let Some((oldest_key, _)) = cache.iter().min_by_key(|(_, entry)| entry.timestamp)
{
let key_to_remove = oldest_key.clone();
cache.remove(&key_to_remove);
} else {
break;
}
}
} else if result.success
&& let Some(ref output) = result.result
&& output.len() > self.max_value_bytes
{
result
.metadata
.warnings
.push("Skipped caching: payload exceeds cache size limit".to_string());
}
result.metadata.layers_executed.push(LAYER_CACHING.into());
result
}
}
#[deprecated(
since = "0.1.0",
note = "Use async_middleware::AsyncRetryMiddleware instead"
)]
pub struct RetryMiddleware {
max_attempts: u32,
initial_backoff_ms: u64,
max_backoff_ms: u64,
}
impl RetryMiddleware {
pub fn new(max_attempts: u32, initial_backoff_ms: u64, max_backoff_ms: u64) -> Self {
Self {
max_attempts,
initial_backoff_ms,
max_backoff_ms,
}
}
fn backoff_duration(&self, attempt: u32) -> u64 {
let backoff = self.initial_backoff_ms * 2_u64.pow(attempt);
backoff.min(self.max_backoff_ms)
}
}
impl Middleware for RetryMiddleware {
fn name(&self) -> &str {
"retry"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
let mut result = next(request.clone());
if !result.success && self.max_attempts > 1 {
for attempt in 1..self.max_attempts {
let backoff = self.backoff_duration(attempt - 1);
std::thread::sleep(std::time::Duration::from_millis(backoff));
result.metadata.retry_count = attempt;
result = next(request.clone());
if result.success {
break;
}
}
}
result.metadata.layers_executed.push(LAYER_RETRY.into());
result
}
}
#[deprecated(since = "0.1.0", note = "Use async_middleware types instead")]
pub struct ValidationMiddleware {
obs_context: Arc<ObservabilityContext>,
}
impl ValidationMiddleware {
pub fn new(obs_context: Arc<ObservabilityContext>) -> Self {
Self { obs_context }
}
fn validate_request(&self, request: &ToolRequest) -> Result<(), MiddlewareError> {
if request.tool_name.is_empty() {
return Err(MiddlewareError::ValidationFailed("tool_name is empty"));
}
if request.arguments.is_empty() {
self.obs_context.event(
EventType::ErrorOccurred,
"validation",
"arguments is empty",
None,
);
}
Ok(())
}
}
impl Middleware for ValidationMiddleware {
fn name(&self) -> &str {
"validation"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
if let Err(err) = self.validate_request(&request) {
return MiddlewareResult {
success: false,
result: None,
error: Some(err),
metadata: ExecutionMetadata::default(),
};
}
let mut result = next(request);
result
.metadata
.layers_executed
.push(LAYER_VALIDATION.into());
result
}
}
pub struct MetricsMiddleware {
analyzer: Arc<std::sync::RwLock<crate::exec::agent_optimization::AgentBehaviorAnalyzer>>,
}
impl MetricsMiddleware {
pub fn new(
analyzer: Arc<std::sync::RwLock<crate::exec::agent_optimization::AgentBehaviorAnalyzer>>,
) -> Self {
Self { analyzer }
}
}
impl Middleware for MetricsMiddleware {
fn name(&self) -> &str {
"metrics"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
let tool_name = request.tool_name.clone();
let result = next(request);
if let Ok(mut analyzer) = self.analyzer.write() {
if result.success {
analyzer.record_tool_usage(&tool_name);
} else {
let error_msg = result
.error
.as_ref()
.map(|e| e.to_string())
.unwrap_or_else(|| "unknown error".to_string());
analyzer.record_tool_failure(&tool_name, &error_msg);
}
}
let mut updated_result = result;
updated_result
.metadata
.layers_executed
.push(LAYER_METRICS.into());
updated_result
}
}
pub struct CircuitBreakerMiddleware {
breaker: crate::tools::circuit_breaker::CircuitBreaker,
}
impl CircuitBreakerMiddleware {
pub fn new(failure_threshold: f64) -> Self {
let config = crate::tools::circuit_breaker::CircuitBreakerConfig::default();
if failure_threshold > 0.0 {
}
Self {
breaker: crate::tools::circuit_breaker::CircuitBreaker::new(config),
}
}
}
impl Middleware for CircuitBreakerMiddleware {
fn name(&self) -> &str {
"circuit_breaker"
}
fn execute(
&self,
request: ToolRequest,
next: Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync>,
) -> MiddlewareResult {
let tool_name = request.tool_name.clone();
if !self.breaker.allow_request_for_tool(&tool_name) {
let wait_time = self
.breaker
.remaining_backoff(&tool_name)
.map(|d| format!("{}s", d.as_secs()))
.unwrap_or_else(|| "unknown".to_string());
return MiddlewareResult {
success: false,
result: None,
error: Some(MiddlewareError::ExecutionFailed(
"circuit breaker open - tool has high failure rate",
)),
metadata: ExecutionMetadata {
layers_executed: vec!["circuit_breaker".into()],
warnings: vec![format!(
"Tool {} blocked by circuit breaker due to high failure rate. Wait time: {}",
tool_name, wait_time
)],
..Default::default()
},
};
}
let result = next(request);
if result.success {
self.breaker.record_success_for_tool(&tool_name);
} else {
self.breaker.record_failure_for_tool(&tool_name, false);
}
let mut updated_result = result;
updated_result
.metadata
.layers_executed
.push(LAYER_CIRCUIT_BREAKER.into());
updated_result
}
}
#[deprecated(
since = "0.1.0",
note = "Use async_middleware::AsyncMiddlewareChain instead"
)]
pub struct MiddlewareChain {
middlewares: Vec<Arc<dyn Middleware>>,
}
impl MiddlewareChain {
pub fn new() -> Self {
Self {
middlewares: Vec::with_capacity(5),
}
}
pub fn with_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
self.middlewares.push(middleware);
self
}
pub fn with_metrics(
self,
analyzer: Arc<std::sync::RwLock<crate::exec::agent_optimization::AgentBehaviorAnalyzer>>,
) -> Self {
self.with_middleware(Arc::new(MetricsMiddleware::new(analyzer)))
}
pub fn with_circuit_breaker(self, threshold: f64) -> Self {
self.with_middleware(Arc::new(CircuitBreakerMiddleware::new(threshold)))
}
pub fn execute_sync<F>(&self, request: ToolRequest, executor: F) -> MiddlewareResult
where
F: Fn(ToolRequest) -> MiddlewareResult + Send + Sync + 'static,
{
let executor = Arc::new(executor);
let mut factory: Arc<
dyn Fn() -> Box<dyn Fn(ToolRequest) -> MiddlewareResult + Send + Sync> + Send + Sync,
> = Arc::new(move || {
let executor = executor.clone();
Box::new(move |req| executor(req))
});
for middleware in self.middlewares.iter().rev() {
let mw = middleware.clone();
let next_factory = factory.clone();
factory = Arc::new(move || {
let mw = mw.clone();
let next_factory = next_factory.clone();
Box::new(move |req| {
let next = next_factory();
mw.execute(req, next)
})
});
}
let root_fn = factory();
root_fn(request)
}
}
impl Default for MiddlewareChain {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logging_middleware() {
let middleware = LoggingMiddleware::new(tracing::Level::INFO);
let request = ToolRequest {
tool_name: crate::config::constants::tools::UNIFIED_SEARCH.into(),
arguments: "pattern:test".into(),
context: "src/".into(),
metadata: RequestMetadata::default(),
};
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("found test".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result = middleware.execute(request, executor);
assert!(result.success);
assert!(result.metadata.layers_executed.contains(&"logging".into()));
}
#[test]
fn test_caching_middleware() {
let middleware = CachingMiddleware::new();
let request = ToolRequest {
tool_name: "test_tool".into(),
arguments: "arg1".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result1 = middleware.execute(request.clone(), executor);
assert!(!result1.metadata.from_cache);
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("new result".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result2 = middleware.execute(request, executor);
assert!(result2.metadata.from_cache);
}
#[test]
fn test_validation_middleware() {
let obs = Arc::new(ObservabilityContext::noop());
let middleware = ValidationMiddleware::new(obs);
let invalid_request = ToolRequest {
tool_name: String::new(),
arguments: "arg".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result = middleware.execute(invalid_request, executor);
assert!(!result.success);
}
#[test]
fn test_metrics_middleware() {
use crate::exec::agent_optimization::AgentBehaviorAnalyzer;
let analyzer = Arc::new(std::sync::RwLock::new(AgentBehaviorAnalyzer::new()));
let middleware = MetricsMiddleware::new(analyzer.clone());
let request = ToolRequest {
tool_name: "test_tool".into(),
arguments: "arg".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result = middleware.execute(request, executor);
assert!(result.success);
assert!(result.metadata.layers_executed.contains(&"metrics".into()));
let analyzer_lock = analyzer.read().unwrap();
assert_eq!(
*analyzer_lock
.tool_stats()
.usage_frequency
.get("test_tool")
.unwrap(),
1
);
}
#[test]
fn test_circuit_breaker_middleware() {
let middleware = CircuitBreakerMiddleware::new(0.5);
let threshold =
crate::tools::circuit_breaker::CircuitBreakerConfig::default().failure_threshold;
let request = ToolRequest {
tool_name: "failing_tool".into(),
arguments: "arg".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
for _ in 0..threshold {
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: false,
result: None,
error: Some(MiddlewareError::ExecutionFailed("test error")),
metadata: ExecutionMetadata::default(),
});
let _ = middleware.execute(request.clone(), executor);
}
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("should not execute".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result = middleware.execute(request, executor);
assert!(!result.success);
assert!(
result
.metadata
.layers_executed
.contains(&"circuit_breaker".into())
);
}
#[test]
fn test_caching_middleware_staleness() {
let middleware = CachingMiddleware::new().with_max_age(1);
let request = ToolRequest {
tool_name: "test_tool".into(),
arguments: "arg".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result1".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result1 = middleware.execute(request.clone(), executor);
assert!(!result1.metadata.from_cache);
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result2".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result2 = middleware.execute(request.clone(), executor);
assert!(result2.metadata.from_cache);
assert_eq!(result2.result.unwrap(), "result1");
std::thread::sleep(std::time::Duration::from_secs(2));
let executor = Box::new(|_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result3".into()),
error: None,
metadata: ExecutionMetadata::default(),
});
let result3 = middleware.execute(request, executor);
assert!(!result3.metadata.from_cache);
assert_eq!(result3.result.unwrap(), "result3");
}
#[test]
fn test_middleware_chain_with_metrics_and_circuit_breaker() {
use crate::exec::agent_optimization::AgentBehaviorAnalyzer;
let analyzer = Arc::new(std::sync::RwLock::new(AgentBehaviorAnalyzer::new()));
let chain = MiddlewareChain::new()
.with_metrics(analyzer.clone())
.with_circuit_breaker(0.8);
let request = ToolRequest {
tool_name: "test_tool".into(),
arguments: "arg".into(),
context: "ctx".into(),
metadata: RequestMetadata::default(),
};
let executor = |_req: ToolRequest| MiddlewareResult {
success: true,
result: Some("result".into()),
error: None,
metadata: ExecutionMetadata::default(),
};
let result = chain.execute_sync(request, executor);
assert!(result.success);
assert!(result.metadata.layers_executed.contains(&"metrics".into()));
assert!(
result
.metadata
.layers_executed
.contains(&"circuit_breaker".into())
);
}
}