Skip to main content

agent_runtime/
timeout.rs

1use crate::error::RuntimeError;
2use std::time::Duration;
3use tokio::time::timeout;
4
5/// Configuration for operation timeouts
6#[derive(Debug, Clone)]
7pub struct TimeoutConfig {
8    /// Overall timeout for the entire operation
9    pub total: Option<Duration>,
10
11    /// Timeout for first response (useful for streaming)
12    pub first_response: Option<Duration>,
13}
14
15impl Default for TimeoutConfig {
16    fn default() -> Self {
17        Self {
18            total: Some(Duration::from_secs(300)), // 5 minutes default
19            first_response: Some(Duration::from_secs(30)), // 30 seconds for first response
20        }
21    }
22}
23
24impl TimeoutConfig {
25    /// No timeout (operations can run indefinitely)
26    pub fn none() -> Self {
27        Self {
28            total: None,
29            first_response: None,
30        }
31    }
32
33    /// Quick timeout for fast operations
34    pub fn quick() -> Self {
35        Self {
36            total: Some(Duration::from_secs(30)),
37            first_response: Some(Duration::from_secs(5)),
38        }
39    }
40
41    /// Long timeout for expensive operations
42    pub fn long() -> Self {
43        Self {
44            total: Some(Duration::from_secs(600)), // 10 minutes
45            first_response: Some(Duration::from_secs(60)),
46        }
47    }
48
49    /// Custom timeout
50    pub fn custom(total: Duration, first_response: Option<Duration>) -> Self {
51        Self {
52            total: Some(total),
53            first_response,
54        }
55    }
56
57    /// Execute an async operation with timeout protection
58    ///
59    /// # Example
60    /// ```no_run
61    /// use agent_runtime::timeout::TimeoutConfig;
62    /// use std::time::Duration;
63    ///
64    /// # async fn example() -> Result<String, agent_runtime::RuntimeError> {
65    /// let config = TimeoutConfig::default();
66    /// let result = config.execute(
67    ///     "fetch_data",
68    ///     async {
69    ///         // Your operation here
70    ///         Ok("success".to_string())
71    ///     }
72    /// ).await?;
73    /// # Ok(result)
74    /// # }
75    /// ```
76    pub async fn execute<F, T>(&self, operation_name: &str, operation: F) -> Result<T, RuntimeError>
77    where
78        F: std::future::Future<Output = Result<T, RuntimeError>>,
79    {
80        if let Some(timeout_duration) = self.total {
81            let start = std::time::Instant::now();
82
83            match timeout(timeout_duration, operation).await {
84                Ok(result) => result,
85                Err(_) => Err(RuntimeError::Timeout {
86                    operation: operation_name.to_string(),
87                    duration_ms: start.elapsed().as_millis() as u64,
88                }),
89            }
90        } else {
91            // No timeout configured
92            operation.await
93        }
94    }
95
96    /// Execute with first response timeout (useful for streaming)
97    ///
98    /// Returns a tuple of (first_chunk, remaining_stream)
99    pub async fn execute_with_first_response<F, T>(
100        &self,
101        operation_name: &str,
102        mut operation: F,
103    ) -> Result<T, RuntimeError>
104    where
105        F: std::future::Future<Output = Result<T, RuntimeError>> + Unpin,
106    {
107        if let Some(first_timeout) = self.first_response {
108            let start = std::time::Instant::now();
109
110            // Wait for first response with timeout
111            match timeout(first_timeout, &mut operation).await {
112                Ok(result) => result,
113                Err(_) => Err(RuntimeError::Timeout {
114                    operation: format!("{} (first response)", operation_name),
115                    duration_ms: start.elapsed().as_millis() as u64,
116                }),
117            }
118        } else {
119            operation.await
120        }
121    }
122}
123
124/// Execute an operation with a specific timeout duration
125///
126/// Convenience function for one-off timeouts
127///
128/// # Example
129/// ```no_run
130/// use agent_runtime::timeout::with_timeout;
131/// use std::time::Duration;
132///
133/// # async fn example() -> Result<String, agent_runtime::RuntimeError> {
134/// let result = with_timeout(
135///     Duration::from_secs(30),
136///     "api_call",
137///     async {
138///         // Your operation
139///         Ok("done".to_string())
140///     }
141/// ).await?;
142/// # Ok(result)
143/// # }
144/// ```
145pub async fn with_timeout<F, T>(
146    duration: Duration,
147    operation_name: &str,
148    operation: F,
149) -> Result<T, RuntimeError>
150where
151    F: std::future::Future<Output = Result<T, RuntimeError>>,
152{
153    let config = TimeoutConfig::custom(duration, None);
154    config.execute(operation_name, operation).await
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use crate::error::LlmError;
161
162    #[tokio::test]
163    async fn test_operation_completes_within_timeout() {
164        let config = TimeoutConfig::default();
165
166        let result: Result<&str, RuntimeError> = config
167            .execute("test_op", async {
168                tokio::time::sleep(Duration::from_millis(10)).await;
169                Ok("success")
170            })
171            .await;
172
173        assert!(result.is_ok());
174        assert_eq!(result.unwrap(), "success");
175    }
176
177    #[tokio::test]
178    async fn test_operation_exceeds_timeout() {
179        let config = TimeoutConfig::custom(Duration::from_millis(50), None);
180
181        let result: Result<&str, RuntimeError> = config
182            .execute("test_op", async {
183                tokio::time::sleep(Duration::from_millis(200)).await;
184                Ok("success")
185            })
186            .await;
187
188        assert!(result.is_err());
189
190        match result.unwrap_err() {
191            RuntimeError::Timeout {
192                operation,
193                duration_ms,
194            } => {
195                assert_eq!(operation, "test_op");
196                assert!(duration_ms >= 50);
197            }
198            _ => panic!("Expected Timeout error"),
199        }
200    }
201
202    #[tokio::test]
203    async fn test_no_timeout_allows_long_operations() {
204        let config = TimeoutConfig::none();
205
206        let result: Result<&str, RuntimeError> = config
207            .execute("test_op", async {
208                tokio::time::sleep(Duration::from_millis(100)).await;
209                Ok("success")
210            })
211            .await;
212
213        assert!(result.is_ok());
214    }
215
216    #[tokio::test]
217    async fn test_with_timeout_convenience_function() {
218        let result: Result<&str, RuntimeError> =
219            with_timeout(Duration::from_secs(1), "test_op", async {
220                tokio::time::sleep(Duration::from_millis(10)).await;
221                Ok("success")
222            })
223            .await;
224
225        assert!(result.is_ok());
226    }
227
228    #[tokio::test]
229    async fn test_timeout_with_error_result() {
230        let config = TimeoutConfig::default();
231
232        let result: Result<&str, RuntimeError> = config
233            .execute("test_op", async {
234                Err(LlmError::network("Network error").into())
235            })
236            .await;
237
238        assert!(result.is_err());
239
240        // Should get the actual error, not a timeout
241        match result.unwrap_err() {
242            RuntimeError::Llm(_) => {
243                // Expected
244            }
245            _ => panic!("Expected LLM error, not timeout"),
246        }
247    }
248
249    #[tokio::test]
250    async fn test_quick_timeout_config() {
251        let config = TimeoutConfig::quick();
252        assert_eq!(config.total, Some(Duration::from_secs(30)));
253        assert_eq!(config.first_response, Some(Duration::from_secs(5)));
254    }
255
256    #[tokio::test]
257    async fn test_long_timeout_config() {
258        let config = TimeoutConfig::long();
259        assert_eq!(config.total, Some(Duration::from_secs(600)));
260        assert_eq!(config.first_response, Some(Duration::from_secs(60)));
261    }
262}