Skip to main content

pe_core/
middleware_stack.rs

1//! Middleware stack -- composes middlewares over a base [`LlmProvider`].
2//!
3//! The stack itself implements `LlmProvider`, so it can be used anywhere
4//! a provider is expected. Middlewares execute outside-in: the last one
5//! added via [`with()`](MiddlewareStack::with) wraps all previous ones.
6//!
7//! # Example
8//!
9//! ```ignore
10//! let provider = MiddlewareStack::new(openai)
11//!     .with(TimeoutMiddleware::new(Duration::from_secs(30)))
12//!     .with(RetryMiddleware::new(3, Duration::from_millis(200)));
13//! // Use `provider` as any LlmProvider
14//! ```
15
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25/// A stack of [`ProviderMiddleware`] layers wrapping a base [`LlmProvider`].
26///
27/// Implements `LlmProvider` itself, enabling transparent composition.
28/// Middlewares execute outside-in: the outermost (last added) runs first.
29pub struct MiddlewareStack {
30    /// The effective provider after all middleware is applied.
31    /// Each `with()` call wraps the current provider in a new layer.
32    provider: Arc<dyn LlmProvider>,
33}
34
35impl MiddlewareStack {
36    /// Create a new stack wrapping the given base provider.
37    pub fn new(base: impl LlmProvider) -> Self {
38        Self {
39            provider: Arc::new(base),
40        }
41    }
42
43    /// Add a middleware layer. Returns `self` for chaining.
44    ///
45    /// Layers execute outside-in: the last added runs first.
46    #[must_use = "builder method returns modified stack"]
47    pub fn with(self, middleware: impl ProviderMiddleware) -> Self {
48        Self {
49            provider: Arc::new(WrappedLayer {
50                middleware: Arc::new(middleware),
51                inner: self.provider,
52            }),
53        }
54    }
55}
56
57/// A single middleware layer wrapping an inner provider.
58/// Implements `LlmProvider` so it can be nested.
59struct WrappedLayer {
60    middleware: Arc<dyn ProviderMiddleware>,
61    inner: Arc<dyn LlmProvider>,
62}
63
64impl LlmProvider for WrappedLayer {
65    fn complete(
66        &self,
67        messages: &[Message],
68        tools: &[ToolSchema],
69    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
70        let messages = messages.to_vec();
71        let tools = tools.to_vec();
72        Box::pin(async move {
73            self.middleware
74                .wrap_complete(&messages, &tools, self.inner.as_ref())
75                .await
76        })
77    }
78
79    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
80        // Middleware only wraps complete(); stream passes through.
81        self.inner.stream(messages, tools)
82    }
83
84    fn embed(
85        &self,
86        text: &str,
87    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
88        self.inner.embed(text)
89    }
90
91    fn provider_name(&self) -> &'static str {
92        self.inner.provider_name()
93    }
94}
95
96impl LlmProvider for MiddlewareStack {
97    fn complete(
98        &self,
99        messages: &[Message],
100        tools: &[ToolSchema],
101    ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
102        self.provider.complete(messages, tools)
103    }
104
105    fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
106        self.provider.stream(messages, tools)
107    }
108
109    fn embed(
110        &self,
111        text: &str,
112    ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
113        self.provider.embed(text)
114    }
115
116    fn provider_name(&self) -> &'static str {
117        self.provider.provider_name()
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::mock_provider::MockProvider;
125    use async_trait::async_trait;
126    use std::sync::atomic::{AtomicU32, Ordering};
127
128    struct CountingMiddleware {
129        count: Arc<AtomicU32>,
130    }
131
132    #[async_trait]
133    impl ProviderMiddleware for CountingMiddleware {
134        async fn wrap_complete(
135            &self,
136            messages: &[Message],
137            tools: &[ToolSchema],
138            next: &dyn LlmProvider,
139        ) -> Result<LlmResponse, PeError> {
140            self.count.fetch_add(1, Ordering::SeqCst);
141            next.complete(messages, tools).await
142        }
143    }
144
145    #[tokio::test]
146    async fn test_stack_no_middleware_passes_through() {
147        let stack = MiddlewareStack::new(MockProvider::new().respond_with("bare"));
148        let resp = stack.complete(&[], &[]).await.unwrap();
149        assert_eq!(resp.message.content.as_text(), Some("bare"));
150    }
151
152    #[tokio::test]
153    async fn test_stack_single_middleware_invoked() {
154        let count = Arc::new(AtomicU32::new(0));
155        let stack =
156            MiddlewareStack::new(MockProvider::new().respond_with("ok")).with(CountingMiddleware {
157                count: count.clone(),
158            });
159
160        let resp = stack.complete(&[], &[]).await.unwrap();
161        assert_eq!(resp.message.content.as_text(), Some("ok"));
162        assert_eq!(count.load(Ordering::SeqCst), 1);
163    }
164
165    #[tokio::test]
166    async fn test_stack_multiple_middlewares_execute_outside_in() {
167        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
168
169        struct OrderMiddleware {
170            id: &'static str,
171            order: Arc<std::sync::Mutex<Vec<&'static str>>>,
172        }
173
174        #[async_trait]
175        impl ProviderMiddleware for OrderMiddleware {
176            async fn wrap_complete(
177                &self,
178                messages: &[Message],
179                tools: &[ToolSchema],
180                next: &dyn LlmProvider,
181            ) -> Result<LlmResponse, PeError> {
182                self.order.lock().unwrap().push(self.id);
183                next.complete(messages, tools).await
184            }
185        }
186
187        let stack = MiddlewareStack::new(MockProvider::new().respond_with("done"))
188            .with(OrderMiddleware {
189                id: "first",
190                order: order.clone(),
191            })
192            .with(OrderMiddleware {
193                id: "second",
194                order: order.clone(),
195            });
196
197        stack.complete(&[], &[]).await.unwrap();
198
199        let recorded = order.lock().unwrap().clone();
200        // Outside-in: last added ("second") runs first
201        assert_eq!(recorded, vec!["second", "first"]);
202    }
203
204    #[tokio::test]
205    async fn test_stack_provider_name_delegates_to_base() {
206        let stack = MiddlewareStack::new(MockProvider::new());
207        assert_eq!(stack.provider_name(), "mock");
208    }
209
210    #[tokio::test]
211    async fn test_stack_embed_delegates_to_base() {
212        let stack = MiddlewareStack::new(MockProvider::new().with_embedding(vec![1.0, 2.0]));
213        let embedding = stack.embed("test").await.unwrap();
214        assert_eq!(embedding, vec![1.0, 2.0]);
215    }
216
217    /// Composition test: timeout + retry + circuit breaker in one stack.
218    #[tokio::test]
219    async fn test_full_middleware_composition() {
220        use crate::circuit_breaker::CircuitBreaker;
221        use crate::retry_middleware::RetryMiddleware;
222        use crate::timeout_middleware::TimeoutMiddleware;
223        use std::time::Duration;
224
225        // Provider: fail once then succeed
226        let provider = MockProvider::new()
227            .respond_with_error(PeError::LlmProvider {
228                details: "503".into(),
229            })
230            .respond_with("recovered");
231
232        let stack = MiddlewareStack::new(provider)
233            .with(CircuitBreaker::new(5, Duration::from_secs(60)))
234            .with(RetryMiddleware::new(3, Duration::from_millis(1)))
235            .with(TimeoutMiddleware::new(Duration::from_secs(5)));
236
237        // Retry middleware catches the first transient failure and retries
238        let resp = stack.complete(&[], &[]).await.unwrap();
239        assert_eq!(resp.message.content.as_text(), Some("recovered"));
240    }
241}