athena_rs 3.3.0

Database gateway API
Documentation
//! Shared gateway resilience helpers: timeout, transient classification, bounded retry, error mapping.
//!
//! Provides consistent handling of backend failures across gateway endpoints.
//! Write retries are disabled by default to avoid duplicate side effects.

use actix_web::HttpResponse;
use actix_web::http::StatusCode;
use std::time::Duration;
use tracing::{debug, warn};

use crate::api::response::{bad_gateway, internal_error, service_unavailable};
use crate::drivers::scylla::health::HostOffline;
use crate::error::{ErrorCategory, ProcessedError, generate_trace_id};

/// Policy for gateway backend operations.
#[derive(Clone, Debug)]
pub struct ResiliencePolicy {
    /// Maximum time for a single backend attempt.
    pub operation_timeout: Duration,
    /// Max retries for transient failures. Default 0 for writes, 1-2 for reads.
    pub max_retries: u32,
    /// Initial backoff between retries.
    pub initial_backoff: Duration,
    /// Whether retries are allowed (disabled for writes by default).
    pub retry_enabled: bool,
}

impl Default for ResiliencePolicy {
    fn default() -> Self {
        Self {
            operation_timeout: Duration::from_secs(30),
            max_retries: 0,
            initial_backoff: Duration::from_millis(100),
            retry_enabled: false,
        }
    }
}

impl ResiliencePolicy {
    /// Read-optimized policy: allow retries with short backoff.
    pub fn for_reads() -> Self {
        Self {
            operation_timeout: Duration::from_secs(30),
            max_retries: 1,
            initial_backoff: Duration::from_millis(100),
            retry_enabled: true,
        }
    }

    /// Write-safe policy: no retries.
    pub fn for_writes() -> Self {
        Self::default()
    }

    /// Build a read policy from configured timeout, max retries, and backoff.
    pub fn from_config_reads(timeout_secs: u64, max_retries: u32, initial_backoff_ms: u64) -> Self {
        Self {
            operation_timeout: Duration::from_secs(timeout_secs),
            max_retries,
            initial_backoff: Duration::from_millis(initial_backoff_ms),
            retry_enabled: max_retries > 0,
        }
    }

    /// Build a write policy from configured timeout (no retries).
    pub fn from_config_writes(timeout_secs: u64) -> Self {
        Self {
            operation_timeout: Duration::from_secs(timeout_secs),
            max_retries: 0,
            initial_backoff: Duration::from_millis(100),
            retry_enabled: false,
        }
    }
}

/// Outcome of transient vs non-transient classification.
#[derive(Debug, Clone)]
pub enum ClassifiedError {
    /// Host/backend offline (circuit breaker); return 503.
    Offline { host: String, until_secs: u64 },
    /// Timeout (pool checkout or query); may retry.
    Timeout { message: String },
    /// Connection/pool transient failure; may retry.
    Transient { message: String },
    /// Non-transient (constraint, validation, etc.); do not retry.
    NonTransient { message: String },
    /// Upstream failure (5xx from backend); return 502.
    Upstream { message: String },
}

impl ClassifiedError {
    /// Build HTTP response from classified error.
    pub fn to_response(&self, trace_id: &str) -> HttpResponse {
        match self {
            ClassifiedError::Offline { host, until_secs } => {
                let msg = format!(
                    "Backend '{}' temporarily unavailable; retry after {}s",
                    host, until_secs
                );
                service_unavailable(&msg, msg.clone())
            }
            ClassifiedError::Timeout { message } => {
                let processed = ProcessedError::new(
                    ErrorCategory::DatabaseConnection,
                    StatusCode::GATEWAY_TIMEOUT,
                    "gateway_timeout",
                    format!("Backend operation timed out: {}", message),
                    trace_id.to_string(),
                )
                .with_metadata("trace_id", serde_json::json!(trace_id));
                HttpResponse::build(StatusCode::GATEWAY_TIMEOUT).json(processed.to_json())
            }
            ClassifiedError::Transient { message } => bad_gateway(
                "Transient backend failure",
                format!("{} (trace_id: {})", message, trace_id),
            ),
            ClassifiedError::NonTransient { message } => internal_error(
                "Backend error",
                format!("{} (trace_id: {})", message, trace_id),
            ),
            ClassifiedError::Upstream { message } => bad_gateway(
                "Upstream backend failure",
                format!("{} (trace_id: {})", message, trace_id),
            ),
        }
    }
}

/// Classify an error as transient (retryable) or not.
pub fn classify_error(err: &(dyn std::error::Error + 'static)) -> (bool, ClassifiedError) {
    let msg = err.to_string();
    let msg_lower = msg.to_lowercase();

    // Circuit breaker / host offline
    if let Some(offline) = err.downcast_ref::<HostOffline>() {
        let remaining = offline
            .until()
            .checked_duration_since(std::time::Instant::now())
            .unwrap_or(Duration::ZERO);
        return (
            false,
            ClassifiedError::Offline {
                host: offline.host().to_string(),
                until_secs: remaining.as_secs().max(1),
            },
        );
    }

    // Timeout indicators
    if msg_lower.contains("timed out")
        || msg_lower.contains("timeout")
        || msg_lower.contains("deadline exceeded")
        || msg_lower.contains("elapsed")
    {
        return (
            true,
            ClassifiedError::Timeout {
                message: msg.to_string(),
            },
        );
    }

    // Connection / pool transient
    if msg_lower.contains("connection")
        || msg_lower.contains("pool")
        || msg_lower.contains("checkout")
        || msg_lower.contains("broken pipe")
        || msg_lower.contains("connection reset")
        || msg_lower.contains("refused")
    {
        return (
            true,
            ClassifiedError::Transient {
                message: msg.to_string(),
            },
        );
    }

    // Constraint / validation = non-transient
    if msg_lower.contains("unique")
        || msg_lower.contains("constraint")
        || msg_lower.contains("23505")
    {
        return (
            false,
            ClassifiedError::NonTransient {
                message: msg.to_string(),
            },
        );
    }

    // Default: treat as non-transient internal
    (
        false,
        ClassifiedError::NonTransient {
            message: msg.to_string(),
        },
    )
}

/// Run an async operation with a timeout.
pub async fn with_timeout<T, F, E>(timeout_duration: Duration, fut: F) -> Result<T, ClassifiedError>
where
    F: std::future::Future<Output = Result<T, E>>,
    E: std::error::Error + Send + Sync + 'static,
{
    match tokio::time::timeout(timeout_duration, fut).await {
        Ok(Ok(value)) => Ok(value),
        Ok(Err(e)) => {
            let (_, classified) = classify_error(&e);
            Err(classified)
        }
        Err(_) => Err(ClassifiedError::Timeout {
            message: format!("Operation exceeded {}s", timeout_duration.as_secs()),
        }),
    }
}

/// Execute an operation with optional retries for transient failures.
pub async fn execute_with_retry<T, F, E>(
    policy: &ResiliencePolicy,
    mut operation: impl FnMut() -> F,
) -> Result<T, ClassifiedError>
where
    F: std::future::Future<Output = Result<T, E>>,
    E: std::error::Error + Send + Sync + 'static,
{
    let mut attempt: u32 = 0;

    loop {
        match with_timeout(policy.operation_timeout, operation()).await {
            Ok(value) => return Ok(value),
            Err(classified) => {
                match &classified {
                    ClassifiedError::Offline { .. }
                    | ClassifiedError::NonTransient { .. }
                    | ClassifiedError::Upstream { .. } => return Err(classified),
                    ClassifiedError::Timeout { .. } | ClassifiedError::Transient { .. } => {}
                }

                if !policy.retry_enabled || attempt >= policy.max_retries {
                    return Err(classified);
                }

                attempt += 1;
                let backoff: Duration = policy.initial_backoff * attempt;
                debug!(
                    attempt = attempt,
                    max_retries = policy.max_retries,
                    backoff_ms = backoff.as_millis(),
                    "Retrying after transient failure"
                );
                tokio::time::sleep(backoff).await;
            }
        }
    }
}

/// Build a standard gateway error response from a classified error.
pub fn gateway_error_response(
    classified: &ClassifiedError,
    operation: &str,
    backend: &str,
) -> HttpResponse {
    let trace_id = generate_trace_id();
    warn!(
        trace_id = %trace_id,
        operation = %operation,
        backend = %backend,
        "Gateway backend operation failed"
    );
    classified.to_response(&trace_id)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io;

    #[test]
    fn classify_timeout_is_transient() {
        let err: Box<dyn std::error::Error + Send + Sync> = Box::new(io::Error::new(
            io::ErrorKind::TimedOut,
            "operation timed out",
        ));
        let (transient, classified) = classify_error(err.as_ref());
        assert!(transient);
        matches!(classified, ClassifiedError::Timeout { .. });
    }

    #[test]
    fn classify_connection_reset_is_transient() {
        let err: Box<dyn std::error::Error + Send + Sync> = Box::new(io::Error::new(
            io::ErrorKind::ConnectionReset,
            "connection reset by peer",
        ));
        let (transient, classified) = classify_error(err.as_ref());
        assert!(transient);
        matches!(classified, ClassifiedError::Transient { .. });
    }

    #[test]
    fn classify_unique_violation_is_non_transient() {
        let err: Box<dyn std::error::Error + Send + Sync> =
            Box::new(io::Error::other("duplicate key violates unique constraint"));
        let (transient, _) = classify_error(err.as_ref());
        assert!(!transient);
    }

    #[test]
    fn policy_for_writes_has_no_retries() {
        let p = ResiliencePolicy::for_writes();
        assert!(!p.retry_enabled);
        assert_eq!(p.max_retries, 0);
    }

    #[test]
    fn policy_for_reads_allows_retry() {
        let p = ResiliencePolicy::for_reads();
        assert!(p.retry_enabled);
        assert!(p.max_retries >= 1);
    }
}