use std::cell::Cell;
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;
use async_trait::async_trait;
use futures::Stream;
use tracing::{Level, event};
use crate::llm::{CallOptions, ChatModel, LlmError, Message, MessageChunk, ToolDefinition};
thread_local! {
static MIDDLEWARE_START_TIME: Cell<Option<Instant>> = const { Cell::new(None) };
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
pub trait LlmMiddleware: Send + Sync + 'static {
async fn pre_invoke(
&self,
_messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
Ok(())
}
async fn post_invoke(&self, _result: &mut Result<Message, LlmError>) -> Result<(), LlmError> {
Ok(())
}
}
#[derive(Clone)]
pub struct MiddlewareModel<M: ChatModel> {
inner: M,
middleware: Vec<Arc<dyn LlmMiddleware>>,
}
impl<M: ChatModel> fmt::Debug for MiddlewareModel<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MiddlewareModel")
.field("inner", &self.inner.model_name())
.field("middleware_count", &self.middleware.len())
.finish()
}
}
impl<M: ChatModel> MiddlewareModel<M> {
#[must_use]
pub fn new(inner: M) -> Self {
Self {
inner,
middleware: Vec::new(),
}
}
#[must_use]
pub fn with_middleware(mut self, middleware: impl LlmMiddleware) -> Self {
self.middleware.push(Arc::new(middleware));
self
}
#[must_use]
pub fn with_middlewares(mut self, middleware: &[Arc<dyn LlmMiddleware>]) -> Self {
self.middleware.extend_from_slice(middleware);
self
}
}
impl<M: ChatModel + Default> Default for MiddlewareModel<M> {
fn default() -> Self {
Self::new(M::default())
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl<M: ChatModel> ChatModel for MiddlewareModel<M> {
async fn invoke(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Result<Message, LlmError> {
let mut messages = messages.to_vec();
let mut options = options.cloned().unwrap_or_default();
for mw in &self.middleware {
mw.pre_invoke(&mut messages, &mut options).await?;
}
let mut result = self.inner.invoke(&messages, Some(&options)).await;
for mw in self.middleware.iter().rev() {
mw.post_invoke(&mut result).await?;
}
result
}
fn stream(
&self,
messages: &[Message],
options: Option<&CallOptions>,
) -> Pin<Box<dyn Stream<Item = Result<MessageChunk, LlmError>> + Send + '_>> {
self.inner.stream(messages, options)
}
fn bind_tools(&self, tools: Vec<ToolDefinition>) -> Self {
let inner_with_tools = self.inner.bind_tools(tools);
Self {
inner: inner_with_tools,
middleware: self.middleware.clone(),
}
}
fn model_name(&self) -> &str {
self.inner.model_name()
}
}
#[derive(Clone, Debug)]
pub struct LoggingMiddleware {
model_name: String,
}
impl LoggingMiddleware {
#[must_use]
pub const fn new() -> Self {
Self {
model_name: String::new(),
}
}
#[must_use]
pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
self.model_name = model_name.into();
self
}
}
impl Default for LoggingMiddleware {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for LoggingMiddleware {
async fn pre_invoke(
&self,
messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
event!(
name: "llm.invoke.started",
Level::INFO,
model_name = %self.model_name,
message_count = messages.len(),
"LLM invoke started",
);
Ok(())
}
async fn post_invoke(&self, result: &mut Result<Message, LlmError>) -> Result<(), LlmError> {
let status = if result.is_ok() { "ok" } else { "error" };
event!(
name: "llm.invoke.completed",
Level::INFO,
model_name = %self.model_name,
status,
"LLM invoke completed",
);
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct LlmMetrics {
pub invoke_count: u64,
pub error_count: u64,
pub avg_duration_ms: u64,
pub total_duration_ms: u64,
}
#[derive(Clone, Debug)]
pub struct MetricsMiddleware {
invoke_count: Arc<AtomicU64>,
error_count: Arc<AtomicU64>,
total_duration_ms: Arc<AtomicU64>,
}
impl MetricsMiddleware {
#[must_use]
pub fn new() -> Self {
Self {
invoke_count: Arc::new(AtomicU64::new(0)),
error_count: Arc::new(AtomicU64::new(0)),
total_duration_ms: Arc::new(AtomicU64::new(0)),
}
}
#[must_use]
pub fn metrics(&self) -> LlmMetrics {
let invoke_count = self.invoke_count.load(Ordering::Relaxed);
let error_count = self.error_count.load(Ordering::Relaxed);
let total_duration_ms = self.total_duration_ms.load(Ordering::Relaxed);
let avg_duration_ms = if invoke_count > 0 {
total_duration_ms / invoke_count
} else {
0
};
LlmMetrics {
invoke_count,
error_count,
avg_duration_ms,
total_duration_ms,
}
}
pub fn reset(&self) {
self.invoke_count.store(0, Ordering::Relaxed);
self.error_count.store(0, Ordering::Relaxed);
self.total_duration_ms.store(0, Ordering::Relaxed);
}
}
impl Default for MetricsMiddleware {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for MetricsMiddleware {
async fn pre_invoke(
&self,
_messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
#[cfg(not(target_family = "wasm"))]
MIDDLEWARE_START_TIME.set(Some(Instant::now()));
Ok(())
}
async fn post_invoke(&self, result: &mut Result<Message, LlmError>) -> Result<(), LlmError> {
let duration_ms = MIDDLEWARE_START_TIME.with(|start| {
start.take().map_or(0, |s| {
s.elapsed().as_millis().try_into().unwrap_or(u64::MAX)
})
});
match result {
Ok(_) => {
self.invoke_count.fetch_add(1, Ordering::Relaxed);
self.total_duration_ms
.fetch_add(duration_ms, Ordering::Relaxed);
}
Err(_) => {
self.error_count.fetch_add(1, Ordering::Relaxed);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::mock::MockChatModel;
#[test]
fn test_middleware_model_new() {
let base_model = MockChatModel::new("gpt-4");
let model = MiddlewareModel::new(base_model);
assert!(model.middleware.is_empty());
assert_eq!(model.model_name(), "gpt-4");
}
#[test]
fn test_middleware_model_builder() {
let base_model = MockChatModel::new("gpt-4");
let log_mw = LoggingMiddleware::new();
let metrics_mw = MetricsMiddleware::new();
let model = MiddlewareModel::new(base_model)
.with_middleware(log_mw)
.with_middleware(metrics_mw);
assert_eq!(model.middleware.len(), 2);
}
#[test]
fn test_logging_middleware() {
let middleware = LoggingMiddleware::new();
assert_eq!(middleware.model_name, "");
let with_name = LoggingMiddleware::new().with_model_name("gpt-4");
assert_eq!(with_name.model_name, "gpt-4");
}
#[test]
fn test_metrics_middleware_new() {
let middleware = MetricsMiddleware::new();
let metrics = middleware.metrics();
assert_eq!(metrics.invoke_count, 0);
assert_eq!(metrics.error_count, 0);
assert_eq!(metrics.total_duration_ms, 0);
assert_eq!(metrics.avg_duration_ms, 0);
}
#[test]
fn test_metrics_middleware_reset() {
let middleware = MetricsMiddleware::new();
middleware.invoke_count.fetch_add(5, Ordering::Relaxed);
middleware.error_count.fetch_add(2, Ordering::Relaxed);
middleware
.total_duration_ms
.fetch_add(100, Ordering::Relaxed);
let metrics = middleware.metrics();
assert_eq!(metrics.invoke_count, 5);
assert_eq!(metrics.error_count, 2);
middleware.reset();
let metrics_after = middleware.metrics();
assert_eq!(metrics_after.invoke_count, 0);
assert_eq!(metrics_after.error_count, 0);
assert_eq!(metrics_after.total_duration_ms, 0);
}
struct AbortMiddleware;
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for AbortMiddleware {
async fn pre_invoke(
&self,
_messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
Err(LlmError::Other(Box::new(std::io::Error::other(
"aborted by middleware",
))))
}
}
#[tokio::test]
async fn test_pre_invoke_abort() {
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model).with_middleware(AbortMiddleware);
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("aborted by middleware")
);
}
struct ResultModifierMiddleware;
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for ResultModifierMiddleware {
async fn post_invoke(
&self,
result: &mut Result<Message, LlmError>,
) -> Result<(), LlmError> {
if result.is_err() {
*result = Ok(Message::ai("Fallback response"));
}
Ok(())
}
}
#[tokio::test]
async fn test_post_invoke_modifies_result() {
let base_model = MockChatModel::new("gpt-4").with_error();
let model = MiddlewareModel::new(base_model).with_middleware(ResultModifierMiddleware);
let messages = vec![Message::human("Hi")];
let result = model.invoke(&messages, None).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(matches!(response.role, crate::llm::Role::Ai));
}
struct OrderRecorder {
order: Arc<std::sync::Mutex<Vec<String>>>,
name: String,
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl LlmMiddleware for OrderRecorder {
async fn pre_invoke(
&self,
_messages: &mut Vec<Message>,
_options: &mut CallOptions,
) -> Result<(), LlmError> {
self.order
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(format!("{}_pre", self.name));
Ok(())
}
async fn post_invoke(
&self,
_result: &mut Result<Message, LlmError>,
) -> Result<(), LlmError> {
self.order
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(format!("{}_post", self.name));
Ok(())
}
}
#[tokio::test]
async fn test_post_invoke_order() {
let order = Arc::new(std::sync::Mutex::new(Vec::new()));
let mw1 = OrderRecorder {
order: Arc::clone(&order),
name: "first".to_string(),
};
let mw2 = OrderRecorder {
order: Arc::clone(&order),
name: "second".to_string(),
};
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let model = MiddlewareModel::new(base_model)
.with_middleware(mw1)
.with_middleware(mw2);
let messages = vec![Message::human("Hi")];
let _ = model.invoke(&messages, None).await;
let order_data = {
let order_guard = order
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
order_guard.clone()
};
assert_eq!(
order_data,
vec![
"first_pre".to_string(),
"second_pre".to_string(),
"second_post".to_string(),
"first_post".to_string(),
]
);
}
#[test]
fn test_bind_tools_preserves_middleware() {
let base_model = MockChatModel::new("gpt-4");
let _log_mw = LoggingMiddleware::new();
let _metrics_mw = MetricsMiddleware::new();
let model = MiddlewareModel::new(base_model)
.with_middleware(LoggingMiddleware::new())
.with_middleware(MetricsMiddleware::new());
let model_with_tools = model.bind_tools(vec![]);
assert_eq!(model_with_tools.middleware.len(), 2);
}
#[tokio::test]
async fn test_metrics_middleware_tracks_invocations() {
let base_model = MockChatModel::new("gpt-4").with_response("Hello!");
let middleware = MetricsMiddleware::new();
let model = MiddlewareModel::new(base_model).with_middleware(middleware.clone());
let messages = vec![Message::human("Hi")];
let _ = model.invoke(&messages, None).await;
let metrics = middleware.metrics();
assert_eq!(metrics.invoke_count, 1);
assert_eq!(metrics.error_count, 0);
let _ = metrics.total_duration_ms;
let _ = metrics.avg_duration_ms;
}
#[tokio::test]
async fn test_metrics_middleware_tracks_errors() {
let base_model = MockChatModel::new("gpt-4").with_error();
let middleware = MetricsMiddleware::new();
let model = MiddlewareModel::new(base_model).with_middleware(middleware.clone());
let messages = vec![Message::human("Hi")];
let _ = model.invoke(&messages, None).await;
let metrics = middleware.metrics();
assert_eq!(metrics.invoke_count, 0);
assert_eq!(metrics.error_count, 1);
assert_eq!(metrics.total_duration_ms, 0);
assert_eq!(metrics.avg_duration_ms, 0);
}
}