Skip to main content

pe_core/
timeout_middleware.rs

1//! Timeout middleware — enforces a per-call time limit on LLM completions.
2//!
3//! Wraps each `complete()` call with [`tokio::time::timeout`]. If the call
4//! exceeds the deadline, returns [`PeError::Timeout`].
5//!
6//! # Example
7//!
8//! ```ignore
9//! use std::time::Duration;
10//! use pe_core::timeout_middleware::TimeoutMiddleware;
11//!
12//! let timeout = TimeoutMiddleware::new(Duration::from_secs(30));
13//! let stack = MiddlewareStack::new(provider).with(timeout);
14//! ```
15
16use std::time::Duration;
17
18use async_trait::async_trait;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25/// Enforces a per-call time limit on LLM completion requests.
26///
27/// Returns [`PeError::Timeout`] if the inner call exceeds the configured duration.
28pub struct TimeoutMiddleware {
29    duration: Duration,
30}
31
32impl TimeoutMiddleware {
33    /// Create a timeout middleware with the given deadline per call.
34    pub fn new(duration: Duration) -> Self {
35        Self { duration }
36    }
37}
38
39#[async_trait]
40impl ProviderMiddleware for TimeoutMiddleware {
41    async fn wrap_complete(
42        &self,
43        messages: &[Message],
44        tools: &[ToolSchema],
45        next: &dyn LlmProvider,
46    ) -> Result<LlmResponse, PeError> {
47        tokio::time::timeout(self.duration, next.complete(messages, tools))
48            .await
49            .map_err(|_| PeError::Timeout {
50                seconds: self.duration.as_secs_f64(),
51            })?
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use crate::mock_provider::MockProvider;
59
60    #[tokio::test]
61    async fn test_timeout_fast_call_succeeds() {
62        let timeout = TimeoutMiddleware::new(Duration::from_secs(5));
63        let provider = MockProvider::new().respond_with("fast");
64
65        let resp = timeout.wrap_complete(&[], &[], &provider).await.unwrap();
66        assert_eq!(resp.message.content.as_text(), Some("fast"));
67    }
68
69    #[tokio::test]
70    async fn test_timeout_slow_call_returns_error() {
71        let timeout = TimeoutMiddleware::new(Duration::from_millis(10));
72
73        // Use a provider that sleeps longer than the timeout
74        struct SlowProvider;
75        impl LlmProvider for SlowProvider {
76            fn complete(
77                &self,
78                _messages: &[Message],
79                _tools: &[ToolSchema],
80            ) -> std::pin::Pin<
81                Box<
82                    dyn std::future::Future<Output = Result<crate::llm::LlmResponse, PeError>>
83                        + Send
84                        + '_,
85                >,
86            > {
87                Box::pin(async {
88                    tokio::time::sleep(Duration::from_secs(10)).await;
89                    unreachable!("should have timed out")
90                })
91            }
92
93            fn stream(
94                &self,
95                _messages: &[Message],
96                _tools: &[ToolSchema],
97            ) -> crate::llm::StreamFuture<'_> {
98                unimplemented!()
99            }
100
101            fn embed(
102                &self,
103                _text: &str,
104            ) -> std::pin::Pin<
105                Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>,
106            > {
107                unimplemented!()
108            }
109
110            fn provider_name(&self) -> &'static str {
111                "slow"
112            }
113        }
114
115        let err = timeout
116            .wrap_complete(&[], &[], &SlowProvider)
117            .await
118            .unwrap_err();
119        match err {
120            PeError::Timeout { seconds } => {
121                assert!(
122                    (seconds - 0.01).abs() < 0.001,
123                    "expected ~0.01s, got {seconds}"
124                );
125            }
126            other => panic!("expected Timeout, got {other:?}"),
127        }
128    }
129}