Skip to main content

pe_core/
provider_middleware.rs

1//! Provider middleware trait for composable LLM call wrappers.
2//!
3//! Middleware intercepts `complete()` calls, adding cross-cutting concerns
4//! like retry, timeout, circuit breaking, and fallback — without modifying
5//! the underlying provider.
6//!
7//! # Example
8//!
9//! ```ignore
10//! struct LoggingMiddleware;
11//!
12//! #[async_trait]
13//! impl ProviderMiddleware for LoggingMiddleware {
14//!     async fn wrap_complete(
15//!         &self,
16//!         messages: &[Message],
17//!         tools: &[ToolSchema],
18//!         next: &dyn LlmProvider,
19//!     ) -> Result<LlmResponse, PeError> {
20//!         println!("calling LLM with {} messages", messages.len());
21//!         next.complete(messages, tools).await
22//!     }
23//! }
24//! ```
25
26use async_trait::async_trait;
27
28use crate::error::PeError;
29use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
30use crate::message::Message;
31
32/// Intercepts LLM completion calls for cross-cutting concerns.
33///
34/// Middlewares are composed into a [`MiddlewareStack`](super::middleware_stack::MiddlewareStack)
35/// and execute outside-in: the first middleware added wraps all subsequent ones.
36///
37/// # Implementors
38///
39/// - [`RetryMiddleware`](super::retry_middleware::RetryMiddleware) — exponential backoff
40/// - [`TimeoutMiddleware`](super::timeout_middleware::TimeoutMiddleware) — per-call deadline
41/// - [`CircuitBreaker`](super::circuit_breaker::CircuitBreaker) — fail-fast on repeated errors
42#[async_trait]
43pub trait ProviderMiddleware: Send + Sync + 'static {
44    /// Wrap a completion call. Call `next.complete(messages, tools).await`
45    /// to forward to the next layer (or the base provider).
46    async fn wrap_complete(
47        &self,
48        messages: &[Message],
49        tools: &[ToolSchema],
50        next: &dyn LlmProvider,
51    ) -> Result<LlmResponse, PeError>;
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::mock_provider::MockProvider;
58
59    struct PassthroughMiddleware;
60
61    #[async_trait]
62    impl ProviderMiddleware for PassthroughMiddleware {
63        async fn wrap_complete(
64            &self,
65            messages: &[Message],
66            tools: &[ToolSchema],
67            next: &dyn LlmProvider,
68        ) -> Result<LlmResponse, PeError> {
69            next.complete(messages, tools).await
70        }
71    }
72
73    #[tokio::test]
74    async fn test_passthrough_middleware_forwards_to_provider() {
75        let provider = MockProvider::new().respond_with("hello");
76        let mw = PassthroughMiddleware;
77
78        let resp = mw.wrap_complete(&[], &[], &provider).await.unwrap();
79        assert_eq!(resp.message.content.as_text(), Some("hello"));
80    }
81
82    #[tokio::test]
83    async fn test_middleware_receives_provider_errors() {
84        let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
85            details: "rate limited".into(),
86        });
87        let mw = PassthroughMiddleware;
88
89        let err = mw.wrap_complete(&[], &[], &provider).await.unwrap_err();
90        assert!(matches!(err, PeError::LlmProvider { .. }));
91    }
92}