#![ allow( clippy::missing_inline_in_public_items ) ]
#[ cfg( feature = "retry" ) ]
mod private
{
use crate::
{
error ::{ OpenAIError, Result },
};
use core::time::Duration;
use std::
{
sync ::{ Arc, Mutex },
time ::Instant,
};
use serde::{ Serialize, Deserialize };
use tokio::time::sleep;
use rand::Rng;
#[ derive( Debug, Clone, Serialize, Deserialize ) ]
pub struct EnhancedRetryConfig
{
pub max_attempts : u32,
pub base_delay_ms : u64,
pub max_delay_ms : u64,
pub max_elapsed_time_ms : u64,
pub jitter_ms : u64,
pub backoff_multiplier : f64,
}
impl Default for EnhancedRetryConfig
{
fn default() -> Self
{
Self
{
max_attempts : 3,
base_delay_ms : 1000,
max_delay_ms : 30000,
max_elapsed_time_ms : 120_000,
jitter_ms : 100,
backoff_multiplier : 2.0,
}
}
}
impl EnhancedRetryConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ must_use ]
pub fn with_max_attempts( mut self, max_attempts : u32 ) -> Self
{
self.max_attempts = max_attempts;
self
}
#[ must_use ]
pub fn with_base_delay( mut self, base_delay_ms : u64 ) -> Self
{
self.base_delay_ms = base_delay_ms;
self
}
#[ must_use ]
pub fn with_max_delay( mut self, max_delay_ms : u64 ) -> Self
{
self.max_delay_ms = max_delay_ms;
self
}
#[ must_use ]
pub fn with_max_elapsed_time( mut self, max_elapsed_time_ms : u64 ) -> Self
{
self.max_elapsed_time_ms = max_elapsed_time_ms;
self
}
#[ must_use ]
pub fn with_jitter( mut self, jitter_ms : u64 ) -> Self
{
self.jitter_ms = jitter_ms;
self
}
#[ must_use ]
pub fn with_backoff_multiplier( mut self, multiplier : f64 ) -> Self
{
self.backoff_multiplier = multiplier;
self
}
#[ must_use ]
pub fn calculate_delay( &self, attempt : u32 ) -> Duration
{
let max_delay = Duration::from_millis( self.max_delay_ms );
let base_delay_f64 = self.base_delay_ms as f64;
let attempt_i32 = i32::try_from( attempt ).unwrap_or( i32::MAX );
let exponential_f64 = base_delay_f64 * self.backoff_multiplier.powi( attempt_i32 );
#[ allow(clippy::cast_possible_truncation, clippy::cast_sign_loss) ]
let exponential_delay = exponential_f64.min( u64::MAX as f64 ).max( 0.0 ) as u64;
let mut rng = rand::rng();
let jitter = rng.random_range( 0..=self.jitter_ms );
let total_delay_ms = exponential_delay + jitter;
let total_delay = Duration::from_millis( total_delay_ms );
core ::cmp::min( total_delay, max_delay )
}
#[ must_use ]
pub fn is_retryable_error( &self, error : &OpenAIError ) -> bool
{
match error
{
OpenAIError::Network( _ ) | OpenAIError::Timeout( _ ) | OpenAIError::RateLimit( _ ) | OpenAIError::Stream( _ ) | OpenAIError::Ws( _ ) => true,
OpenAIError::Http( message ) =>
{
message.contains( '5' ) || message.contains( "429" ) || message.contains( "500" ) || message.contains( "502" ) || message.contains( "503" ) || message.contains( "504" )
},
OpenAIError::Api( _ ) | OpenAIError::WsInvalidMessage( _ ) | OpenAIError::Internal( _ ) |
OpenAIError::InvalidArgument( _ ) | OpenAIError::MissingArgument( _ ) | OpenAIError::MissingEnvironment( _ ) |
OpenAIError::MissingHeader( _ ) | OpenAIError::MissingFile( _ ) | OpenAIError::File( _ ) | OpenAIError::Unknown( _ ) => false,
}
}
pub fn validate( &self ) -> core::result::Result< (), String >
{
if self.max_attempts == 0
{
return Err( "max_attempts must be greater than 0".to_string() );
}
if self.base_delay_ms == 0
{
return Err( "base_delay_ms must be greater than 0".to_string() );
}
if self.max_delay_ms < self.base_delay_ms
{
return Err( "max_delay_ms must be greater than or equal to base_delay_ms".to_string() );
}
if self.max_elapsed_time_ms == 0
{
return Err( "max_elapsed_time_ms must be greater than 0".to_string() );
}
if self.backoff_multiplier <= 0.0
{
return Err( "backoff_multiplier must be greater than 0".to_string() );
}
Ok( () )
}
}
#[ derive( Debug ) ]
pub struct RetryState
{
pub attempt : u32,
pub total_attempts : u32,
pub start_time : Instant,
pub last_error : Option< String >,
pub elapsed_time : Duration,
}
impl Default for RetryState
{
fn default() -> Self
{
Self::new()
}
}
impl RetryState
{
#[ must_use ]
pub fn new() -> Self
{
Self
{
attempt : 0,
total_attempts : 0,
start_time : Instant::now(),
last_error : None,
elapsed_time : Duration::ZERO,
}
}
pub fn next_attempt( &mut self )
{
self.attempt += 1;
self.total_attempts += 1;
self.elapsed_time = self.start_time.elapsed();
}
pub fn set_error( &mut self, error : String )
{
self.last_error = Some( error );
}
pub fn reset( &mut self )
{
self.attempt = 0;
self.total_attempts = 0;
self.start_time = Instant::now();
self.last_error = None;
self.elapsed_time = Duration::ZERO;
}
#[ must_use ]
pub fn is_elapsed_time_exceeded( &self, max_elapsed_time : Duration ) -> bool
{
self.elapsed_time >= max_elapsed_time
}
}
#[ derive( Debug ) ]
pub struct EnhancedRetryExecutor
{
config : EnhancedRetryConfig,
state : Arc< Mutex< RetryState > >,
}
impl EnhancedRetryExecutor
{
pub fn new( config : EnhancedRetryConfig ) -> core::result::Result< Self, String >
{
config.validate()?;
Ok( Self
{
config,
state : Arc::new( Mutex::new( RetryState::new() ) ),
} )
}
pub async fn execute< F, Fut, T >( &self, operation : F ) -> Result< T >
where
F : Fn() -> Fut,
Fut : core::future::Future< Output = Result< T > >,
{
{
let mut state = self.state.lock().unwrap();
state.reset();
}
let max_elapsed_time = Duration::from_millis( self.config.max_elapsed_time_ms );
loop
{
{
let state = self.state.lock().unwrap();
if state.is_elapsed_time_exceeded( max_elapsed_time )
{
return Err( error_tools::untyped::Error::msg( format!( "Max elapsed time exceeded : {max_elapsed_time:?}" ) ) );
}
}
{
let mut state = self.state.lock().unwrap();
state.next_attempt();
}
let current_attempt = {
let state = self.state.lock().unwrap();
state.attempt
};
match operation().await
{
Ok( result ) => return Ok( result ),
Err( error ) =>
{
{
let mut state = self.state.lock().unwrap();
state.set_error( error.to_string() );
}
let is_retryable = if let Some( openai_error ) = error.downcast_ref::< OpenAIError >()
{
self.config.is_retryable_error( openai_error )
}
else
{
let error_msg = error.to_string().to_lowercase();
error_msg.contains( "network" ) || error_msg.contains( "timeout" ) || error_msg.contains( "connection" )
};
if !is_retryable
{
return Err( error );
}
if current_attempt >= self.config.max_attempts
{
return Err( error );
}
let delay = self.config.calculate_delay( current_attempt - 1 );
#[ cfg( feature = "retry" ) ]
{
tracing ::debug!( "Retrying request attempt {} after {:?} delay", current_attempt, delay );
}
sleep( delay ).await;
}
}
}
}
#[ must_use ]
pub fn get_state( &self ) -> RetryState
{
let state = self.state.lock().unwrap();
RetryState
{
attempt : state.attempt,
total_attempts : state.total_attempts,
start_time : state.start_time,
last_error : state.last_error.clone(),
elapsed_time : state.elapsed_time,
}
}
#[ must_use ]
pub fn config( &self ) -> &EnhancedRetryConfig
{
&self.config
}
}
}
#[ cfg( feature = "retry" ) ]
pub use private::
{
EnhancedRetryConfig,
RetryState,
EnhancedRetryExecutor,
};
#[ cfg( not( feature = "retry" ) ) ]
pub mod private
{
#[ derive( Debug, Clone ) ]
pub struct EnhancedRetryConfig;
impl EnhancedRetryConfig
{
#[ must_use ]
pub fn new() -> Self
{
Self
}
}
impl Default for EnhancedRetryConfig
{
fn default() -> Self
{
Self
}
}
}
#[ cfg( not( feature = "retry" ) ) ]
pub use private::EnhancedRetryConfig;
crate ::mod_interface!
{
#[ cfg( feature = "retry" ) ]
exposed use
{
EnhancedRetryConfig,
RetryState,
EnhancedRetryExecutor,
};
#[ cfg( not( feature = "retry" ) ) ]
exposed use
{
EnhancedRetryConfig,
};
}