use crate::utils::RetryConfig;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FailureType {
Timeout,
ConnectionFailed,
RateLimited,
InvalidData,
ServerError,
Unknown,
}
impl FailureType {
#[must_use]
#[inline]
pub const fn is_retryable(&self) -> bool {
matches!(
self,
Self::Timeout | Self::ConnectionFailed | Self::ServerError
)
}
#[must_use]
#[inline]
pub const fn retry_multiplier(&self) -> f64 {
match self {
Self::Timeout => 1.5, Self::ConnectionFailed => 2.0, Self::RateLimited => 3.0, Self::ServerError => 1.2, Self::InvalidData => 0.5, Self::Unknown => 1.0, }
}
}
#[derive(Debug, Clone)]
struct FailureRecord {
failure_type: FailureType,
timestamp: Instant,
}
#[derive(Debug, Clone, Default)]
struct TargetStats {
total_attempts: u64,
successful_attempts: u64,
recent_failures: Vec<FailureRecord>,
last_success: Option<Instant>,
consecutive_failures: u32,
}
impl TargetStats {
#[must_use]
#[inline]
fn success_rate(&self) -> f64 {
if self.total_attempts == 0 {
return 0.5; }
self.successful_attempts as f64 / self.total_attempts as f64
}
#[must_use]
#[inline]
fn dominant_failure_type(&self) -> Option<FailureType> {
let mut counts: HashMap<FailureType, usize> = HashMap::new();
for record in &self.recent_failures {
*counts.entry(record.failure_type).or_insert(0) += 1;
}
counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(failure_type, _)| failure_type)
}
#[must_use]
#[inline]
fn is_having_issues(&self) -> bool {
self.consecutive_failures > 3 || self.success_rate() < 0.3
}
}
pub struct AdaptiveRetryPolicy {
target_stats: Arc<Mutex<HashMap<String, TargetStats>>>,
base_config: RetryConfig,
history_window: Duration,
}
impl AdaptiveRetryPolicy {
#[must_use]
#[inline]
pub fn new() -> Self {
Self {
target_stats: Arc::new(Mutex::new(HashMap::new())),
base_config: RetryConfig::default(),
history_window: Duration::from_secs(300), }
}
#[must_use]
#[inline]
pub fn with_config(base_config: RetryConfig) -> Self {
Self {
target_stats: Arc::new(Mutex::new(HashMap::new())),
base_config,
history_window: Duration::from_secs(300),
}
}
pub fn record_success(&mut self, target: &str) {
let mut stats = self.target_stats.lock().unwrap();
let entry = stats.entry(target.to_string()).or_default();
entry.total_attempts += 1;
entry.successful_attempts += 1;
entry.consecutive_failures = 0;
entry.last_success = Some(Instant::now());
}
pub fn record_failure(&mut self, target: &str, failure_type: FailureType) {
let mut stats = self.target_stats.lock().unwrap();
let entry = stats.entry(target.to_string()).or_default();
entry.total_attempts += 1;
entry.consecutive_failures += 1;
entry.recent_failures.push(FailureRecord {
failure_type,
timestamp: Instant::now(),
});
self.cleanup_old_failures(entry);
}
fn cleanup_old_failures(&self, stats: &mut TargetStats) {
let cutoff = Instant::now() - self.history_window;
stats.recent_failures.retain(|f| f.timestamp > cutoff);
}
#[must_use]
#[inline]
pub fn should_retry(&self, target: &str, attempt: u32) -> bool {
let stats = self.target_stats.lock().unwrap();
if attempt >= self.base_config.max_attempts {
return false;
}
if let Some(target_stats) = stats.get(target) {
if target_stats.consecutive_failures > 5 {
return false;
}
if let Some(failure_type) = target_stats.dominant_failure_type() {
if !failure_type.is_retryable() {
return false;
}
}
}
true
}
#[must_use]
#[inline]
pub fn retry_delay(&self, target: &str, attempt: u32) -> Duration {
let base_delay = self.base_config.delay_for_attempt(attempt);
let stats = self.target_stats.lock().unwrap();
if let Some(target_stats) = stats.get(target) {
let mut multiplier = 1.0;
let success_rate = target_stats.success_rate();
if success_rate < 0.5 {
multiplier *= 1.5;
} else if success_rate < 0.7 {
multiplier *= 1.2;
}
if let Some(failure_type) = target_stats.dominant_failure_type() {
multiplier *= failure_type.retry_multiplier();
}
if target_stats.consecutive_failures > 2 {
multiplier *= 1.5f64.powi(target_stats.consecutive_failures as i32 - 2);
}
Duration::from_millis((base_delay.as_millis() as f64 * multiplier) as u64)
.min(Duration::from_millis(self.base_config.max_delay_ms))
} else {
base_delay
}
}
#[must_use]
#[inline]
pub fn success_rate(&self, target: &str) -> f64 {
let stats = self.target_stats.lock().unwrap();
stats.get(target).map(|s| s.success_rate()).unwrap_or(0.5)
}
#[must_use]
#[inline]
pub fn consecutive_failures(&self, target: &str) -> u32 {
let stats = self.target_stats.lock().unwrap();
stats
.get(target)
.map(|s| s.consecutive_failures)
.unwrap_or(0)
}
#[must_use]
#[inline]
pub fn is_target_having_issues(&self, target: &str) -> bool {
let stats = self.target_stats.lock().unwrap();
stats
.get(target)
.map(|s| s.is_having_issues())
.unwrap_or(false)
}
#[must_use]
#[inline]
pub fn recommended_config(&self, target: &str) -> RetryConfig {
let stats = self.target_stats.lock().unwrap();
if let Some(target_stats) = stats.get(target) {
let success_rate = target_stats.success_rate();
if success_rate > 0.8 {
RetryConfig::aggressive()
} else if success_rate > 0.5 {
self.base_config.clone()
} else {
RetryConfig::conservative()
}
} else {
self.base_config.clone()
}
}
pub fn reset_target(&mut self, target: &str) {
let mut stats = self.target_stats.lock().unwrap();
stats.remove(target);
}
pub fn reset_all(&mut self) {
let mut stats = self.target_stats.lock().unwrap();
stats.clear();
}
#[must_use]
#[inline]
pub fn tracked_targets_count(&self) -> usize {
let stats = self.target_stats.lock().unwrap();
stats.len()
}
#[must_use]
#[inline]
pub fn detect_failure_burst(&self, target: &str) -> bool {
let stats = self.target_stats.lock().unwrap();
if let Some(target_stats) = stats.get(target) {
let one_minute_ago = Instant::now() - Duration::from_secs(60);
let recent_count = target_stats
.recent_failures
.iter()
.filter(|f| f.timestamp > one_minute_ago)
.count();
return recent_count >= 5;
}
false
}
#[must_use]
#[inline]
pub fn failure_interval(&self, target: &str) -> Option<Duration> {
let stats = self.target_stats.lock().unwrap();
if let Some(target_stats) = stats.get(target) {
if target_stats.recent_failures.len() < 2 {
return None;
}
let failures = &target_stats.recent_failures;
let mut intervals = Vec::new();
for i in 1..failures.len() {
let interval = failures[i]
.timestamp
.saturating_duration_since(failures[i - 1].timestamp);
intervals.push(interval);
}
if intervals.is_empty() {
return None;
}
let total: Duration = intervals.iter().sum();
Some(total / intervals.len() as u32)
} else {
None
}
}
#[must_use]
#[inline]
pub fn predict_recovery_time(&self, target: &str) -> Option<Duration> {
let stats = self.target_stats.lock().unwrap();
if let Some(target_stats) = stats.get(target) {
if let Some(last_success) = target_stats.last_success {
let time_since_success = Instant::now().saturating_duration_since(last_success);
if target_stats.consecutive_failures > 0 {
let estimated_recovery = time_since_success * 2;
return Some(estimated_recovery);
}
}
if !target_stats.recent_failures.is_empty() {
return Some(Duration::from_secs(60)); }
}
None
}
#[must_use]
#[inline]
pub fn failure_patterns(&self, target: &str) -> Option<FailurePatterns> {
let stats = self.target_stats.lock().unwrap();
stats.get(target).map(|target_stats| {
let mut type_counts: HashMap<FailureType, usize> = HashMap::new();
for record in &target_stats.recent_failures {
*type_counts.entry(record.failure_type).or_insert(0) += 1;
}
let one_minute_ago = Instant::now() - Duration::from_secs(60);
let recent_count = target_stats
.recent_failures
.iter()
.filter(|f| f.timestamp > one_minute_ago)
.count();
let is_burst = recent_count >= 5;
FailurePatterns {
total_failures: target_stats.recent_failures.len(),
failure_types: type_counts,
consecutive_failures: target_stats.consecutive_failures,
success_rate: target_stats.success_rate(),
is_burst,
dominant_type: target_stats.dominant_failure_type(),
}
})
}
}
#[derive(Debug, Clone)]
pub struct FailurePatterns {
pub total_failures: usize,
pub failure_types: HashMap<FailureType, usize>,
pub consecutive_failures: u32,
pub success_rate: f64,
pub is_burst: bool,
pub dominant_type: Option<FailureType>,
}
impl FailurePatterns {
#[must_use]
#[inline]
pub fn is_systemic_issue(&self) -> bool {
self.consecutive_failures > 5 || self.success_rate < 0.2 || self.is_burst
}
#[must_use]
#[inline]
pub fn failure_type_percentage(&self, failure_type: FailureType) -> f64 {
if self.total_failures == 0 {
return 0.0;
}
let count = self.failure_types.get(&failure_type).copied().unwrap_or(0);
count as f64 / self.total_failures as f64
}
}
impl Default for AdaptiveRetryPolicy {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_failure_type_retryable() {
assert!(FailureType::Timeout.is_retryable());
assert!(FailureType::ConnectionFailed.is_retryable());
assert!(!FailureType::RateLimited.is_retryable());
}
#[test]
fn test_adaptive_policy_success_rate() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_success("peer1");
policy.record_success("peer1");
policy.record_failure("peer1", FailureType::Timeout);
let rate = policy.success_rate("peer1");
assert!((rate - 0.666).abs() < 0.01);
}
#[test]
fn test_should_retry_after_max_attempts() {
let policy = AdaptiveRetryPolicy::new();
assert!(!policy.should_retry("peer1", 10));
}
#[test]
fn test_consecutive_failures_tracking() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_failure("peer1", FailureType::Timeout);
policy.record_failure("peer1", FailureType::Timeout);
assert_eq!(policy.consecutive_failures("peer1"), 2);
policy.record_success("peer1");
assert_eq!(policy.consecutive_failures("peer1"), 0);
}
#[test]
fn test_recommended_config_adapts() {
let mut policy = AdaptiveRetryPolicy::new();
for _ in 0..10 {
policy.record_success("peer1");
}
policy.record_failure("peer1", FailureType::Timeout);
let config = policy.recommended_config("peer1");
assert!(config.max_attempts >= 5);
}
#[test]
fn test_target_having_issues() {
let mut policy = AdaptiveRetryPolicy::new();
for _ in 0..5 {
policy.record_failure("peer1", FailureType::Timeout);
}
assert!(policy.is_target_having_issues("peer1"));
}
#[test]
fn test_reset_target() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_failure("peer1", FailureType::Timeout);
assert_eq!(policy.consecutive_failures("peer1"), 1);
policy.reset_target("peer1");
assert_eq!(policy.consecutive_failures("peer1"), 0);
}
#[test]
fn test_failure_burst_detection() {
let mut policy = AdaptiveRetryPolicy::new();
for _ in 0..6 {
policy.record_failure("peer1", FailureType::Timeout);
}
assert!(policy.detect_failure_burst("peer1"));
assert!(!policy.detect_failure_burst("peer2"));
}
#[test]
fn test_failure_patterns() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_failure("peer1", FailureType::Timeout);
policy.record_failure("peer1", FailureType::Timeout);
policy.record_failure("peer1", FailureType::ConnectionFailed);
policy.record_success("peer1");
let patterns = policy.failure_patterns("peer1");
assert!(patterns.is_some());
let patterns = patterns.unwrap();
assert_eq!(patterns.total_failures, 3);
assert_eq!(patterns.dominant_type, Some(FailureType::Timeout));
assert_eq!(
patterns.failure_type_percentage(FailureType::Timeout),
2.0 / 3.0
);
}
#[test]
fn test_systemic_issue_detection() {
let mut policy = AdaptiveRetryPolicy::new();
for _ in 0..7 {
policy.record_failure("peer1", FailureType::ServerError);
}
let patterns = policy.failure_patterns("peer1").unwrap();
assert!(patterns.is_systemic_issue());
}
#[test]
fn test_predict_recovery_time() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_success("peer1");
std::thread::sleep(Duration::from_millis(10));
policy.record_failure("peer1", FailureType::Timeout);
let recovery = policy.predict_recovery_time("peer1");
assert!(recovery.is_some());
}
#[test]
fn test_failure_interval() {
let mut policy = AdaptiveRetryPolicy::new();
policy.record_failure("peer1", FailureType::Timeout);
std::thread::sleep(Duration::from_millis(10));
policy.record_failure("peer1", FailureType::Timeout);
let interval = policy.failure_interval("peer1");
assert!(interval.is_some());
}
}