use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;
pub struct MiddlewareStack {
provider: Arc<dyn LlmProvider>,
}
impl MiddlewareStack {
pub fn new(base: impl LlmProvider) -> Self {
Self {
provider: Arc::new(base),
}
}
#[must_use = "builder method returns modified stack"]
pub fn with(self, middleware: impl ProviderMiddleware) -> Self {
Self {
provider: Arc::new(WrappedLayer {
middleware: Arc::new(middleware),
inner: self.provider,
}),
}
}
}
struct WrappedLayer {
middleware: Arc<dyn ProviderMiddleware>,
inner: Arc<dyn LlmProvider>,
}
impl LlmProvider for WrappedLayer {
fn complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
let messages = messages.to_vec();
let tools = tools.to_vec();
Box::pin(async move {
self.middleware
.wrap_complete(&messages, &tools, self.inner.as_ref())
.await
})
}
fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
self.inner.stream(messages, tools)
}
fn embed(
&self,
text: &str,
) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
self.inner.embed(text)
}
fn provider_name(&self) -> &'static str {
self.inner.provider_name()
}
}
impl LlmProvider for MiddlewareStack {
fn complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
self.provider.complete(messages, tools)
}
fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
self.provider.stream(messages, tools)
}
fn embed(
&self,
text: &str,
) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
self.provider.embed(text)
}
fn provider_name(&self) -> &'static str {
self.provider.provider_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock_provider::MockProvider;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct CountingMiddleware {
count: Arc<AtomicU32>,
}
#[async_trait]
impl ProviderMiddleware for CountingMiddleware {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
self.count.fetch_add(1, Ordering::SeqCst);
next.complete(messages, tools).await
}
}
#[tokio::test]
async fn test_stack_no_middleware_passes_through() {
let stack = MiddlewareStack::new(MockProvider::new().respond_with("bare"));
let resp = stack.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("bare"));
}
#[tokio::test]
async fn test_stack_single_middleware_invoked() {
let count = Arc::new(AtomicU32::new(0));
let stack =
MiddlewareStack::new(MockProvider::new().respond_with("ok")).with(CountingMiddleware {
count: count.clone(),
});
let resp = stack.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("ok"));
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_stack_multiple_middlewares_execute_outside_in() {
let order = Arc::new(std::sync::Mutex::new(Vec::new()));
struct OrderMiddleware {
id: &'static str,
order: Arc<std::sync::Mutex<Vec<&'static str>>>,
}
#[async_trait]
impl ProviderMiddleware for OrderMiddleware {
async fn wrap_complete(
&self,
messages: &[Message],
tools: &[ToolSchema],
next: &dyn LlmProvider,
) -> Result<LlmResponse, PeError> {
self.order.lock().unwrap().push(self.id);
next.complete(messages, tools).await
}
}
let stack = MiddlewareStack::new(MockProvider::new().respond_with("done"))
.with(OrderMiddleware {
id: "first",
order: order.clone(),
})
.with(OrderMiddleware {
id: "second",
order: order.clone(),
});
stack.complete(&[], &[]).await.unwrap();
let recorded = order.lock().unwrap().clone();
assert_eq!(recorded, vec!["second", "first"]);
}
#[tokio::test]
async fn test_stack_provider_name_delegates_to_base() {
let stack = MiddlewareStack::new(MockProvider::new());
assert_eq!(stack.provider_name(), "mock");
}
#[tokio::test]
async fn test_stack_embed_delegates_to_base() {
let stack = MiddlewareStack::new(MockProvider::new().with_embedding(vec![1.0, 2.0]));
let embedding = stack.embed("test").await.unwrap();
assert_eq!(embedding, vec![1.0, 2.0]);
}
#[tokio::test]
async fn test_full_middleware_composition() {
use crate::circuit_breaker::CircuitBreaker;
use crate::retry_middleware::RetryMiddleware;
use crate::timeout_middleware::TimeoutMiddleware;
use std::time::Duration;
let provider = MockProvider::new()
.respond_with_error(PeError::LlmProvider {
details: "503".into(),
})
.respond_with("recovered");
let stack = MiddlewareStack::new(provider)
.with(CircuitBreaker::new(5, Duration::from_secs(60)))
.with(RetryMiddleware::new(3, Duration::from_millis(1)))
.with(TimeoutMiddleware::new(Duration::from_secs(5)));
let resp = stack.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("recovered"));
}
}