athena_rs 0.77.0

WIP Database API gateway
Documentation
//! Track Scylla host availability and expose helpers for fallback logic.
use fmt::{Debug, Display, Formatter, Result};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use std::sync::{Mutex, MutexGuard};
use std::time::{Duration, Instant};

/// Trait representing a clock used by the tracker.
pub trait Clock: Send + Sync + 'static {
    fn now(&self) -> Instant;
}

/// System clock that reads the real current time.
#[derive(Clone, Copy, Default)]
pub struct SystemClock;

impl Clock for SystemClock {
    fn now(&self) -> Instant {
        Instant::now()
    }
}

struct HostRecord {
    failures: Vec<Instant>,
    offline_until: Option<Instant>,
}

impl Default for HostRecord {
    fn default() -> Self {
        Self {
            failures: Vec::new(),
            offline_until: None,
        }
    }
}

/// Tracks host failures and offline windows.
pub struct HostHealthTracker<C: Clock> {
    clock: C,
    failure_threshold: usize,
    failure_window: Duration,
    offline_duration: Duration,
    inner: Mutex<HashMap<String, HostRecord>>,
}

impl<C: Clock> HostHealthTracker<C> {
    /// Creates a new tracker with custom thresholds.
    pub fn new(
        clock: C,
        failure_threshold: usize,
        failure_window: Duration,
        offline_duration: Duration,
    ) -> Self {
        Self {
            clock,
            failure_threshold,
            failure_window,
            offline_duration,
            inner: Mutex::new(HashMap::new()),
        }
    }

    fn prune_old(&self, record: &mut HostRecord, now: Instant) {
        record
            .failures
            .retain(|&ts| now.duration_since(ts) <= self.failure_window);
    }

    /// Returns the cached offline deadline, if the host is currently blocked.
    pub fn offline_until(&self, host: &str) -> Option<Instant> {
        let now: Instant = self.clock.now();
        let mut guard: MutexGuard<'_, HashMap<String, HostRecord>> = match self.inner.lock() {
            Ok(guard) => guard,
            Err(poisoned) => {
                tracing::warn!("HostHealthTracker mutex poisoned, recovering");
                poisoned.into_inner()
            }
        };
        if let Some(record) = guard.get_mut(host) {
            if let Some(until) = record.offline_until {
                if now >= until {
                    record.offline_until = None;
                    record.failures.clear();
                    return None;
                }
                return Some(until);
            }
        }
        None
    }

    /// Record a successful interaction and forget any failures.
    pub fn record_success(&self, host: &str) {
        let mut guard: MutexGuard<'_, HashMap<String, HostRecord>> = match self.inner.lock() {
            Ok(guard) => guard,
            Err(poisoned) => {
                tracing::warn!("HostHealthTracker mutex poisoned, recovering");
                poisoned.into_inner()
            }
        };
        guard.remove(host);
    }

    /// Record a failure. Returns the offline deadline when the host becomes blocked.
    pub fn record_failure(&self, host: &str) -> Option<Instant> {
        let now: Instant = self.clock.now();
        let mut guard: MutexGuard<'_, HashMap<String, HostRecord>> = match self.inner.lock() {
            Ok(guard) => guard,
            Err(poisoned) => {
                tracing::warn!("HostHealthTracker mutex poisoned, recovering");
                poisoned.into_inner()
            }
        };
        let record: &mut HostRecord = guard.entry(host.to_string()).or_default();

        if let Some(until) = record.offline_until {
            if now < until {
                return Some(until);
            }
            record.offline_until = None;
            record.failures.clear();
        }

        self.prune_old(record, now);
        record.failures.push(now);

        if record.failures.len() >= self.failure_threshold {
            let until: Instant = now + self.offline_duration;
            record.offline_until = Some(until);
            record.failures.clear();
            return Some(until);
        }

        None
    }

    /// Force the host offline for the given duration (testing helper).
    pub fn force_offline(&self, host: &str, duration: Duration) -> Instant {
        let now: Instant = self.clock.now();
        let deadline: Instant = now + duration;
        let mut guard: MutexGuard<'_, HashMap<String, HostRecord>> = match self.inner.lock() {
            Ok(guard) => guard,
            Err(poisoned) => {
                tracing::warn!("HostHealthTracker mutex poisoned, recovering");
                poisoned.into_inner()
            }
        };
        let record: &mut HostRecord = guard.entry(host.to_string()).or_default();
        record.failures.clear();
        record.offline_until = Some(deadline);

        deadline
    }

    /// Clear all recorded state for the host (testing helper).
    pub fn reset_host(&self, host: &str) {
        let mut guard: MutexGuard<'_, HashMap<String, HostRecord>> = match self.inner.lock() {
            Ok(guard) => guard,
            Err(poisoned) => {
                tracing::warn!("HostHealthTracker mutex poisoned, recovering");
                poisoned.into_inner()
            }
        };
        guard.remove(host);
    }
}

/// Global tracker tailored for the Scylla driver.
pub fn global_tracker() -> &'static HostHealthTracker<SystemClock> {
    static SCYLLA_TRACKER: Lazy<HostHealthTracker<SystemClock>> = Lazy::new(|| {
        HostHealthTracker::new(
            SystemClock::default(),
            5,
            Duration::from_secs(60),
            Duration::from_secs(300),
        )
    });
    &SCYLLA_TRACKER
}

/// Error returned when the host is temporarily blocked.
pub struct HostOffline {
    host: String,
    until: Instant,
}

impl HostOffline {
    pub fn new(host: String, until: Instant) -> Self {
        Self { host, until }
    }

    /// Host name that is blocked.
    pub fn host(&self) -> &str {
        &self.host
    }

    /// Deadline until the host remains offline.
    pub fn until(&self) -> Instant {
        self.until
    }
}

impl Display for HostOffline {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "Host {} is offline for another {:?}",
            self.host,
            self.until
                .checked_duration_since(Instant::now())
                .unwrap_or_else(|| Duration::from_secs(0))
        )
    }
}

impl fmt::Debug for HostOffline {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "{{ host: {}, remaining: {:?} }}",
            self.host,
            self.until
                .checked_duration_since(Instant::now())
                .unwrap_or_else(|| Duration::from_secs(0))
        )
    }
}

impl Error for HostOffline {}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Clone)]
    struct TestClock {
        now: Arc<Mutex<Instant>>,
    }

    impl TestClock {
        fn new(now: Instant) -> Self {
            Self {
                now: Arc::new(Mutex::new(now)),
            }
        }

        fn advance(&self, duration: Duration) {
            if let Ok(mut guard) = self.now.lock() {
                *guard = *guard + duration;
            }
        }
    }

    impl Clock for TestClock {
        fn now(&self) -> Instant {
            *self.now.lock().expect("TestClock mutex poisoned")
        }
    }

    fn tracker_with_clock(clock: TestClock) -> HostHealthTracker<TestClock> {
        HostHealthTracker::new(clock, 3, Duration::from_secs(60), Duration::from_secs(120))
    }

    #[test]
    fn offline_after_threshold_and_unblocks() {
        let clock: TestClock = TestClock::new(Instant::now());
        let tracker: HostHealthTracker<TestClock> = tracker_with_clock(clock.clone());
        let host: &str = "failexample";

        assert!(tracker.offline_until(host).is_none());

        for _ in 0..3 {
            let deadline: Option<Instant> = tracker.record_failure(host);
            assert!(deadline.is_none());
            clock.advance(Duration::from_secs(10));
        }

        let deadline: Option<Instant> = tracker.record_failure(host);
        assert!(matches!(deadline, Some(_)));
        assert!(tracker.offline_until(host).is_some());

        clock.advance(Duration::from_secs(121));
        assert!(tracker.offline_until(host).is_none());
    }

    #[test]
    fn force_offline_resets_state() {
        let clock: TestClock = TestClock::new(Instant::now());
        let tracker: HostHealthTracker<TestClock> = tracker_with_clock(clock.clone());
        let host: &str = "force-host";

        let until: Instant = tracker.force_offline(host, Duration::from_secs(30));
        assert_eq!(tracker.offline_until(host), Some(until));

        tracker.reset_host(host);
        assert!(tracker.offline_until(host).is_none());
    }
}