pe_core/
timeout_middleware.rs1use std::time::Duration;
17
18use async_trait::async_trait;
19
20use crate::error::PeError;
21use crate::llm::{LlmProvider, LlmResponse, ToolSchema};
22use crate::message::Message;
23use crate::provider_middleware::ProviderMiddleware;
24
25pub struct TimeoutMiddleware {
29 duration: Duration,
30}
31
32impl TimeoutMiddleware {
33 pub fn new(duration: Duration) -> Self {
35 Self { duration }
36 }
37}
38
39#[async_trait]
40impl ProviderMiddleware for TimeoutMiddleware {
41 async fn wrap_complete(
42 &self,
43 messages: &[Message],
44 tools: &[ToolSchema],
45 next: &dyn LlmProvider,
46 ) -> Result<LlmResponse, PeError> {
47 tokio::time::timeout(self.duration, next.complete(messages, tools))
48 .await
49 .map_err(|_| PeError::Timeout {
50 seconds: self.duration.as_secs_f64(),
51 })?
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use crate::mock_provider::MockProvider;
59
60 #[tokio::test]
61 async fn test_timeout_fast_call_succeeds() {
62 let timeout = TimeoutMiddleware::new(Duration::from_secs(5));
63 let provider = MockProvider::new().respond_with("fast");
64
65 let resp = timeout.wrap_complete(&[], &[], &provider).await.unwrap();
66 assert_eq!(resp.message.content.as_text(), Some("fast"));
67 }
68
69 #[tokio::test]
70 async fn test_timeout_slow_call_returns_error() {
71 let timeout = TimeoutMiddleware::new(Duration::from_millis(10));
72
73 struct SlowProvider;
75 impl LlmProvider for SlowProvider {
76 fn complete(
77 &self,
78 _messages: &[Message],
79 _tools: &[ToolSchema],
80 ) -> std::pin::Pin<
81 Box<
82 dyn std::future::Future<Output = Result<crate::llm::LlmResponse, PeError>>
83 + Send
84 + '_,
85 >,
86 > {
87 Box::pin(async {
88 tokio::time::sleep(Duration::from_secs(10)).await;
89 unreachable!("should have timed out")
90 })
91 }
92
93 fn stream(
94 &self,
95 _messages: &[Message],
96 _tools: &[ToolSchema],
97 ) -> crate::llm::StreamFuture<'_> {
98 unimplemented!()
99 }
100
101 fn embed(
102 &self,
103 _text: &str,
104 ) -> std::pin::Pin<
105 Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>,
106 > {
107 unimplemented!()
108 }
109
110 fn provider_name(&self) -> &'static str {
111 "slow"
112 }
113 }
114
115 let err = timeout
116 .wrap_complete(&[], &[], &SlowProvider)
117 .await
118 .unwrap_err();
119 match err {
120 PeError::Timeout { seconds } => {
121 assert!(
122 (seconds - 0.01).abs() < 0.001,
123 "expected ~0.01s, got {seconds}"
124 );
125 }
126 other => panic!("expected Timeout, got {other:?}"),
127 }
128 }
129}