use std::time::{Duration, Instant};
use async_trait::async_trait;
use rs_genai::prelude::FunctionCall;
use super::Middleware;
use crate::error::{AgentError, ToolError};
#[derive(Debug, Clone)]
pub struct ToolLatency {
pub name: String,
pub elapsed: Duration,
pub success: bool,
}
pub struct LatencyMiddleware {
in_flight: parking_lot::Mutex<std::collections::HashMap<String, Instant>>,
records: parking_lot::Mutex<Vec<ToolLatency>>,
}
impl LatencyMiddleware {
pub fn new() -> Self {
Self {
in_flight: parking_lot::Mutex::new(std::collections::HashMap::new()),
records: parking_lot::Mutex::new(Vec::new()),
}
}
pub fn tool_latencies(&self) -> Vec<ToolLatency> {
self.records.lock().clone()
}
pub fn clear(&self) {
self.in_flight.lock().clear();
self.records.lock().clear();
}
}
impl Default for LatencyMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for LatencyMiddleware {
fn name(&self) -> &str {
"latency"
}
async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
self.in_flight
.lock()
.insert(call.name.clone(), Instant::now());
Ok(())
}
async fn after_tool(
&self,
call: &FunctionCall,
_result: &serde_json::Value,
) -> Result<(), AgentError> {
let elapsed = self
.in_flight
.lock()
.remove(&call.name)
.map(|start| start.elapsed())
.unwrap_or_default();
self.records.lock().push(ToolLatency {
name: call.name.clone(),
elapsed,
success: true,
});
Ok(())
}
async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
let elapsed = self
.in_flight
.lock()
.remove(&call.name)
.map(|start| start.elapsed())
.unwrap_or_default();
self.records.lock().push(ToolLatency {
name: call.name.clone(),
elapsed,
success: false,
});
Ok(())
}
}