use async_trait::async_trait;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;
#[async_trait]
pub trait ProviderMiddleware: Send + Sync + 'static {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
struct PassthroughMiddleware;
#[async_trait]
impl ProviderMiddleware for PassthroughMiddleware {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
next.complete(messages, tools).await
}
}
#[tokio::test]
async fn test_passthrough_middleware_forwards_to_provider() {
let provider = MockProvider::new().respond_with("hello");
let mw = PassthroughMiddleware;
let resp = mw.wrap_complete(&[], &[], &provider).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("hello"));
}
#[tokio::test]
async fn test_middleware_receives_provider_errors() {
let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
details: "rate limited".into(),
});
let mw = PassthroughMiddleware;
let err = mw.wrap_complete(&[], &[], &provider).await.unwrap_err();
assert!(matches!(err, PeError::LlmProvider { .. }));
}
}