pe_core/
middleware_stack.rs1use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, StreamFuture, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25pub struct MiddlewareStack {
30 provider: Arc<dyn LlmProvider>,
33}
34
35impl MiddlewareStack {
36 pub fn new(base: impl LlmProvider) -> Self {
38 Self {
39 provider: Arc::new(base),
40 }
41 }
42
43 #[must_use = "builder method returns modified stack"]
47 pub fn with(self, middleware: impl ProviderMiddleware) -> Self {
48 Self {
49 provider: Arc::new(WrappedLayer {
50 middleware: Arc::new(middleware),
51 inner: self.provider,
52 }),
53 }
54 }
55}
56
57struct WrappedLayer {
60 middleware: Arc<dyn ProviderMiddleware>,
61 inner: Arc<dyn LlmProvider>,
62}
63
64impl LlmProvider for WrappedLayer {
65 fn complete(
66 &self,
67 messages: &[Message],
68 tools: &[ToolSchema],
69 ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
70 let messages = messages.to_vec();
71 let tools = tools.to_vec();
72 Box::pin(async move {
73 self.middleware
74 .wrap_complete(&messages, &tools, self.inner.as_ref())
75 .await
76 })
77 }
78
79 fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
80 self.inner.stream(messages, tools)
82 }
83
84 fn embed(
85 &self,
86 text: &str,
87 ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
88 self.inner.embed(text)
89 }
90
91 fn provider_name(&self) -> &'static str {
92 self.inner.provider_name()
93 }
94}
95
96impl LlmProvider for MiddlewareStack {
97 fn complete(
98 &self,
99 messages: &[Message],
100 tools: &[ToolSchema],
101 ) -> Pin<Box<dyn Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
102 self.provider.complete(messages, tools)
103 }
104
105 fn stream(&self, messages: &[Message], tools: &[ToolSchema]) -> StreamFuture<'_> {
106 self.provider.stream(messages, tools)
107 }
108
109 fn embed(
110 &self,
111 text: &str,
112 ) -> Pin<Box<dyn Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
113 self.provider.embed(text)
114 }
115
116 fn provider_name(&self) -> &'static str {
117 self.provider.provider_name()
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::mock_provider::MockProvider;
125 use async_trait::async_trait;
126 use std::sync::atomic::{AtomicU32, Ordering};
127
128 struct CountingMiddleware {
129 count: Arc<AtomicU32>,
130 }
131
132 #[async_trait]
133 impl ProviderMiddleware for CountingMiddleware {
134 async fn wrap_complete(
135 &self,
136 messages: &[Message],
137 tools: &[ToolSchema],
138 next: &dyn LlmProvider,
139 ) -> Result<LlmResponse, PeError> {
140 self.count.fetch_add(1, Ordering::SeqCst);
141 next.complete(messages, tools).await
142 }
143 }
144
145 #[tokio::test]
146 async fn test_stack_no_middleware_passes_through() {
147 let stack = MiddlewareStack::new(MockProvider::new().respond_with("bare"));
148 let resp = stack.complete(&[], &[]).await.unwrap();
149 assert_eq!(resp.message.content.as_text(), Some("bare"));
150 }
151
152 #[tokio::test]
153 async fn test_stack_single_middleware_invoked() {
154 let count = Arc::new(AtomicU32::new(0));
155 let stack =
156 MiddlewareStack::new(MockProvider::new().respond_with("ok")).with(CountingMiddleware {
157 count: count.clone(),
158 });
159
160 let resp = stack.complete(&[], &[]).await.unwrap();
161 assert_eq!(resp.message.content.as_text(), Some("ok"));
162 assert_eq!(count.load(Ordering::SeqCst), 1);
163 }
164
165 #[tokio::test]
166 async fn test_stack_multiple_middlewares_execute_outside_in() {
167 let order = Arc::new(std::sync::Mutex::new(Vec::new()));
168
169 struct OrderMiddleware {
170 id: &'static str,
171 order: Arc<std::sync::Mutex<Vec<&'static str>>>,
172 }
173
174 #[async_trait]
175 impl ProviderMiddleware for OrderMiddleware {
176 async fn wrap_complete(
177 &self,
178 messages: &[Message],
179 tools: &[ToolSchema],
180 next: &dyn LlmProvider,
181 ) -> Result<LlmResponse, PeError> {
182 self.order.lock().unwrap().push(self.id);
183 next.complete(messages, tools).await
184 }
185 }
186
187 let stack = MiddlewareStack::new(MockProvider::new().respond_with("done"))
188 .with(OrderMiddleware {
189 id: "first",
190 order: order.clone(),
191 })
192 .with(OrderMiddleware {
193 id: "second",
194 order: order.clone(),
195 });
196
197 stack.complete(&[], &[]).await.unwrap();
198
199 let recorded = order.lock().unwrap().clone();
200 assert_eq!(recorded, vec!["second", "first"]);
202 }
203
204 #[tokio::test]
205 async fn test_stack_provider_name_delegates_to_base() {
206 let stack = MiddlewareStack::new(MockProvider::new());
207 assert_eq!(stack.provider_name(), "mock");
208 }
209
210 #[tokio::test]
211 async fn test_stack_embed_delegates_to_base() {
212 let stack = MiddlewareStack::new(MockProvider::new().with_embedding(vec![1.0, 2.0]));
213 let embedding = stack.embed("test").await.unwrap();
214 assert_eq!(embedding, vec![1.0, 2.0]);
215 }
216
217 #[tokio::test]
219 async fn test_full_middleware_composition() {
220 use crate::circuit_breaker::CircuitBreaker;
221 use crate::retry_middleware::RetryMiddleware;
222 use crate::timeout_middleware::TimeoutMiddleware;
223 use std::time::Duration;
224
225 let provider = MockProvider::new()
227 .respond_with_error(PeError::LlmProvider {
228 details: "503".into(),
229 })
230 .respond_with("recovered");
231
232 let stack = MiddlewareStack::new(provider)
233 .with(CircuitBreaker::new(5, Duration::from_secs(60)))
234 .with(RetryMiddleware::new(3, Duration::from_millis(1)))
235 .with(TimeoutMiddleware::new(Duration::from_secs(5)));
236
237 let resp = stack.complete(&[], &[]).await.unwrap();
239 assert_eq!(resp.message.content.as_text(), Some("recovered"));
240 }
241}