pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Middleware stack -- composes middlewares over a base [`LlmProvider`].
//!
//! The stack itself implements `LlmProvider`, so it can be used anywhere
//! a provider is expected. Middlewares execute outside-in: the last one
//! added via [`with()`](MiddlewareStack::with) wraps all previous ones.
//!
//! # Example
//!
//! ```ignore
//! let provider = MiddlewareStack::new(openai)
//!     .with(TimeoutMiddleware::new(Duration::from_secs(30)))
//!     .with(RetryMiddleware::new(3, Duration::from_millis(200)));
//! // Use `provider` as any LlmProvider
//! ```

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
use crate::message::Message;
use crate::provider_middleware::ProviderMiddleware;

/// A stack of [`ProviderMiddleware`] layers wrapping a base [`LlmProvider`].
///
/// Implements `LlmProvider` itself, enabling transparent composition.
/// Middlewares execute outside-in: the outermost (last added) runs first.
pub struct MiddlewareStack {
    /// The effective provider after all middleware is applied.
    /// Each `with()` call wraps the current provider in a new layer.
    provider: Arc<dyn LlmProvider>,
}

impl MiddlewareStack {
    /// Create a new stack wrapping the given base provider.
    pub fn new(base: impl LlmProvider) -> Self {
        Self {
            provider: Arc::new(base),
        }
    }

    /// Add a middleware layer. Returns `self` for chaining.
    ///
    /// Layers execute outside-in: the last added runs first.
    #[must_use = "builder method returns modified stack"]
    pub fn with(self, middleware: impl ProviderMiddleware) -> Self {
        Self {
            provider: Arc::new(WrappedLayer {
                middleware: Arc::new(middleware),
                inner: self.provider,
            }),
        }
    }
}

/// A single middleware layer wrapping an inner provider.
/// Implements `LlmProvider` so it can be nested.
struct WrappedLayer {
    middleware: Arc<dyn ProviderMiddleware>,
    inner: Arc<dyn LlmProvider>,
}

impl LlmProvider for WrappedLayer {
    fn complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
        let messages = messages.to_vec();
        let tools = tools.to_vec();
        Box::pin(async move {
            self.middleware
                .wrap_complete(&messages, &tools, self.inner.as_ref())
                .await
        })
    }

    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
        // Middleware only wraps complete(); stream passes through.
        self.inner.stream(messages, tools)
    }

    fn embed(
        &self,
        text: &str,
    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
        self.inner.embed(text)
    }

    fn provider_name(&self) -> &'static str {
        self.inner.provider_name()
    }
}

impl LlmProvider for MiddlewareStack {
    fn complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
        self.provider.complete(messages, tools)
    }

    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
        self.provider.stream(messages, tools)
    }

    fn embed(
        &self,
        text: &str,
    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
        self.provider.embed(text)
    }

    fn provider_name(&self) -> &'static str {
        self.provider.provider_name()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;
    use async_trait::async_trait;
    use std::sync::atomic::{AtomicU32, Ordering};

    struct CountingMiddleware {
        count: Arc<AtomicU32>,
    }

    #[async_trait]
    impl ProviderMiddleware for CountingMiddleware {
        async fn wrap_complete(
            &self,
            messages: &[Message],
            tools: &[ToolSchema],
            next: &dyn LlmProvider,
        ) -> Result<LlmResponse, PeError> {
            self.count.fetch_add(1, Ordering::SeqCst);
            next.complete(messages, tools).await
        }
    }

    #[tokio::test]
    async fn test_stack_no_middleware_passes_through() {
        let stack = MiddlewareStack::new(MockProvider::new().respond_with("bare"));
        let resp = stack.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("bare"));
    }

    #[tokio::test]
    async fn test_stack_single_middleware_invoked() {
        let count = Arc::new(AtomicU32::new(0));
        let stack =
            MiddlewareStack::new(MockProvider::new().respond_with("ok")).with(CountingMiddleware {
                count: count.clone(),
            });

        let resp = stack.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("ok"));
        assert_eq!(count.load(Ordering::SeqCst), 1);
    }

    #[tokio::test]
    async fn test_stack_multiple_middlewares_execute_outside_in() {
        let order = Arc::new(std::sync::Mutex::new(Vec::new()));

        struct OrderMiddleware {
            id: &'static str,
            order: Arc<std::sync::Mutex<Vec<&'static str>>>,
        }

        #[async_trait]
        impl ProviderMiddleware for OrderMiddleware {
            async fn wrap_complete(
                &self,
                messages: &[Message],
                tools: &[ToolSchema],
                next: &dyn LlmProvider,
            ) -> Result<LlmResponse, PeError> {
                self.order.lock().unwrap().push(self.id);
                next.complete(messages, tools).await
            }
        }

        let stack = MiddlewareStack::new(MockProvider::new().respond_with("done"))
            .with(OrderMiddleware {
                id: "first",
                order: order.clone(),
            })
            .with(OrderMiddleware {
                id: "second",
                order: order.clone(),
            });

        stack.complete(&[], &[]).await.unwrap();

        let recorded = order.lock().unwrap().clone();
        // Outside-in: last added ("second") runs first
        assert_eq!(recorded, vec!["second", "first"]);
    }

    #[tokio::test]
    async fn test_stack_provider_name_delegates_to_base() {
        let stack = MiddlewareStack::new(MockProvider::new());
        assert_eq!(stack.provider_name(), "mock");
    }

    #[tokio::test]
    async fn test_stack_embed_delegates_to_base() {
        let stack = MiddlewareStack::new(MockProvider::new().with_embedding(vec![1.0, 2.0]));
        let embedding = stack.embed("test").await.unwrap();
        assert_eq!(embedding, vec![1.0, 2.0]);
    }

    /// Composition test: timeout + retry + circuit breaker in one stack.
    #[tokio::test]
    async fn test_full_middleware_composition() {
        use crate::circuit_breaker::CircuitBreaker;
        use crate::retry_middleware::RetryMiddleware;
        use crate::timeout_middleware::TimeoutMiddleware;
        use std::time::Duration;

        // Provider: fail once then succeed
        let provider = MockProvider::new()
            .respond_with_error(PeError::LlmProvider {
                details: "503".into(),
            })
            .respond_with("recovered");

        let stack = MiddlewareStack::new(provider)
            .with(CircuitBreaker::new(5, Duration::from_secs(60)))
            .with(RetryMiddleware::new(3, Duration::from_millis(1)))
            .with(TimeoutMiddleware::new(Duration::from_secs(5)));

        // Retry middleware catches the first transient failure and retries
        let resp = stack.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("recovered"));
    }
}