pe_core/provider_middleware.rs
1//! Provider middleware trait for composable LLM call wrappers.
2//!
3//! Middleware intercepts `complete()` calls, adding cross-cutting concerns
4//! like retry, timeout, circuit breaking, and fallback — without modifying
5//! the underlying provider.
6//!
7//! # Example
8//!
9//! ```ignore
10//! struct LoggingMiddleware;
11//!
12//! #[async_trait]
13//! impl ProviderMiddleware for LoggingMiddleware {
14//! async fn wrap_complete(
15//! &self,
16//! messages: &[Message],
17//! tools: &[ToolSchema],
18//! next: &dyn LlmProvider,
19//! ) -> Result<LlmResponse, PeError> {
20//! println!("calling LLM with {} messages", messages.len());
21//! next.complete(messages, tools).await
22//! }
23//! }
24//! ```
25
26use async_trait::async_trait;
27
28use crate::error::PeError;
29use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
30use crate::message::Message;
31
32/// Intercepts LLM completion calls for cross-cutting concerns.
33///
34/// Middlewares are composed into a [`MiddlewareStack`](super::middleware_stack::MiddlewareStack)
35/// and execute outside-in: the first middleware added wraps all subsequent ones.
36///
37/// # Implementors
38///
39/// - [`RetryMiddleware`](super::retry_middleware::RetryMiddleware) — exponential backoff
40/// - [`TimeoutMiddleware`](super::timeout_middleware::TimeoutMiddleware) — per-call deadline
41/// - [`CircuitBreaker`](super::circuit_breaker::CircuitBreaker) — fail-fast on repeated errors
42#[async_trait]
43pub trait ProviderMiddleware: Send + Sync + 'static {
44 /// Wrap a completion call. Call `next.complete(messages, tools).await`
45 /// to forward to the next layer (or the base provider).
46 async fn wrap_complete(
47 &self,
48 messages: &[Message],
49 tools: &[ToolSchema],
50 next: &dyn LlmProvider,
51 ) -> Result<LlmResponse, PeError>;
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::mock_provider::MockProvider;
58
59 struct PassthroughMiddleware;
60
61 #[async_trait]
62 impl ProviderMiddleware for PassthroughMiddleware {
63 async fn wrap_complete(
64 &self,
65 messages: &[Message],
66 tools: &[ToolSchema],
67 next: &dyn LlmProvider,
68 ) -> Result<LlmResponse, PeError> {
69 next.complete(messages, tools).await
70 }
71 }
72
73 #[tokio::test]
74 async fn test_passthrough_middleware_forwards_to_provider() {
75 let provider = MockProvider::new().respond_with("hello");
76 let mw = PassthroughMiddleware;
77
78 let resp = mw.wrap_complete(&[], &[], &provider).await.unwrap();
79 assert_eq!(resp.message.content.as_text(), Some("hello"));
80 }
81
82 #[tokio::test]
83 async fn test_middleware_receives_provider_errors() {
84 let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
85 details: "rate limited".into(),
86 });
87 let mw = PassthroughMiddleware;
88
89 let err = mw.wrap_complete(&[], &[], &provider).await.unwrap_err();
90 assert!(matches!(err, PeError::LlmProvider { .. }));
91 }
92}