1use crate::error::RuntimeError;
2use std::time::Duration;
3use tokio::time::timeout;
4
5#[derive(Debug, Clone)]
7pub struct TimeoutConfig {
8 pub total: Option<Duration>,
10
11 pub first_response: Option<Duration>,
13}
14
15impl Default for TimeoutConfig {
16 fn default() -> Self {
17 Self {
18 total: Some(Duration::from_secs(300)), first_response: Some(Duration::from_secs(30)), }
21 }
22}
23
24impl TimeoutConfig {
25 pub fn none() -> Self {
27 Self {
28 total: None,
29 first_response: None,
30 }
31 }
32
33 pub fn quick() -> Self {
35 Self {
36 total: Some(Duration::from_secs(30)),
37 first_response: Some(Duration::from_secs(5)),
38 }
39 }
40
41 pub fn long() -> Self {
43 Self {
44 total: Some(Duration::from_secs(600)), first_response: Some(Duration::from_secs(60)),
46 }
47 }
48
49 pub fn custom(total: Duration, first_response: Option<Duration>) -> Self {
51 Self {
52 total: Some(total),
53 first_response,
54 }
55 }
56
57 pub async fn execute<F, T>(&self, operation_name: &str, operation: F) -> Result<T, RuntimeError>
77 where
78 F: std::future::Future<Output = Result<T, RuntimeError>>,
79 {
80 if let Some(timeout_duration) = self.total {
81 let start = std::time::Instant::now();
82
83 match timeout(timeout_duration, operation).await {
84 Ok(result) => result,
85 Err(_) => Err(RuntimeError::Timeout {
86 operation: operation_name.to_string(),
87 duration_ms: start.elapsed().as_millis() as u64,
88 }),
89 }
90 } else {
91 operation.await
93 }
94 }
95
96 pub async fn execute_with_first_response<F, T>(
100 &self,
101 operation_name: &str,
102 mut operation: F,
103 ) -> Result<T, RuntimeError>
104 where
105 F: std::future::Future<Output = Result<T, RuntimeError>> + Unpin,
106 {
107 if let Some(first_timeout) = self.first_response {
108 let start = std::time::Instant::now();
109
110 match timeout(first_timeout, &mut operation).await {
112 Ok(result) => result,
113 Err(_) => Err(RuntimeError::Timeout {
114 operation: format!("{} (first response)", operation_name),
115 duration_ms: start.elapsed().as_millis() as u64,
116 }),
117 }
118 } else {
119 operation.await
120 }
121 }
122}
123
124pub async fn with_timeout<F, T>(
146 duration: Duration,
147 operation_name: &str,
148 operation: F,
149) -> Result<T, RuntimeError>
150where
151 F: std::future::Future<Output = Result<T, RuntimeError>>,
152{
153 let config = TimeoutConfig::custom(duration, None);
154 config.execute(operation_name, operation).await
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use crate::error::LlmError;
161
162 #[tokio::test]
163 async fn test_operation_completes_within_timeout() {
164 let config = TimeoutConfig::default();
165
166 let result: Result<&str, RuntimeError> = config
167 .execute("test_op", async {
168 tokio::time::sleep(Duration::from_millis(10)).await;
169 Ok("success")
170 })
171 .await;
172
173 assert!(result.is_ok());
174 assert_eq!(result.unwrap(), "success");
175 }
176
177 #[tokio::test]
178 async fn test_operation_exceeds_timeout() {
179 let config = TimeoutConfig::custom(Duration::from_millis(50), None);
180
181 let result: Result<&str, RuntimeError> = config
182 .execute("test_op", async {
183 tokio::time::sleep(Duration::from_millis(200)).await;
184 Ok("success")
185 })
186 .await;
187
188 assert!(result.is_err());
189
190 match result.unwrap_err() {
191 RuntimeError::Timeout {
192 operation,
193 duration_ms,
194 } => {
195 assert_eq!(operation, "test_op");
196 assert!(duration_ms >= 50);
197 }
198 _ => panic!("Expected Timeout error"),
199 }
200 }
201
202 #[tokio::test]
203 async fn test_no_timeout_allows_long_operations() {
204 let config = TimeoutConfig::none();
205
206 let result: Result<&str, RuntimeError> = config
207 .execute("test_op", async {
208 tokio::time::sleep(Duration::from_millis(100)).await;
209 Ok("success")
210 })
211 .await;
212
213 assert!(result.is_ok());
214 }
215
216 #[tokio::test]
217 async fn test_with_timeout_convenience_function() {
218 let result: Result<&str, RuntimeError> =
219 with_timeout(Duration::from_secs(1), "test_op", async {
220 tokio::time::sleep(Duration::from_millis(10)).await;
221 Ok("success")
222 })
223 .await;
224
225 assert!(result.is_ok());
226 }
227
228 #[tokio::test]
229 async fn test_timeout_with_error_result() {
230 let config = TimeoutConfig::default();
231
232 let result: Result<&str, RuntimeError> = config
233 .execute("test_op", async {
234 Err(LlmError::network("Network error").into())
235 })
236 .await;
237
238 assert!(result.is_err());
239
240 match result.unwrap_err() {
242 RuntimeError::Llm(_) => {
243 }
245 _ => panic!("Expected LLM error, not timeout"),
246 }
247 }
248
249 #[tokio::test]
250 async fn test_quick_timeout_config() {
251 let config = TimeoutConfig::quick();
252 assert_eq!(config.total, Some(Duration::from_secs(30)));
253 assert_eq!(config.first_response, Some(Duration::from_secs(5)));
254 }
255
256 #[tokio::test]
257 async fn test_long_timeout_config() {
258 let config = TimeoutConfig::long();
259 assert_eq!(config.total, Some(Duration::from_secs(600)));
260 assert_eq!(config.first_response, Some(Duration::from_secs(60)));
261 }
262}