api_openai/
enhanced_retry.rs1#![ allow( clippy::missing_inline_in_public_items ) ]
8
9#[ cfg( feature = "retry" ) ]
10mod private
11{
12 use crate::
13 {
14 error ::{ OpenAIError, Result },
15 };
16
17 use core::time::Duration;
18 use std::
19 {
20 sync ::{ Arc, Mutex },
21 time ::Instant,
22 };
23
24 use serde::{ Serialize, Deserialize };
25 use tokio::time::sleep;
26 use rand::Rng;
27
28 #[ derive( Debug, Clone, Serialize, Deserialize ) ]
30 pub struct EnhancedRetryConfig
31 {
32 pub max_attempts : u32,
34 pub base_delay_ms : u64,
36 pub max_delay_ms : u64,
38 pub max_elapsed_time_ms : u64,
40 pub jitter_ms : u64,
42 pub backoff_multiplier : f64,
44 }
45
46 impl Default for EnhancedRetryConfig
47 {
48 fn default() -> Self
49 {
50 Self
51 {
52 max_attempts : 3,
53 base_delay_ms : 1000,
54 max_delay_ms : 30000,
55 max_elapsed_time_ms : 120_000,
56 jitter_ms : 100,
57 backoff_multiplier : 2.0,
58 }
59 }
60 }
61
62 impl EnhancedRetryConfig
63 {
64 #[ must_use ]
66 pub fn new() -> Self
67 {
68 Self::default()
69 }
70
71 #[ must_use ]
73 pub fn with_max_attempts( mut self, max_attempts : u32 ) -> Self
74 {
75 self.max_attempts = max_attempts;
76 self
77 }
78
79 #[ must_use ]
81 pub fn with_base_delay( mut self, base_delay_ms : u64 ) -> Self
82 {
83 self.base_delay_ms = base_delay_ms;
84 self
85 }
86
87 #[ must_use ]
89 pub fn with_max_delay( mut self, max_delay_ms : u64 ) -> Self
90 {
91 self.max_delay_ms = max_delay_ms;
92 self
93 }
94
95 #[ must_use ]
97 pub fn with_max_elapsed_time( mut self, max_elapsed_time_ms : u64 ) -> Self
98 {
99 self.max_elapsed_time_ms = max_elapsed_time_ms;
100 self
101 }
102
103 #[ must_use ]
105 pub fn with_jitter( mut self, jitter_ms : u64 ) -> Self
106 {
107 self.jitter_ms = jitter_ms;
108 self
109 }
110
111 #[ must_use ]
113 pub fn with_backoff_multiplier( mut self, multiplier : f64 ) -> Self
114 {
115 self.backoff_multiplier = multiplier;
116 self
117 }
118
119 #[ must_use ]
122 pub fn calculate_delay( &self, attempt : u32 ) -> Duration
123 {
124 let max_delay = Duration::from_millis( self.max_delay_ms );
125
126 let base_delay_f64 = self.base_delay_ms as f64;
128 let attempt_i32 = i32::try_from( attempt ).unwrap_or( i32::MAX );
129 let exponential_f64 = base_delay_f64 * self.backoff_multiplier.powi( attempt_i32 );
130 #[ allow(clippy::cast_possible_truncation, clippy::cast_sign_loss) ]
131 let exponential_delay = exponential_f64.min( u64::MAX as f64 ).max( 0.0 ) as u64;
132
133 let mut rng = rand::rng();
135 let jitter = rng.random_range( 0..=self.jitter_ms );
136
137 let total_delay_ms = exponential_delay + jitter;
138 let total_delay = Duration::from_millis( total_delay_ms );
139
140 core ::cmp::min( total_delay, max_delay )
142 }
143
144 #[ must_use ]
146 pub fn is_retryable_error( &self, error : &OpenAIError ) -> bool
147 {
148 match error
149 {
150 OpenAIError::Network( _ ) | OpenAIError::Timeout( _ ) | OpenAIError::RateLimit( _ ) | OpenAIError::Stream( _ ) | OpenAIError::Ws( _ ) => true,
152 OpenAIError::Http( message ) =>
154 {
155 message.contains( '5' ) || message.contains( "429" ) || message.contains( "500" ) || message.contains( "502" ) || message.contains( "503" ) || message.contains( "504" )
156 },
157 OpenAIError::Api( _ ) | OpenAIError::WsInvalidMessage( _ ) | OpenAIError::Internal( _ ) |
159 OpenAIError::InvalidArgument( _ ) | OpenAIError::MissingArgument( _ ) | OpenAIError::MissingEnvironment( _ ) |
160 OpenAIError::MissingHeader( _ ) | OpenAIError::MissingFile( _ ) | OpenAIError::File( _ ) | OpenAIError::Unknown( _ ) => false,
161 }
162 }
163
164 pub fn validate( &self ) -> core::result::Result< (), String >
170 {
171 if self.max_attempts == 0
172 {
173 return Err( "max_attempts must be greater than 0".to_string() );
174 }
175
176 if self.base_delay_ms == 0
177 {
178 return Err( "base_delay_ms must be greater than 0".to_string() );
179 }
180
181 if self.max_delay_ms < self.base_delay_ms
182 {
183 return Err( "max_delay_ms must be greater than or equal to base_delay_ms".to_string() );
184 }
185
186 if self.max_elapsed_time_ms == 0
187 {
188 return Err( "max_elapsed_time_ms must be greater than 0".to_string() );
189 }
190
191 if self.backoff_multiplier <= 0.0
192 {
193 return Err( "backoff_multiplier must be greater than 0".to_string() );
194 }
195
196 Ok( () )
197 }
198 }
199
200 #[ derive( Debug ) ]
202 pub struct RetryState
203 {
204 pub attempt : u32,
206 pub total_attempts : u32,
208 pub start_time : Instant,
210 pub last_error : Option< String >,
212 pub elapsed_time : Duration,
214 }
215
216 impl Default for RetryState
217 {
218 fn default() -> Self
219 {
220 Self::new()
221 }
222 }
223
224 impl RetryState
225 {
226 #[ must_use ]
228 pub fn new() -> Self
229 {
230 Self
231 {
232 attempt : 0,
233 total_attempts : 0,
234 start_time : Instant::now(),
235 last_error : None,
236 elapsed_time : Duration::ZERO,
237 }
238 }
239
240 pub fn next_attempt( &mut self )
242 {
243 self.attempt += 1;
244 self.total_attempts += 1;
245 self.elapsed_time = self.start_time.elapsed();
246 }
247
248 pub fn set_error( &mut self, error : String )
250 {
251 self.last_error = Some( error );
252 }
253
254 pub fn reset( &mut self )
256 {
257 self.attempt = 0;
258 self.total_attempts = 0;
259 self.start_time = Instant::now();
260 self.last_error = None;
261 self.elapsed_time = Duration::ZERO;
262 }
263
264 #[ must_use ]
266 pub fn is_elapsed_time_exceeded( &self, max_elapsed_time : Duration ) -> bool
267 {
268 self.elapsed_time >= max_elapsed_time
269 }
270 }
271
272 #[ derive( Debug ) ]
274 pub struct EnhancedRetryExecutor
275 {
276 config : EnhancedRetryConfig,
277 state : Arc< Mutex< RetryState > >,
278 }
279
280 impl EnhancedRetryExecutor
281 {
282 pub fn new( config : EnhancedRetryConfig ) -> core::result::Result< Self, String >
288 {
289 config.validate()?;
290
291 Ok( Self
292 {
293 config,
294 state : Arc::new( Mutex::new( RetryState::new() ) ),
295 } )
296 }
297
298 pub async fn execute< F, Fut, T >( &self, operation : F ) -> Result< T >
308 where
309 F : Fn() -> Fut,
310 Fut : core::future::Future< Output = Result< T > >,
311 {
312 {
314 let mut state = self.state.lock().unwrap();
315 state.reset();
316 }
317
318 let max_elapsed_time = Duration::from_millis( self.config.max_elapsed_time_ms );
319
320 loop
321 {
322 {
324 let state = self.state.lock().unwrap();
325 if state.is_elapsed_time_exceeded( max_elapsed_time )
326 {
327 return Err( error_tools::untyped::Error::msg( format!( "Max elapsed time exceeded : {max_elapsed_time:?}" ) ) );
328 }
329 }
330
331 {
333 let mut state = self.state.lock().unwrap();
334 state.next_attempt();
335 }
336
337 let current_attempt = {
339 let state = self.state.lock().unwrap();
340 state.attempt
341 };
342
343 match operation().await
345 {
346 Ok( result ) => return Ok( result ),
347 Err( error ) =>
348 {
349 {
351 let mut state = self.state.lock().unwrap();
352 state.set_error( error.to_string() );
353 }
354
355 let is_retryable = if let Some( openai_error ) = error.downcast_ref::< OpenAIError >()
357 {
358 self.config.is_retryable_error( openai_error )
359 }
360 else
361 {
362 let error_msg = error.to_string().to_lowercase();
364 error_msg.contains( "network" ) || error_msg.contains( "timeout" ) || error_msg.contains( "connection" )
365 };
366
367 if !is_retryable
369 {
370 return Err( error );
371 }
372
373 if current_attempt >= self.config.max_attempts
375 {
376 return Err( error );
377 }
378
379 let delay = self.config.calculate_delay( current_attempt - 1 );
381
382 #[ cfg( feature = "retry" ) ]
384 {
385 tracing ::debug!( "Retrying request attempt {} after {:?} delay", current_attempt, delay );
386 }
387
388 sleep( delay ).await;
390 }
391 }
392 }
393 }
394
395 #[ must_use ]
401 pub fn get_state( &self ) -> RetryState
402 {
403 let state = self.state.lock().unwrap();
404 RetryState
405 {
406 attempt : state.attempt,
407 total_attempts : state.total_attempts,
408 start_time : state.start_time,
409 last_error : state.last_error.clone(),
410 elapsed_time : state.elapsed_time,
411 }
412 }
413
414 #[ must_use ]
416 pub fn config( &self ) -> &EnhancedRetryConfig
417 {
418 &self.config
419 }
420 }
421}
422
423#[ cfg( feature = "retry" ) ]
425pub use private::
426{
427 EnhancedRetryConfig,
428 RetryState,
429 EnhancedRetryExecutor,
430};
431
432#[ cfg( not( feature = "retry" ) ) ]
434pub mod private
436{
437 #[ derive( Debug, Clone ) ]
439 pub struct EnhancedRetryConfig;
440
441 impl EnhancedRetryConfig
442 {
443 #[ must_use ]
445 pub fn new() -> Self
446 {
447 Self
448 }
449 }
450
451 impl Default for EnhancedRetryConfig
452 {
453 fn default() -> Self
454 {
455 Self
456 }
457 }
458}
459
460#[ cfg( not( feature = "retry" ) ) ]
461pub use private::EnhancedRetryConfig;
462
463crate ::mod_interface!
465{
466 #[ cfg( feature = "retry" ) ]
467 exposed use
468 {
469 EnhancedRetryConfig,
470 RetryState,
471 EnhancedRetryExecutor,
472 };
473
474 #[ cfg( not( feature = "retry" ) ) ]
475 exposed use
476 {
477 EnhancedRetryConfig,
478 };
479}