use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::Result;
use crate::model::types::{ChatRequest, ChatResponse};
use crate::tool::ToolCall;
#[derive(Debug)]
pub enum MiddlewareAction {
Continue,
ShortCircuit(ChatResponse),
}
pub trait Middleware: Send + Sync {
fn on_request(
&self,
_request: &mut ChatRequest,
) -> impl Future<Output = Result<MiddlewareAction>> + Send {
async { Ok(MiddlewareAction::Continue) }
}
fn on_response(
&self,
_response: &mut ChatResponse,
) -> impl Future<Output = Result<MiddlewareAction>> + Send {
async { Ok(MiddlewareAction::Continue) }
}
fn on_tool_call(
&self,
_call: &mut ToolCall,
) -> impl Future<Output = Result<MiddlewareAction>> + Send {
async { Ok(MiddlewareAction::Continue) }
}
}
pub trait ErasedMiddleware: Send + Sync {
fn on_request_erased<'a>(
&'a self,
request: &'a mut ChatRequest,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>>;
fn on_response_erased<'a>(
&'a self,
response: &'a mut ChatResponse,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>>;
fn on_tool_call_erased<'a>(
&'a self,
call: &'a mut ToolCall,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>>;
}
impl<T: Middleware> ErasedMiddleware for T {
fn on_request_erased<'a>(
&'a self,
request: &'a mut ChatRequest,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>> {
Box::pin(self.on_request(request))
}
fn on_response_erased<'a>(
&'a self,
response: &'a mut ChatResponse,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>> {
Box::pin(self.on_response(response))
}
fn on_tool_call_erased<'a>(
&'a self,
call: &'a mut ToolCall,
) -> Pin<Box<dyn Future<Output = Result<MiddlewareAction>> + Send + 'a>> {
Box::pin(self.on_tool_call(call))
}
}
pub type SharedMiddleware = Arc<dyn ErasedMiddleware>;