#[ cfg( feature = "retry" ) ]
mod private
{
use std::time::{ Duration, Instant };
use std::sync::{ Arc, Mutex };
use std::pin::Pin;
use std::future::Future;
#[ derive( Debug, Clone ) ]
pub struct RetryConfig
{
pub max_attempts : u32,
pub base_delay_ms : u64,
pub max_elapsed_time : Duration,
pub jitter_ms : u64,
pub backoff_multiplier : f64,
pub log_attempts : bool,
}
impl Default for RetryConfig
{
#[ inline ]
fn default() -> Self
{
Self
{
max_attempts : 3,
base_delay_ms : 1000, max_elapsed_time : Duration::from_secs( 30 ),
jitter_ms : 500, backoff_multiplier : 2.0, log_attempts : true,
}
}
}
impl RetryConfig
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ inline ]
#[ must_use ]
pub fn with_max_attempts( mut self, max_attempts : u32 ) -> Self
{
self.max_attempts = max_attempts;
self
}
#[ inline ]
#[ must_use ]
pub fn with_base_delay_ms( mut self, base_delay_ms : u64 ) -> Self
{
self.base_delay_ms = base_delay_ms;
self
}
#[ inline ]
#[ must_use ]
pub fn with_max_elapsed_time( mut self, max_elapsed_time : Duration ) -> Self
{
self.max_elapsed_time = max_elapsed_time;
self
}
#[ inline ]
#[ must_use ]
pub fn with_jitter_ms( mut self, jitter_ms : u64 ) -> Self
{
self.jitter_ms = jitter_ms;
self
}
#[ inline ]
#[ must_use ]
pub fn with_backoff_multiplier( mut self, backoff_multiplier : f64 ) -> Self
{
self.backoff_multiplier = backoff_multiplier;
self
}
#[ inline ]
#[ must_use ]
pub fn with_logging( mut self, log_attempts : bool ) -> Self
{
self.log_attempts = log_attempts;
self
}
}
#[ derive( Debug, Clone, PartialEq ) ]
pub enum ErrorClassification
{
Retryable,
NonRetryable,
Timeout,
}
#[ derive( Debug, Default, Clone ) ]
pub struct RetryMetrics
{
pub total_attempts : Arc< Mutex< u64 > >,
pub successful_retries : Arc< Mutex< u64 > >,
pub failed_operations : Arc< Mutex< u64 > >,
pub total_delay_ms : Arc< Mutex< u64 > >,
}
impl RetryMetrics
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ inline ]
pub fn record_attempt( &self )
{
let mut total = self.total_attempts.lock().unwrap();
*total += 1;
}
#[ inline ]
pub fn record_success( &self )
{
let mut successful = self.successful_retries.lock().unwrap();
*successful += 1;
}
#[ inline ]
pub fn record_failure( &self )
{
let mut failed = self.failed_operations.lock().unwrap();
*failed += 1;
}
#[ inline ]
pub fn record_delay( &self, delay : Duration )
{
let mut total_delay = self.total_delay_ms.lock().unwrap();
*total_delay += delay.as_millis() as u64;
}
#[ inline ]
#[ must_use ]
pub fn get_stats( &self ) -> RetryStats
{
let total_attempts = *self.total_attempts.lock().unwrap();
let successful_retries = *self.successful_retries.lock().unwrap();
let failed_operations = *self.failed_operations.lock().unwrap();
let total_delay_ms = *self.total_delay_ms.lock().unwrap();
RetryStats
{
total_attempts,
successful_retries,
failed_operations,
total_delay_ms,
success_rate : if total_attempts > 0
{
successful_retries as f64 / total_attempts as f64
}
else
{
0.0
},
}
}
}
#[ derive( Debug, Clone ) ]
pub struct RetryStats
{
pub total_attempts : u64,
pub successful_retries : u64,
pub failed_operations : u64,
pub total_delay_ms : u64,
pub success_rate : f64,
}
#[ derive( Debug ) ]
pub struct ErrorClassifier;
impl ErrorClassifier
{
#[ inline ]
#[ must_use ]
pub fn classify( error_message : &str ) -> ErrorClassification
{
let error_lower = error_message.to_lowercase();
if error_lower.contains( "400" ) || error_lower.contains( "401" ) || error_lower.contains( "403" ) || error_lower.contains( "404" ) || error_lower.contains( "405" ) || error_lower.contains( "406" ) || error_lower.contains( "409" ) || error_lower.contains( "410" ) || error_lower.contains( "422" ) || error_lower.contains( "unauthorized" ) ||
error_lower.contains( "forbidden" ) ||
error_lower.contains( "bad request" ) ||
error_lower.contains( "invalid" ) && !error_lower.contains( "invalid response" )
{
return ErrorClassification::NonRetryable;
}
if error_lower.contains( "timed out" ) ||
error_lower.contains( "timeout" ) ||
error_lower.contains( "deadline exceeded" )
{
return ErrorClassification::Timeout;
}
if error_lower.contains( "connection" ) ||
error_lower.contains( "network" ) ||
error_lower.contains( "dns" ) ||
error_lower.contains( "502" ) || error_lower.contains( "503" ) || error_lower.contains( "504" ) || error_lower.contains( "500" ) || error_lower.contains( "unreachable" ) ||
error_lower.contains( "refused" ) ||
error_lower.contains( "reset" ) ||
error_lower.contains( "aborted" )
{
return ErrorClassification::Retryable;
}
ErrorClassification::Retryable
}
}
#[ inline ]
pub fn calculate_retry_delay( attempt : u32, config : &RetryConfig ) -> Duration
{
let base_delay_f64 = config.base_delay_ms as f64;
let exponential_delay = base_delay_f64 * config.backoff_multiplier.powi( attempt as i32 );
let jitter = if config.jitter_ms > 0
{
fastrand ::u64( 0..=config.jitter_ms )
}
else
{
0
};
Duration::from_millis( exponential_delay as u64 + jitter )
}
pub async fn execute_with_retries< F, T, E >(
operation : F,
config : RetryConfig,
metrics : Option< &RetryMetrics >
) -> std::result::Result< T, E >
where
F: Fn() -> Pin< Box< dyn Future< Output = std::result::Result< T, E > > + Send > > + Send + Sync,
E: std::fmt::Display + Send + Sync,
{
let start_time = Instant::now();
let mut last_error = None;
for attempt in 0..config.max_attempts
{
if start_time.elapsed() > config.max_elapsed_time
{
if config.log_attempts
{
println!( "⚠ Retry abandoned due to max elapsed time : {:?}", config.max_elapsed_time );
}
break;
}
if let Some( m ) = metrics
{
m.record_attempt();
}
match operation().await
{
Ok( result ) =>
{
if attempt > 0
{
if let Some( m ) = metrics
{
m.record_success();
}
if config.log_attempts
{
println!( "✓ Operation succeeded after {} retry attempts", attempt );
}
}
return Ok( result );
}
Err( error ) =>
{
let error_str = error.to_string();
let classification = ErrorClassifier::classify( &error_str );
if config.log_attempts
{
println!( "⚠ Attempt {} failed : {} (classification : {:?})", attempt + 1, error_str, classification );
}
if classification == ErrorClassification::NonRetryable
{
if config.log_attempts
{
println!( "⚠ Error classified as non-retryable, aborting retries" );
}
return Err( error );
}
last_error = Some( error );
if attempt < config.max_attempts - 1
{
let delay = calculate_retry_delay( attempt, &config );
if config.log_attempts
{
println!( "⏳ Waiting {:?} before retry attempt {}", delay, attempt + 2 );
}
if let Some( m ) = metrics
{
m.record_delay( delay );
}
tokio ::time::sleep( delay ).await;
}
}
}
}
if let Some( m ) = metrics
{
m.record_failure();
}
if config.log_attempts
{
println!( "⚠ All {} retry attempts exhausted", config.max_attempts );
}
Err( last_error.unwrap() )
}
#[ derive( Debug, Clone ) ]
pub struct RetryableHttpClient
{
pub config : Option< RetryConfig >,
pub metrics : Arc< RetryMetrics >,
}
impl RetryableHttpClient
{
#[ inline ]
#[ must_use ]
pub fn new( config : Option< RetryConfig > ) -> Self
{
Self
{
config,
metrics : Arc::new( RetryMetrics::new() ),
}
}
pub async fn execute< F, T, E >( &self, operation : F ) -> std::result::Result< T, E >
where
F: Fn() -> Pin< Box< dyn Future< Output = std::result::Result< T, E > > + Send > > + Send + Sync,
E: std::fmt::Display + Send + Sync,
{
match &self.config
{
Some( config ) =>
{
execute_with_retries( operation, config.clone(), Some( &self.metrics ) ).await
}
None =>
{
operation().await
}
}
}
#[ inline ]
#[ must_use ]
pub fn get_metrics( &self ) -> RetryStats
{
self.metrics.get_stats()
}
#[ inline ]
pub fn reset_metrics( &self )
{
*self.metrics.total_attempts.lock().unwrap() = 0;
*self.metrics.successful_retries.lock().unwrap() = 0;
*self.metrics.failed_operations.lock().unwrap() = 0;
*self.metrics.total_delay_ms.lock().unwrap() = 0;
}
}
impl Default for RetryableHttpClient
{
#[ inline ]
fn default() -> Self
{
Self::new( None )
}
}
#[ inline ]
pub fn retry_operation< F, T, E >( operation : F ) -> impl Fn() -> Pin< Box< dyn Future< Output = std::result::Result< T, E > > + Send > >
where
F: Fn() -> Pin< Box< dyn Future< Output = std::result::Result< T, E > > + Send > > + Send + Sync + Clone,
{
move ||
{
let op = operation.clone();
op()
}
}
#[ cfg( test ) ]
pub mod test_utils
{
use super::*;
pub fn test_operation_with_failures(
failure_count : u32,
success_message : String
) -> impl Fn() -> Pin< Box< dyn Future< Output = std::result::Result< String, String > > + Send > >
{
use std::sync::atomic::{ AtomicU32, Ordering };
let attempt_counter = Arc::new( AtomicU32::new( 0 ) );
move ||
{
let counter = Arc::clone( &attempt_counter );
let msg = success_message.clone();
Box::pin( async move
{
let current_attempt = counter.fetch_add( 1, Ordering::SeqCst ) + 1;
if current_attempt <= failure_count
{
Err( format!( "Test failure on attempt {current_attempt}" ) )
}
else
{
Ok( format!( "{msg} (succeeded on attempt {current_attempt})" ) )
}
} )
}
}
pub fn test_operation_always_fails(
error_message : String
) -> impl Fn() -> Pin< Box< dyn Future< Output = std::result::Result< String, String > > + Send > >
{
move ||
{
let error = error_message.clone();
Box::pin( async move
{
Err( error )
} )
}
}
}
}
#[ cfg( feature = "retry" ) ]
crate ::mod_interface!
{
exposed use private::RetryConfig;
exposed use private::ErrorClassification;
exposed use private::RetryMetrics;
exposed use private::RetryStats;
exposed use private::ErrorClassifier;
exposed use private::RetryableHttpClient;
exposed use private::execute_with_retries;
exposed use private::calculate_retry_delay;
exposed use private::retry_operation;
#[ cfg( test ) ]
exposed use private::test_utils;
}