use actix_web::HttpResponse;
use actix_web::http::StatusCode;
use std::time::Duration;
use tracing::{debug, warn};
use crate::api::response::{bad_gateway, internal_error, service_unavailable};
use crate::drivers::scylla::health::HostOffline;
use crate::error::{ErrorCategory, ProcessedError, generate_trace_id};
#[derive(Clone, Debug)]
pub struct ResiliencePolicy {
pub operation_timeout: Duration,
pub max_retries: u32,
pub initial_backoff: Duration,
pub retry_enabled: bool,
}
impl Default for ResiliencePolicy {
fn default() -> Self {
Self {
operation_timeout: Duration::from_secs(30),
max_retries: 0,
initial_backoff: Duration::from_millis(100),
retry_enabled: false,
}
}
}
impl ResiliencePolicy {
pub fn for_reads() -> Self {
Self {
operation_timeout: Duration::from_secs(30),
max_retries: 1,
initial_backoff: Duration::from_millis(100),
retry_enabled: true,
}
}
pub fn for_writes() -> Self {
Self::default()
}
pub fn from_config_reads(timeout_secs: u64, max_retries: u32, initial_backoff_ms: u64) -> Self {
Self {
operation_timeout: Duration::from_secs(timeout_secs),
max_retries,
initial_backoff: Duration::from_millis(initial_backoff_ms),
retry_enabled: max_retries > 0,
}
}
pub fn from_config_writes(timeout_secs: u64) -> Self {
Self {
operation_timeout: Duration::from_secs(timeout_secs),
max_retries: 0,
initial_backoff: Duration::from_millis(100),
retry_enabled: false,
}
}
}
#[derive(Debug, Clone)]
pub enum ClassifiedError {
Offline { host: String, until_secs: u64 },
Timeout { message: String },
Transient { message: String },
NonTransient { message: String },
Upstream { message: String },
}
impl ClassifiedError {
pub fn to_response(&self, trace_id: &str) -> HttpResponse {
match self {
ClassifiedError::Offline { host, until_secs } => {
let msg = format!(
"Backend '{}' temporarily unavailable; retry after {}s",
host, until_secs
);
service_unavailable(&msg, msg.clone())
}
ClassifiedError::Timeout { message } => {
let processed = ProcessedError::new(
ErrorCategory::DatabaseConnection,
StatusCode::GATEWAY_TIMEOUT,
"gateway_timeout",
format!("Backend operation timed out: {}", message),
trace_id.to_string(),
)
.with_metadata("trace_id", serde_json::json!(trace_id));
HttpResponse::build(StatusCode::GATEWAY_TIMEOUT).json(processed.to_json())
}
ClassifiedError::Transient { message } => bad_gateway(
"Transient backend failure",
format!("{} (trace_id: {})", message, trace_id),
),
ClassifiedError::NonTransient { message } => internal_error(
"Backend error",
format!("{} (trace_id: {})", message, trace_id),
),
ClassifiedError::Upstream { message } => bad_gateway(
"Upstream backend failure",
format!("{} (trace_id: {})", message, trace_id),
),
}
}
}
pub fn classify_error(err: &(dyn std::error::Error + 'static)) -> (bool, ClassifiedError) {
let msg = err.to_string();
let msg_lower = msg.to_lowercase();
if let Some(offline) = err.downcast_ref::<HostOffline>() {
let remaining = offline
.until()
.checked_duration_since(std::time::Instant::now())
.unwrap_or(Duration::ZERO);
return (
false,
ClassifiedError::Offline {
host: offline.host().to_string(),
until_secs: remaining.as_secs().max(1),
},
);
}
if msg_lower.contains("timed out")
|| msg_lower.contains("timeout")
|| msg_lower.contains("deadline exceeded")
|| msg_lower.contains("elapsed")
{
return (
true,
ClassifiedError::Timeout {
message: msg.to_string(),
},
);
}
if msg_lower.contains("connection")
|| msg_lower.contains("pool")
|| msg_lower.contains("checkout")
|| msg_lower.contains("broken pipe")
|| msg_lower.contains("connection reset")
|| msg_lower.contains("refused")
{
return (
true,
ClassifiedError::Transient {
message: msg.to_string(),
},
);
}
if msg_lower.contains("unique")
|| msg_lower.contains("constraint")
|| msg_lower.contains("23505")
{
return (
false,
ClassifiedError::NonTransient {
message: msg.to_string(),
},
);
}
(
false,
ClassifiedError::NonTransient {
message: msg.to_string(),
},
)
}
pub async fn with_timeout<T, F, E>(timeout_duration: Duration, fut: F) -> Result<T, ClassifiedError>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::error::Error + Send + Sync + 'static,
{
match tokio::time::timeout(timeout_duration, fut).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(e)) => {
let (_, classified) = classify_error(&e);
Err(classified)
}
Err(_) => Err(ClassifiedError::Timeout {
message: format!("Operation exceeded {}s", timeout_duration.as_secs()),
}),
}
}
pub async fn execute_with_retry<T, F, E>(
policy: &ResiliencePolicy,
mut operation: impl FnMut() -> F,
) -> Result<T, ClassifiedError>
where
F: std::future::Future<Output = Result<T, E>>,
E: std::error::Error + Send + Sync + 'static,
{
let mut attempt: u32 = 0;
loop {
match with_timeout(policy.operation_timeout, operation()).await {
Ok(value) => return Ok(value),
Err(classified) => {
match &classified {
ClassifiedError::Offline { .. }
| ClassifiedError::NonTransient { .. }
| ClassifiedError::Upstream { .. } => return Err(classified),
ClassifiedError::Timeout { .. } | ClassifiedError::Transient { .. } => {}
}
if !policy.retry_enabled || attempt >= policy.max_retries {
return Err(classified);
}
attempt += 1;
let backoff: Duration = policy.initial_backoff * attempt;
debug!(
attempt = attempt,
max_retries = policy.max_retries,
backoff_ms = backoff.as_millis(),
"Retrying after transient failure"
);
tokio::time::sleep(backoff).await;
}
}
}
}
pub fn gateway_error_response(
classified: &ClassifiedError,
operation: &str,
backend: &str,
) -> HttpResponse {
let trace_id = generate_trace_id();
warn!(
trace_id = %trace_id,
operation = %operation,
backend = %backend,
"Gateway backend operation failed"
);
classified.to_response(&trace_id)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn classify_timeout_is_transient() {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(io::Error::new(
io::ErrorKind::TimedOut,
"operation timed out",
));
let (transient, classified) = classify_error(err.as_ref());
assert!(transient);
matches!(classified, ClassifiedError::Timeout { .. });
}
#[test]
fn classify_connection_reset_is_transient() {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
));
let (transient, classified) = classify_error(err.as_ref());
assert!(transient);
matches!(classified, ClassifiedError::Transient { .. });
}
#[test]
fn classify_unique_violation_is_non_transient() {
let err: Box<dyn std::error::Error + Send + Sync> =
Box::new(io::Error::other("duplicate key violates unique constraint"));
let (transient, _) = classify_error(err.as_ref());
assert!(!transient);
}
#[test]
fn policy_for_writes_has_no_retries() {
let p = ResiliencePolicy::for_writes();
assert!(!p.retry_enabled);
assert_eq!(p.max_retries, 0);
}
#[test]
fn policy_for_reads_allows_retry() {
let p = ResiliencePolicy::for_reads();
assert!(p.retry_enabled);
assert!(p.max_retries >= 1);
}
}