pforge_runtime/
timeout.rs1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::time::Duration;
4use tokio::time::timeout;
5
6pub struct TimeoutMiddleware {
9 duration: Duration,
10}
11
12impl TimeoutMiddleware {
13 pub fn new(duration: Duration) -> Self {
14 Self { duration }
15 }
16
17 pub fn from_millis(millis: u64) -> Self {
18 Self::new(Duration::from_millis(millis))
19 }
20
21 pub fn from_secs(secs: u64) -> Self {
22 Self::new(Duration::from_secs(secs))
23 }
24
25 pub fn duration(&self) -> Duration {
26 self.duration
27 }
28}
29
30#[async_trait::async_trait]
31impl Middleware for TimeoutMiddleware {
32 async fn before(&self, request: Value) -> Result<Value> {
33 Ok(request)
34 }
35
36 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
37 Ok(response)
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct RetryPolicy {
44 pub max_attempts: u32,
46 pub initial_backoff: Duration,
48 pub max_backoff: Duration,
50 pub backoff_multiplier: f64,
52 pub use_jitter: bool,
54}
55
56impl RetryPolicy {
57 pub fn new(max_attempts: u32) -> Self {
58 Self {
59 max_attempts,
60 initial_backoff: Duration::from_millis(100),
61 max_backoff: Duration::from_secs(30),
62 backoff_multiplier: 2.0,
63 use_jitter: true,
64 }
65 }
66
67 pub fn with_backoff(mut self, initial: Duration, max: Duration) -> Self {
68 self.initial_backoff = initial;
69 self.max_backoff = max;
70 self
71 }
72
73 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
74 self.backoff_multiplier = multiplier;
75 self
76 }
77
78 pub fn with_jitter(mut self, use_jitter: bool) -> Self {
79 self.use_jitter = use_jitter;
80 self
81 }
82
83 pub fn backoff_duration(&self, attempt: u32) -> Duration {
85 let base_duration = self.initial_backoff.as_millis() as f64
86 * self.backoff_multiplier.powi(attempt as i32);
87
88 let capped = base_duration.min(self.max_backoff.as_millis() as f64);
89
90 let duration = if self.use_jitter {
91 let jitter = rand::random::<f64>() * capped * 0.1; Duration::from_millis((capped + jitter) as u64)
93 } else {
94 Duration::from_millis(capped as u64)
95 };
96
97 duration
98 }
99
100 pub fn is_retryable(&self, error: &Error) -> bool {
102 match error {
104 Error::Handler(msg) => {
105 msg.contains("timeout")
107 || msg.contains("timed out")
108 || msg.contains("connection")
109 || msg.contains("temporary")
110 }
111 _ => false,
112 }
113 }
114}
115
116impl Default for RetryPolicy {
117 fn default() -> Self {
118 Self::new(3)
119 }
120}
121
122pub struct RetryMiddleware {
125 policy: RetryPolicy,
126}
127
128impl RetryMiddleware {
129 pub fn new(policy: RetryPolicy) -> Self {
130 Self { policy }
131 }
132
133 pub fn with_max_attempts(max_attempts: u32) -> Self {
134 Self::new(RetryPolicy::new(max_attempts))
135 }
136
137 pub fn policy(&self) -> &RetryPolicy {
138 &self.policy
139 }
140}
141
142#[async_trait::async_trait]
143impl Middleware for RetryMiddleware {
144 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
145 Err(error)
149 }
150}
151
152pub async fn retry_with_policy<F, Fut, T>(
154 policy: &RetryPolicy,
155 mut operation: F,
156) -> Result<T>
157where
158 F: FnMut() -> Fut,
159 Fut: std::future::Future<Output = Result<T>>,
160{
161 let mut attempt = 0;
162 let mut last_error = None;
163
164 while attempt < policy.max_attempts {
165 match operation().await {
166 Ok(result) => return Ok(result),
167 Err(error) => {
168 if !policy.is_retryable(&error) {
169 return Err(error);
170 }
171
172 last_error = Some(error);
173 attempt += 1;
174
175 if attempt < policy.max_attempts {
176 let backoff = policy.backoff_duration(attempt - 1);
177 tokio::time::sleep(backoff).await;
178 }
179 }
180 }
181 }
182
183 Err(last_error.unwrap_or_else(|| Error::Handler("All retry attempts failed".to_string())))
184}
185
186pub async fn with_timeout<F>(duration: Duration, future: F) -> Result<F::Output>
188where
189 F: std::future::Future,
190{
191 timeout(duration, future)
192 .await
193 .map_err(|_| Error::Handler(format!("Operation timed out after {:?}", duration)))
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::sync::atomic::{AtomicU32, Ordering};
200 use std::sync::Arc;
201
202 #[test]
203 fn test_retry_policy_backoff() {
204 let policy = RetryPolicy::new(3)
205 .with_backoff(Duration::from_millis(100), Duration::from_secs(5))
206 .with_multiplier(2.0)
207 .with_jitter(false);
208
209 let backoff1 = policy.backoff_duration(0);
210 let backoff2 = policy.backoff_duration(1);
211 let backoff3 = policy.backoff_duration(2);
212
213 assert_eq!(backoff1.as_millis(), 100);
214 assert_eq!(backoff2.as_millis(), 200);
215 assert_eq!(backoff3.as_millis(), 400);
216 }
217
218 #[test]
219 fn test_retry_policy_max_backoff() {
220 let policy = RetryPolicy::new(10)
221 .with_backoff(Duration::from_millis(100), Duration::from_secs(1))
222 .with_multiplier(2.0)
223 .with_jitter(false);
224
225 let backoff = policy.backoff_duration(10); assert!(backoff <= Duration::from_secs(1));
227 }
228
229 #[tokio::test]
230 async fn test_retry_with_policy_success() {
231 let counter = Arc::new(AtomicU32::new(0));
232 let counter_clone = counter.clone();
233
234 let policy = RetryPolicy::new(3)
235 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
236 .with_jitter(false);
237
238 let result = retry_with_policy(&policy, || {
239 let counter = counter_clone.clone();
240 async move {
241 let count = counter.fetch_add(1, Ordering::SeqCst);
242 if count < 2 {
243 Err(Error::Handler("timeout error".to_string()))
244 } else {
245 Ok(42)
246 }
247 }
248 })
249 .await;
250
251 assert!(result.is_ok());
252 assert_eq!(result.unwrap(), 42);
253 assert_eq!(counter.load(Ordering::SeqCst), 3); }
255
256 #[tokio::test]
257 async fn test_retry_with_policy_max_attempts() {
258 let counter = Arc::new(AtomicU32::new(0));
259 let counter_clone = counter.clone();
260
261 let policy = RetryPolicy::new(3)
262 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
263 .with_jitter(false);
264
265 let result = retry_with_policy(&policy, || {
266 let counter = counter_clone.clone();
267 async move {
268 counter.fetch_add(1, Ordering::SeqCst);
269 Err::<(), _>(Error::Handler("timeout error".to_string()))
270 }
271 })
272 .await;
273
274 assert!(result.is_err());
275 assert_eq!(counter.load(Ordering::SeqCst), 3); }
277
278 #[tokio::test]
279 async fn test_retry_non_retryable_error() {
280 let counter = Arc::new(AtomicU32::new(0));
281 let counter_clone = counter.clone();
282
283 let policy = RetryPolicy::new(3);
284
285 let result = retry_with_policy(&policy, || {
286 let counter = counter_clone.clone();
287 async move {
288 counter.fetch_add(1, Ordering::SeqCst);
289 Err::<(), _>(Error::Handler("fatal error".to_string()))
290 }
291 })
292 .await;
293
294 assert!(result.is_err());
295 assert_eq!(counter.load(Ordering::SeqCst), 1); }
297
298 #[tokio::test]
299 async fn test_with_timeout_success() {
300 let result = with_timeout(Duration::from_secs(1), async {
301 tokio::time::sleep(Duration::from_millis(10)).await;
302 42
303 })
304 .await;
305
306 assert!(result.is_ok());
307 assert_eq!(result.unwrap(), 42);
308 }
309
310 #[tokio::test]
311 async fn test_with_timeout_exceeded() {
312 let result = with_timeout(Duration::from_millis(50), async {
313 tokio::time::sleep(Duration::from_secs(10)).await;
314 42
315 })
316 .await;
317
318 assert!(result.is_err());
319 assert!(result.unwrap_err().to_string().contains("timed out"));
320 }
321
322 #[tokio::test]
323 async fn test_combined_timeout_and_retry() {
324 let counter = Arc::new(AtomicU32::new(0));
325 let counter_clone = counter.clone();
326
327 let policy = RetryPolicy::new(3)
328 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
329 .with_jitter(false);
330
331 let result = retry_with_policy(&policy, || {
332 let counter = counter_clone.clone();
333 async move {
334 let count = counter.fetch_add(1, Ordering::SeqCst);
335 if count < 2 {
336 with_timeout(Duration::from_millis(10), async {
337 tokio::time::sleep(Duration::from_secs(10)).await;
338 42
339 })
340 .await
341 } else {
342 Ok(100)
343 }
344 }
345 })
346 .await;
347
348 assert!(result.is_ok());
349 assert_eq!(result.unwrap(), 100);
350 assert_eq!(counter.load(Ordering::SeqCst), 3);
351 }
352}