use crate::error::{Error, ErrorCode, Result};
use crate::runtime::{self, RwLock};
use async_trait::async_trait;
use std::collections::HashMap;
use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tracing::{debug, error, info, warn};
#[derive(Debug)]
pub enum BulkRecoveryResult {
AllSuccess(Vec<serde_json::Value>),
PartialSuccess {
successes: Vec<(usize, serde_json::Value)>,
failures: Vec<(usize, Error)>,
},
AllFailed(Vec<Error>),
}
#[derive(Debug, Clone, Copy)]
pub enum JitterStrategy {
None,
Full,
Equal,
Decorrelated,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Degraded,
Unhealthy,
Unknown,
}
#[derive(Debug, Clone)]
pub struct HealthCheckResult {
pub component: String,
pub status: HealthStatus,
pub response_time_us: u64,
pub timestamp: SystemTime,
pub message: Option<String>,
}
#[derive(Debug, Clone)]
pub enum RecoveryEvent {
HealthChanged {
component: String,
old_status: HealthStatus,
new_status: HealthStatus,
},
RecoveryStarted {
operation_id: String,
strategy: String,
},
RecoveryCompleted {
operation_id: String,
success: bool,
duration: Duration,
},
CascadingFailure {
trigger_component: String,
affected_components: Vec<String>,
},
}
#[derive(Debug, Clone)]
pub struct RecoveryDeadline {
pub deadline: Instant,
pub remaining: Duration,
pub exceeded: bool,
}
impl RecoveryDeadline {
pub fn new(timeout: Duration) -> Self {
let deadline = Instant::now() + timeout;
Self {
deadline,
remaining: timeout,
exceeded: false,
}
}
pub fn update(&mut self) -> bool {
let now = Instant::now();
if now >= self.deadline {
self.remaining = Duration::ZERO;
self.exceeded = true;
} else {
self.remaining = self.deadline - now;
}
self.exceeded
}
pub fn has_time_for(&self, duration: Duration) -> bool {
!self.exceeded && self.remaining >= duration
}
}
#[derive(Debug, Default)]
pub struct RecoveryMetrics {
total_attempts: AtomicU64,
successful_recoveries: AtomicU64,
failed_recoveries: AtomicU64,
total_recovery_time_us: AtomicU64,
circuit_breaker_trips: AtomicU64,
fallback_executions: AtomicU64,
cascade_preventions: AtomicU64,
}
impl RecoveryMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn record_attempt(&self) {
self.total_attempts.fetch_add(1, Ordering::Relaxed);
}
pub fn record_success(&self, duration: Duration) {
self.successful_recoveries.fetch_add(1, Ordering::Relaxed);
self.total_recovery_time_us
.fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
}
pub fn record_failure(&self, duration: Duration) {
self.failed_recoveries.fetch_add(1, Ordering::Relaxed);
self.total_recovery_time_us
.fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
}
pub fn record_circuit_trip(&self) {
self.circuit_breaker_trips.fetch_add(1, Ordering::Relaxed);
}
pub fn record_fallback(&self) {
self.fallback_executions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_cascade_prevention(&self) {
self.cascade_preventions.fetch_add(1, Ordering::Relaxed);
}
pub fn success_rate(&self) -> f64 {
let total = self.total_attempts.load(Ordering::Relaxed);
let successes = self.successful_recoveries.load(Ordering::Relaxed);
if total > 0 {
(successes as f64 / total as f64) * 100.0
} else {
0.0
}
}
pub fn average_recovery_time(&self) -> Duration {
let total_time = self.total_recovery_time_us.load(Ordering::Relaxed);
let attempts = self.total_attempts.load(Ordering::Relaxed);
if attempts > 0 {
Duration::from_micros(total_time / attempts)
} else {
Duration::ZERO
}
}
pub fn circuit_breaker_trips(&self) -> u64 {
self.circuit_breaker_trips.load(Ordering::Relaxed)
}
pub fn fallback_executions(&self) -> u64 {
self.fallback_executions.load(Ordering::Relaxed)
}
pub fn cascade_preventions(&self) -> u64 {
self.cascade_preventions.load(Ordering::Relaxed)
}
}
#[derive(Debug, Clone)]
pub enum RecoveryStrategy {
RetryFixed {
attempts: u32,
delay: Duration,
},
RetryExponential {
attempts: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
},
RetryAdaptive {
attempts: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
jitter: JitterStrategy,
},
Fallback,
CircuitBreaker {
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
},
AdvancedCircuitBreaker {
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
health_check_interval: Duration,
response_time_threshold_ms: u64,
},
DeadlineAware {
max_recovery_time: Duration,
base_strategy: Box<Self>,
},
BulkRecovery {
individual_strategy: Box<Self>,
min_success_rate: f64,
fail_fast: bool,
},
CascadeAware {
base_strategy: Box<Self>,
dependencies: Vec<String>,
isolation_timeout: Duration,
},
FailFast,
}
#[derive(Debug, Clone)]
pub struct RecoveryPolicy {
strategies: HashMap<ErrorCode, RecoveryStrategy>,
default_strategy: RecoveryStrategy,
log_attempts: bool,
}
impl Default for RecoveryPolicy {
fn default() -> Self {
let mut strategies = HashMap::new();
strategies.insert(
ErrorCode::INTERNAL_ERROR,
RecoveryStrategy::RetryExponential {
attempts: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
multiplier: 2.0,
},
);
strategies.insert(
ErrorCode::INVALID_REQUEST,
RecoveryStrategy::RetryFixed {
attempts: 2,
delay: Duration::from_millis(500),
},
);
Self {
strategies,
default_strategy: RecoveryStrategy::FailFast,
log_attempts: true,
}
}
}
impl RecoveryPolicy {
pub fn new(default_strategy: RecoveryStrategy) -> Self {
Self {
strategies: HashMap::new(),
default_strategy,
log_attempts: true,
}
}
pub fn add_strategy(&mut self, error_code: ErrorCode, strategy: RecoveryStrategy) {
self.strategies.insert(error_code, strategy);
}
pub fn get_strategy(&self, error_code: &ErrorCode) -> &RecoveryStrategy {
self.strategies
.get(error_code)
.unwrap_or(&self.default_strategy)
}
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
pub trait RecoveryHandler: Send + Sync {
async fn recover(&self, error_msg: &str) -> Result<serde_json::Value>;
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
pub trait RecoveryHandler {
async fn recover(&self, error_msg: &str) -> Result<serde_json::Value>;
}
#[derive(Debug)]
pub struct DefaultRecoveryHandler;
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
impl RecoveryHandler for DefaultRecoveryHandler {
async fn recover(&self, error_msg: &str) -> Result<serde_json::Value> {
Err(Error::internal(error_msg))
}
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
impl RecoveryHandler for DefaultRecoveryHandler {
async fn recover(&self, error_msg: &str) -> Result<serde_json::Value> {
Err(Error::internal(error_msg))
}
}
pub struct FallbackHandler<F> {
fallback: F,
}
impl<F> std::fmt::Debug for FallbackHandler<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FallbackHandler")
.field("fallback", &"<function>")
.finish()
}
}
impl<F> FallbackHandler<F> {
pub fn new(fallback: F) -> Self {
Self { fallback }
}
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
impl<F, Fut> RecoveryHandler for FallbackHandler<F>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<serde_json::Value>> + Send,
{
async fn recover(&self, _error_msg: &str) -> Result<serde_json::Value> {
(self.fallback)().await
}
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
impl<F, Fut> RecoveryHandler for FallbackHandler<F>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
async fn recover(&self, _error_msg: &str) -> Result<serde_json::Value> {
(self.fallback)().await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CircuitState {
Closed,
Open,
HalfOpen,
}
pub struct CircuitBreaker {
state: Arc<RwLock<CircuitState>>,
failure_count: Arc<RwLock<u32>>,
success_count: Arc<RwLock<u32>>,
last_failure_time: Arc<RwLock<Option<std::time::Instant>>>,
config: CircuitBreakerConfig,
}
impl std::fmt::Debug for CircuitBreaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreaker")
.field("state", &"Arc<RwLock<CircuitState>>")
.field("failure_count", &"Arc<RwLock<u32>>")
.field("success_count", &"Arc<RwLock<u32>>")
.field("last_failure_time", &"Arc<RwLock<Option<Instant>>>")
.field("config", &self.config)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: u32,
pub success_threshold: u32,
pub timeout: Duration,
}
impl CircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: Arc::new(RwLock::new(CircuitState::Closed)),
failure_count: Arc::new(RwLock::new(0)),
success_count: Arc::new(RwLock::new(0)),
last_failure_time: Arc::new(RwLock::new(None)),
config,
}
}
pub async fn allow_request(&self) -> bool {
let state = *self.state.read().await;
match state {
CircuitState::Closed => true,
CircuitState::Open => {
let last_failure_opt = *self.last_failure_time.read().await;
if let Some(last_failure) = last_failure_opt {
if last_failure.elapsed() >= self.config.timeout {
*self.state.write().await = CircuitState::HalfOpen;
*self.success_count.write().await = 0;
info!("Circuit breaker transitioning to half-open");
true
} else {
false
}
} else {
false
}
},
CircuitState::HalfOpen => true,
}
}
pub async fn record_success(&self) {
let state = *self.state.read().await;
match state {
CircuitState::Closed => {
*self.failure_count.write().await = 0;
},
CircuitState::HalfOpen => {
let mut success_count = self.success_count.write().await;
*success_count += 1;
if *success_count >= self.config.success_threshold {
*self.state.write().await = CircuitState::Closed;
*self.failure_count.write().await = 0;
info!("Circuit breaker closed after successful recovery");
}
},
CircuitState::Open => {
*self.state.write().await = CircuitState::Closed;
*self.failure_count.write().await = 0;
},
}
}
pub async fn record_failure(&self) {
let state = *self.state.read().await;
match state {
CircuitState::Closed => {
let mut failure_count = self.failure_count.write().await;
*failure_count += 1;
if *failure_count >= self.config.failure_threshold {
*self.state.write().await = CircuitState::Open;
*self.last_failure_time.write().await = Some(std::time::Instant::now());
warn!("Circuit breaker opened after {} failures", *failure_count);
}
},
CircuitState::HalfOpen => {
*self.state.write().await = CircuitState::Open;
*self.last_failure_time.write().await = Some(std::time::Instant::now());
*self.failure_count.write().await = 1;
warn!("Circuit breaker reopened after failure in half-open state");
},
CircuitState::Open => {
*self.last_failure_time.write().await = Some(std::time::Instant::now());
},
}
}
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
pub trait HealthMonitor: Send + Sync {
async fn check_health(&self, component: &str) -> HealthCheckResult;
async fn get_health_status(&self, component: &str) -> HealthStatus;
async fn subscribe_health_changes(
&self,
) -> Result<Box<dyn Future<Output = RecoveryEvent> + Send + Unpin>>;
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
pub trait HealthMonitor {
async fn check_health(&self, component: &str) -> HealthCheckResult;
async fn get_health_status(&self, component: &str) -> HealthStatus;
async fn subscribe_health_changes(
&self,
) -> Result<Box<dyn Future<Output = RecoveryEvent> + Unpin>>;
}
#[cfg(not(target_arch = "wasm32"))]
type EventHandlers = Arc<RwLock<Vec<Arc<dyn Fn(RecoveryEvent) + Send + Sync>>>>;
#[cfg(target_arch = "wasm32")]
type EventHandlers = Arc<RwLock<Vec<Arc<dyn Fn(RecoveryEvent)>>>>;
pub struct RecoveryCoordinator {
health_monitor: Option<Arc<dyn HealthMonitor>>,
dependencies: Arc<RwLock<HashMap<String, Vec<String>>>>,
metrics: Arc<RecoveryMetrics>,
event_handlers: EventHandlers,
}
impl std::fmt::Debug for RecoveryCoordinator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecoveryCoordinator")
.field("health_monitor", &self.health_monitor.is_some())
.field("dependencies", &"Arc<RwLock<HashMap<...>>>")
.field("metrics", &self.metrics)
.field("event_handlers", &"Arc<RwLock<Vec<...>>>")
.finish()
}
}
impl RecoveryCoordinator {
pub fn new() -> Self {
Self {
health_monitor: None,
dependencies: Arc::new(RwLock::new(HashMap::new())),
metrics: Arc::new(RecoveryMetrics::new()),
event_handlers: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn with_health_monitor(mut self, monitor: Arc<dyn HealthMonitor>) -> Self {
self.health_monitor = Some(monitor);
self
}
pub async fn add_dependency(&self, component: String, dependencies: Vec<String>) {
self.dependencies
.write()
.await
.insert(component, dependencies);
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn add_event_handler(&self, handler: Arc<dyn Fn(RecoveryEvent) + Send + Sync>) {
self.event_handlers.write().await.push(handler);
}
#[cfg(target_arch = "wasm32")]
pub async fn add_event_handler(&self, handler: Arc<dyn Fn(RecoveryEvent)>) {
self.event_handlers.write().await.push(handler);
}
pub async fn emit_event(&self, event: RecoveryEvent) {
let handlers = self.event_handlers.read().await;
for handler in handlers.iter() {
handler(event.clone());
}
}
pub async fn detect_cascade(&self, failed_component: &str) -> Vec<String> {
let mut affected = Vec::new();
let dependencies = self.dependencies.read().await;
for (component, deps) in dependencies.iter() {
if deps.contains(&failed_component.to_string()) {
affected.push(component.clone());
}
}
if !affected.is_empty() {
self.emit_event(RecoveryEvent::CascadingFailure {
trigger_component: failed_component.to_string(),
affected_components: affected.clone(),
})
.await;
self.metrics.record_cascade_prevention();
}
affected
}
pub fn get_metrics(&self) -> Arc<RecoveryMetrics> {
self.metrics.clone()
}
}
impl Default for RecoveryCoordinator {
fn default() -> Self {
Self::new()
}
}
pub struct BulkRecoveryHandler {
coordinator: Arc<RecoveryCoordinator>,
operation_timeout: Duration,
}
impl std::fmt::Debug for BulkRecoveryHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BulkRecoveryHandler")
.field("coordinator", &self.coordinator)
.field("operation_timeout", &self.operation_timeout)
.finish()
}
}
impl BulkRecoveryHandler {
pub fn new(coordinator: Arc<RecoveryCoordinator>, operation_timeout: Duration) -> Self {
Self {
coordinator,
operation_timeout,
}
}
pub async fn execute_bulk<F, Fut, T>(
&self,
operations: Vec<F>,
min_success_rate: f64,
fail_fast: bool,
) -> BulkRecoveryResult
where
F: Fn() -> Fut + Send,
Fut: Future<Output = Result<T>> + Send,
T: serde::Serialize + Send,
{
let total_operations = operations.len();
let mut successes = Vec::new();
let mut failures = Vec::new();
for (index, operation) in operations.into_iter().enumerate() {
let start_time = Instant::now();
match operation().await {
Ok(result) => {
let value = serde_json::to_value(result).unwrap_or(serde_json::Value::Null);
successes.push((index, value));
self.coordinator
.metrics
.record_success(start_time.elapsed());
},
Err(error) => {
failures.push((index, error));
self.coordinator
.metrics
.record_failure(start_time.elapsed());
if fail_fast {
break;
}
},
}
}
let success_rate = successes.len() as f64 / total_operations as f64;
if successes.is_empty() {
BulkRecoveryResult::AllFailed(failures.into_iter().map(|(_, e)| e).collect())
} else if failures.is_empty() {
BulkRecoveryResult::AllSuccess(successes.into_iter().map(|(_, v)| v).collect())
} else if success_rate >= min_success_rate {
BulkRecoveryResult::PartialSuccess {
successes,
failures,
}
} else {
BulkRecoveryResult::AllFailed(failures.into_iter().map(|(_, e)| e).collect())
}
}
}
#[derive(Debug)]
pub struct JitterCalculator;
impl JitterCalculator {
#[allow(clippy::cast_sign_loss, clippy::too_many_arguments)]
pub fn calculate_delay(base_delay: Duration, strategy: JitterStrategy) -> Duration {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let base_millis = base_delay.as_millis() as f64;
match strategy {
JitterStrategy::None => base_delay,
JitterStrategy::Full => {
let mut hasher = DefaultHasher::new();
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos()
.hash(&mut hasher);
let random = (hasher.finish() % 1000) as f64 / 1000.0;
Duration::from_millis((base_millis * random).max(0.0) as u64)
},
JitterStrategy::Equal => {
let mut hasher = DefaultHasher::new();
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos()
.hash(&mut hasher);
let random = (hasher.finish() % 1000) as f64 / 1000.0;
let jittered_millis = base_millis.mul_add(0.5, base_millis * 0.5 * random);
Duration::from_millis(jittered_millis.max(0.0) as u64)
},
JitterStrategy::Decorrelated => {
let mut hasher = DefaultHasher::new();
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos()
.hash(&mut hasher);
let random = ((hasher.finish() % 1000) as f64 / 1000.0 - 0.5) * 0.5; let jittered_millis = base_millis * (1.0 + random);
Duration::from_millis(jittered_millis.max(0.0) as u64)
},
}
}
}
pub struct AdvancedRecoveryExecutor {
policy: RecoveryPolicy,
handlers: HashMap<String, Arc<dyn RecoveryHandler>>,
#[allow(dead_code)]
circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
coordinator: Arc<RecoveryCoordinator>,
bulk_handler: Arc<BulkRecoveryHandler>,
}
impl std::fmt::Debug for AdvancedRecoveryExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdvancedRecoveryExecutor")
.field("policy", &self.policy)
.field("handlers", &self.handlers.keys().collect::<Vec<_>>())
.field("circuit_breakers", &"Arc<RwLock<HashMap<...>>>")
.field("coordinator", &self.coordinator)
.field("bulk_handler", &self.bulk_handler)
.finish()
}
}
impl AdvancedRecoveryExecutor {
pub fn new(policy: RecoveryPolicy) -> Self {
let coordinator = Arc::new(RecoveryCoordinator::new());
let bulk_handler = Arc::new(BulkRecoveryHandler::new(
coordinator.clone(),
Duration::from_secs(30),
));
Self {
policy,
handlers: HashMap::new(),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
coordinator,
bulk_handler,
}
}
pub fn coordinator(&self) -> Arc<RecoveryCoordinator> {
self.coordinator.clone()
}
pub fn bulk_handler(&self) -> Arc<BulkRecoveryHandler> {
self.bulk_handler.clone()
}
#[allow(clippy::too_many_arguments)]
pub async fn retry_adaptive<F, Fut>(
&self,
error: Error,
attempts: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
jitter: JitterStrategy,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let mut last_error = error;
let mut current_delay = initial_delay;
for attempt in 1..=attempts {
let jittered_delay = JitterCalculator::calculate_delay(current_delay, jitter);
if self.policy.log_attempts {
debug!(
"Adaptive retry attempt {} of {} after {:?} (jitter: {:?})",
attempt, attempts, jittered_delay, jitter
);
}
runtime::sleep(jittered_delay).await;
match operation().await {
Ok(result) => {
self.coordinator.metrics.record_success(jittered_delay);
return Ok(result);
},
Err(e) => {
last_error = e;
self.coordinator.metrics.record_failure(jittered_delay);
if self.policy.log_attempts {
warn!("Adaptive retry attempt {} failed: {}", attempt, last_error);
}
let next_delay = Duration::from_secs_f64(
(current_delay.as_secs_f64() * multiplier).min(max_delay.as_secs_f64()),
);
current_delay = next_delay;
},
}
}
Err(last_error)
}
pub async fn execute_with_deadline<F, Fut>(
&self,
operation_id: &str,
operation: F,
deadline: &mut RecoveryDeadline,
base_strategy: &RecoveryStrategy,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut + Clone,
Fut: Future<Output = Result<serde_json::Value>>,
{
if deadline.update() {
return Err(Error::Timeout(0));
}
self.coordinator
.emit_event(RecoveryEvent::RecoveryStarted {
operation_id: operation_id.to_string(),
strategy: "deadline_aware".to_string(),
})
.await;
let start_time = Instant::now();
#[cfg(not(target_arch = "wasm32"))]
let timeout_result = tokio::time::timeout(deadline.remaining, operation()).await;
#[cfg(target_arch = "wasm32")]
let timeout_result: std::result::Result<Result<serde_json::Value>, ()> = {
let res: Result<serde_json::Value> = operation().await;
match res {
Ok(value) => Ok(Ok(value)),
Err(e) => Ok(Err(e)),
}
};
let result = match timeout_result {
Ok(Ok(value)) => Ok(value),
Ok(Err(error)) => Err(error),
Err(_) => {
deadline.update();
return Err(Error::Timeout(deadline.remaining.as_millis() as u64));
},
};
let result = match result {
Ok(value) => {
let duration = start_time.elapsed();
self.coordinator
.emit_event(RecoveryEvent::RecoveryCompleted {
operation_id: operation_id.to_string(),
success: true,
duration,
})
.await;
Ok(value)
},
Err(error) => {
deadline.update();
if !deadline.has_time_for(Duration::from_millis(100)) {
let duration = start_time.elapsed();
self.coordinator
.emit_event(RecoveryEvent::RecoveryCompleted {
operation_id: operation_id.to_string(),
success: false,
duration,
})
.await;
return Err(Error::Timeout(duration.as_millis() as u64));
}
match base_strategy {
RecoveryStrategy::RetryFixed { attempts, delay } => {
let max_attempts = (deadline.remaining.as_millis() / delay.as_millis())
.min(*attempts as u128)
as u32;
self.retry_fixed(error, max_attempts, *delay, operation.clone())
.await
},
_ => Err(error), }
},
};
result
}
pub fn register_handler(&mut self, name: String, handler: Arc<dyn RecoveryHandler>) {
self.handlers.insert(name, handler);
}
async fn retry_fixed<F, Fut>(
&self,
error: Error,
attempts: u32,
delay: Duration,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let mut last_error = error;
for attempt in 1..=attempts {
if self.policy.log_attempts {
debug!(
"Retry attempt {} of {} after {:?}",
attempt, attempts, delay
);
}
runtime::sleep(delay).await;
match operation().await {
Ok(result) => {
self.coordinator.metrics.record_success(delay);
return Ok(result);
},
Err(e) => {
last_error = e;
self.coordinator.metrics.record_failure(delay);
if self.policy.log_attempts {
warn!("Retry attempt {} failed: {}", attempt, last_error);
}
},
}
}
Err(last_error)
}
}
pub struct RecoveryExecutor {
policy: RecoveryPolicy,
handlers: HashMap<String, Arc<dyn RecoveryHandler>>,
#[allow(dead_code)]
circuit_breakers: Arc<RwLock<HashMap<String, Arc<CircuitBreaker>>>>,
}
impl std::fmt::Debug for RecoveryExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecoveryExecutor")
.field("policy", &self.policy)
.field("handlers", &self.handlers.keys().collect::<Vec<_>>())
.field("circuit_breakers", &"Arc<RwLock<HashMap<...>>>")
.finish()
}
}
impl RecoveryExecutor {
pub fn new(policy: RecoveryPolicy) -> Self {
Self {
policy,
handlers: HashMap::new(),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_handler(&mut self, name: String, handler: Arc<dyn RecoveryHandler>) {
self.handlers.insert(name, handler);
}
pub async fn execute_with_recovery<F, Fut>(
&self,
operation_id: &str,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
match operation().await {
Ok(result) => {
if let Some(cb) = self.circuit_breakers.read().await.get(operation_id) {
cb.record_success().await;
}
Ok(result)
},
Err(error) => {
let error_code = error.error_code().unwrap_or(ErrorCode::INTERNAL_ERROR);
let strategy = self.policy.get_strategy(&error_code);
match strategy {
RecoveryStrategy::RetryFixed { attempts, delay } => {
self.retry_fixed(error, *attempts, *delay, operation).await
},
RecoveryStrategy::RetryExponential {
attempts,
initial_delay,
max_delay,
multiplier,
} => {
self.retry_exponential(
error,
*attempts,
*initial_delay,
*max_delay,
*multiplier,
operation,
)
.await
},
RecoveryStrategy::Fallback => {
self.fallback(&error.to_string(), operation_id).await
},
RecoveryStrategy::CircuitBreaker {
failure_threshold,
success_threshold,
timeout,
} => {
self.circuit_breaker(
error,
operation_id,
*failure_threshold,
*success_threshold,
*timeout,
operation,
)
.await
},
RecoveryStrategy::FailFast => Err(error),
RecoveryStrategy::RetryAdaptive {
attempts,
initial_delay,
max_delay,
multiplier,
..
} => {
self.retry_exponential(
error,
*attempts,
*initial_delay,
*max_delay,
*multiplier,
operation,
)
.await
},
RecoveryStrategy::AdvancedCircuitBreaker {
failure_threshold,
success_threshold,
timeout,
..
} => {
self.circuit_breaker(
error,
operation_id,
*failure_threshold,
*success_threshold,
*timeout,
operation,
)
.await
},
RecoveryStrategy::DeadlineAware { base_strategy, .. } => {
match base_strategy.as_ref() {
RecoveryStrategy::RetryFixed { attempts, delay } => {
self.retry_fixed(error, *attempts, *delay, operation).await
},
_ => Err(error), }
},
RecoveryStrategy::BulkRecovery { .. } => {
Err(error)
},
RecoveryStrategy::CascadeAware { base_strategy, .. } => {
match base_strategy.as_ref() {
RecoveryStrategy::RetryFixed { attempts, delay } => {
self.retry_fixed(error, *attempts, *delay, operation).await
},
_ => Err(error), }
},
}
},
}
}
async fn retry_fixed<F, Fut>(
&self,
error: Error,
attempts: u32,
delay: Duration,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let mut last_error = error;
for attempt in 1..=attempts {
if self.policy.log_attempts {
debug!(
"Retry attempt {} of {} after {:?}",
attempt, attempts, delay
);
}
runtime::sleep(delay).await;
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = e;
if self.policy.log_attempts {
warn!("Retry attempt {} failed: {}", attempt, last_error);
}
},
}
}
Err(last_error)
}
async fn retry_exponential<F, Fut>(
&self,
error: Error,
attempts: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let mut last_error = error;
let mut current_delay = initial_delay;
for attempt in 1..=attempts {
if self.policy.log_attempts {
debug!(
"Exponential retry attempt {} of {} after {:?}",
attempt, attempts, current_delay
);
}
runtime::sleep(current_delay).await;
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = e;
if self.policy.log_attempts {
warn!(
"Exponential retry attempt {} failed: {}",
attempt, last_error
);
}
let next_delay = Duration::from_secs_f64(
(current_delay.as_secs_f64() * multiplier).min(max_delay.as_secs_f64()),
);
current_delay = next_delay;
},
}
}
Err(last_error)
}
async fn fallback(&self, error_msg: &str, operation_id: &str) -> Result<serde_json::Value> {
if let Some(handler) = self.handlers.get(operation_id) {
if self.policy.log_attempts {
info!("Using fallback handler for operation: {}", operation_id);
}
handler.recover(error_msg).await
} else {
if self.policy.log_attempts {
error!(
"No fallback handler registered for operation: {}",
operation_id
);
}
Err(Error::internal(error_msg))
}
}
async fn circuit_breaker<F, Fut>(
&self,
_error: Error,
operation_id: &str,
failure_threshold: u32,
success_threshold: u32,
timeout: Duration,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let cb = {
let mut breakers = self.circuit_breakers.write().await;
breakers
.entry(operation_id.to_string())
.or_insert_with(|| {
Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold,
success_threshold,
timeout,
}))
})
.clone()
};
if !cb.allow_request().await {
if self.policy.log_attempts {
warn!("Circuit breaker open for operation: {}", operation_id);
}
return Err(Error::protocol(
ErrorCode::INTERNAL_ERROR,
"Circuit breaker is open",
));
}
match operation().await {
Ok(result) => {
cb.record_success().await;
Ok(result)
},
Err(e) => {
cb.record_failure().await;
Err(e)
},
}
}
}
pub async fn with_retry<F, Fut>(
attempts: u32,
delay: Duration,
operation: F,
) -> Result<serde_json::Value>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<serde_json::Value>>,
{
let mut last_error = None;
for attempt in 0..attempts {
if attempt > 0 {
runtime::sleep(delay).await;
}
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
},
}
}
Err(last_error.unwrap_or_else(|| Error::internal("No attempts made")))
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_retry_fixed() {
let policy = RecoveryPolicy::default();
let executor = RecoveryExecutor::new(policy);
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = executor
.retry_fixed(
Error::internal("test"),
3,
Duration::from_millis(10),
|| {
let count = attempt_count_clone.fetch_add(1, Ordering::Relaxed);
async move {
if count < 2 {
Err(Error::internal("retry"))
} else {
Ok(serde_json::json!({"success": true}))
}
}
},
)
.await;
assert!(result.is_ok());
assert_eq!(attempt_count.load(Ordering::Relaxed), 3);
}
#[tokio::test]
async fn test_circuit_breaker() {
let config = CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 2,
timeout: Duration::from_millis(100),
};
let cb = CircuitBreaker::new(config);
assert!(cb.allow_request().await);
cb.record_failure().await;
cb.record_failure().await;
assert!(!cb.allow_request().await);
runtime::sleep(Duration::from_millis(150)).await;
assert!(cb.allow_request().await);
cb.record_success().await;
cb.record_success().await;
assert!(cb.allow_request().await);
}
#[tokio::test]
async fn test_with_retry() {
let attempt_count = Arc::new(AtomicU32::new(0));
let attempt_count_clone = attempt_count.clone();
let result = with_retry(3, Duration::from_millis(10), || {
let count = attempt_count_clone.fetch_add(1, Ordering::Relaxed);
async move {
if count < 2 {
Err(Error::internal("retry"))
} else {
Ok(serde_json::json!({"success": true}))
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(attempt_count.load(Ordering::Relaxed), 3);
}
}