athena_rs 3.3.0

Database gateway API
Documentation
use once_cell::sync::{Lazy, OnceCell};
use redis::{Client, Commands};
use serde_json::Value;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::task;
use tracing::{info, warn};

pub struct RedisCacheClient {
    client: Client,
}

impl RedisCacheClient {
    pub fn from_url(url: &str) -> Result<Self, String> {
        let client: Client =
            Client::open(url).map_err(|err| format!("invalid redis URL: {err}"))?;
        Ok(Self { client })
    }

    pub async fn get(&self, key: &str) -> Result<Value, String> {
        let client: Client = self.client.clone();
        let redis_key: String = key.to_string();

        task::spawn_blocking(move || {
            let mut conn = client
                .get_connection()
                .map_err(|err| format!("redis connection failed: {err}"))?;

            let raw: Option<String> = conn
                .get(redis_key)
                .map_err(|err| format!("redis get failed: {err}"))?;

            let Some(raw) = raw else {
                return Ok(Value::Null);
            };

            serde_json::from_str::<Value>(&raw)
                .map_err(|err| format!("invalid JSON payload from redis: {err}"))
        })
        .await
        .map_err(|err| format!("redis task join error: {err}"))?
    }

    pub async fn set_with_ttl(
        &self,
        key: &str,
        value: &Value,
        ttl_secs: u64,
    ) -> Result<(), String> {
        let client: Client = self.client.clone();
        let redis_key: String = key.to_string();
        let payload: String = serde_json::to_string(value)
            .map_err(|err| format!("failed to serialize JSON for redis: {err}"))?;

        task::spawn_blocking(move || {
            let mut conn = client
                .get_connection()
                .map_err(|err| format!("redis connection failed: {err}"))?;

            let _: () = conn
                .set_ex(redis_key, payload, ttl_secs)
                .map_err(|err| format!("redis set_ex failed: {err}"))?;
            Ok(())
        })
        .await
        .map_err(|err| format!("redis task join error: {err}"))?
    }

    pub async fn increment_counter_with_ttl(
        &self,
        key: &str,
        ttl_secs: u64,
    ) -> Result<u64, String> {
        let client: Client = self.client.clone();
        let redis_key: String = key.to_string();

        task::spawn_blocking(move || {
            let mut conn = client
                .get_connection()
                .map_err(|err| format!("redis connection failed: {err}"))?;
            let count: i64 = conn
                .incr(&redis_key, 1_i64)
                .map_err(|err| format!("redis incr failed: {err}"))?;
            if count == 1 {
                let _: bool = conn
                    .expire(&redis_key, ttl_secs as i64)
                    .map_err(|err| format!("redis expire failed: {err}"))?;
            }
            Ok(count.max(0) as u64)
        })
        .await
        .map_err(|err| format!("redis task join error: {err}"))?
    }

    pub fn increment_counter_with_ttl_blocking(
        &self,
        key: &str,
        ttl_secs: u64,
    ) -> Result<u64, String> {
        let mut conn = self
            .client
            .get_connection()
            .map_err(|err| format!("redis connection failed: {err}"))?;
        let count: i64 = conn
            .incr(key, 1_i64)
            .map_err(|err| format!("redis incr failed: {err}"))?;
        if count == 1 {
            let _: bool = conn
                .expire(key, ttl_secs as i64)
                .map_err(|err| format!("redis expire failed: {err}"))?;
        }
        Ok(count.max(0) as u64)
    }
}

pub static GLOBAL_REDIS: OnceCell<RedisCacheClient> = OnceCell::new();

static REDIS_OPERATION_TIMEOUT: Lazy<Duration> = Lazy::new(|| {
    let timeout_ms: u64 = std::env::var("ATHENA_REDIS_OP_TIMEOUT_MS")
        .ok()
        .and_then(|value| value.parse::<u64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(30);
    Duration::from_millis(timeout_ms)
});

static REDIS_COOLDOWN_UNTIL: Lazy<Mutex<Option<Instant>>> = Lazy::new(|| Mutex::new(None));
static REDIS_COOLDOWN_MS: Lazy<u64> = Lazy::new(|| {
    std::env::var("ATHENA_REDIS_COOLDOWN_MS")
        .ok()
        .and_then(|value| value.parse::<u64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(5_000)
});

pub fn redis_operation_timeout() -> Duration {
    *REDIS_OPERATION_TIMEOUT
}

pub fn should_bypass_redis_temporarily() -> bool {
    let Ok(guard) = REDIS_COOLDOWN_UNTIL.lock() else {
        return false;
    };
    guard
        .as_ref()
        .map(|deadline| Instant::now() < *deadline)
        .unwrap_or(false)
}

pub fn note_redis_failure_and_start_cooldown() {
    if let Ok(mut guard) = REDIS_COOLDOWN_UNTIL.lock() {
        *guard = Some(Instant::now() + Duration::from_millis(*REDIS_COOLDOWN_MS));
    }
}

pub fn note_redis_success() {
    if let Ok(mut guard) = REDIS_COOLDOWN_UNTIL.lock() {
        *guard = None;
    }
}

pub fn initialize_global_redis_from_env() -> bool {
    if GLOBAL_REDIS.get().is_some() {
        return true;
    }

    let redis_url: Option<String> = std::env::var("ATHENA_REDIS_URL")
        .ok()
        .map(|value| value.trim().to_string())
        .filter(|value| !value.is_empty());

    let Some(redis_url) = redis_url else {
        info!("ATHENA_REDIS_URL is not set; Redis cache integration is disabled");
        return false;
    };

    let client = match RedisCacheClient::from_url(&redis_url) {
        Ok(client) => client,
        Err(err) => {
            warn!(error = %err, "Failed to initialize Redis cache client; continuing without Redis");
            return false;
        }
    };

    if GLOBAL_REDIS.set(client).is_err() {
        warn!("GLOBAL_REDIS was already initialized; skipping Redis re-initialization");
        return true;
    }

    info!("Redis cache integration enabled from ATHENA_REDIS_URL");
    true
}

#[cfg(test)]
mod tests {
    use super::RedisCacheClient;

    #[test]
    fn redis_client_rejects_invalid_url() {
        let result = RedisCacheClient::from_url("not a redis url");
        assert!(result.is_err());
    }
}