pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! Provider middleware trait for composable LLM call wrappers.
//!
//! Middleware intercepts `complete()` calls, adding cross-cutting concerns
//! like retry, timeout, circuit breaking, and fallback — without modifying
//! the underlying provider.
//!
//! # Example
//!
//! ```ignore
//! struct LoggingMiddleware;
//!
//! #[async_trait]
//! impl ProviderMiddleware for LoggingMiddleware {
//!     async fn wrap_complete(
//!         &self,
//!         messages: &[Message],
//!         tools: &[ToolSchema],
//!         next: &dyn LlmProvider,
//!     ) -> Result<LlmResponse, PeError> {
//!         println!("calling LLM with {} messages", messages.len());
//!         next.complete(messages, tools).await
//!     }
//! }
//! ```

use async_trait::async_trait;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
use crate::message::Message;

/// Intercepts LLM completion calls for cross-cutting concerns.
///
/// Middlewares are composed into a [`MiddlewareStack`](super::middleware_stack::MiddlewareStack)
/// and execute outside-in: the first middleware added wraps all subsequent ones.
///
/// # Implementors
///
/// - [`RetryMiddleware`](super::retry_middleware::RetryMiddleware) — exponential backoff
/// - [`TimeoutMiddleware`](super::timeout_middleware::TimeoutMiddleware) — per-call deadline
/// - [`CircuitBreaker`](super::circuit_breaker::CircuitBreaker) — fail-fast on repeated errors
#[async_trait]
pub trait ProviderMiddleware: Send + Sync + 'static {
    /// Wrap a completion call. Call `next.complete(messages, tools).await`
    /// to forward to the next layer (or the base provider).
    async fn wrap_complete(
        &self,
        messages: &[Message],
        tools: &[ToolSchema],
        next: &dyn LlmProvider,
    ) -> Result<LlmResponse, PeError>;
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::mock_provider::MockProvider;

    struct PassthroughMiddleware;

    #[async_trait]
    impl ProviderMiddleware for PassthroughMiddleware {
        async fn wrap_complete(
            &self,
            messages: &[Message],
            tools: &[ToolSchema],
            next: &dyn LlmProvider,
        ) -> Result<LlmResponse, PeError> {
            next.complete(messages, tools).await
        }
    }

    #[tokio::test]
    async fn test_passthrough_middleware_forwards_to_provider() {
        let provider = MockProvider::new().respond_with("hello");
        let mw = PassthroughMiddleware;

        let resp = mw.wrap_complete(&[], &[], &provider).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("hello"));
    }

    #[tokio::test]
    async fn test_middleware_receives_provider_errors() {
        let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
            details: "rate limited".into(),
        });
        let mw = PassthroughMiddleware;

        let err = mw.wrap_complete(&[], &[], &provider).await.unwrap_err();
        assert!(matches!(err, PeError::LlmProvider { .. }));
    }
}