athena_rs 2.10.0

Database gateway API
Documentation
use once_cell::sync::{Lazy, OnceCell};
use redis::{Client, Commands};
use serde_json::Value;
use std::time::Duration;
use tokio::task;
use tracing::{info, warn};

pub struct RedisCacheClient {
    client: Client,
}

impl RedisCacheClient {
    pub fn from_url(url: &str) -> Result<Self, String> {
        let client = Client::open(url).map_err(|err| format!("invalid redis URL: {err}"))?;
        Ok(Self { client })
    }

    pub async fn get(&self, key: &str) -> Result<Value, String> {
        let client: Client = self.client.clone();
        let redis_key: String = key.to_string();

        task::spawn_blocking(move || {
            let mut conn = client
                .get_connection()
                .map_err(|err| format!("redis connection failed: {err}"))?;

            let raw: Option<String> = conn
                .get(redis_key)
                .map_err(|err| format!("redis get failed: {err}"))?;

            let Some(raw) = raw else {
                return Ok(Value::Null);
            };

            serde_json::from_str::<Value>(&raw)
                .map_err(|err| format!("invalid JSON payload from redis: {err}"))
        })
        .await
        .map_err(|err| format!("redis task join error: {err}"))?
    }

    pub async fn set_with_ttl(
        &self,
        key: &str,
        value: &Value,
        ttl_secs: u64,
    ) -> Result<(), String> {
        let client: Client = self.client.clone();
        let redis_key: String = key.to_string();
        let payload: String = serde_json::to_string(value)
            .map_err(|err| format!("failed to serialize JSON for redis: {err}"))?;

        task::spawn_blocking(move || {
            let mut conn = client
                .get_connection()
                .map_err(|err| format!("redis connection failed: {err}"))?;

            let _: () = conn
                .set_ex(redis_key, payload, ttl_secs)
                .map_err(|err| format!("redis set_ex failed: {err}"))?;
            Ok(())
        })
        .await
        .map_err(|err| format!("redis task join error: {err}"))?
    }
}

pub static GLOBAL_REDIS: OnceCell<RedisCacheClient> = OnceCell::new();

static REDIS_OPERATION_TIMEOUT: Lazy<Duration> = Lazy::new(|| {
    let timeout_ms: u64 = std::env::var("ATHENA_REDIS_OP_TIMEOUT_MS")
        .ok()
        .and_then(|value| value.parse::<u64>().ok())
        .filter(|value| *value > 0)
        .unwrap_or(30);
    Duration::from_millis(timeout_ms)
});

pub fn redis_operation_timeout() -> Duration {
    *REDIS_OPERATION_TIMEOUT
}

pub fn initialize_global_redis_from_env() -> bool {
    if GLOBAL_REDIS.get().is_some() {
        return true;
    }

    let redis_url: Option<String> = std::env::var("ATHENA_REDIS_URL")
        .ok()
        .map(|value| value.trim().to_string())
        .filter(|value| !value.is_empty());

    let Some(redis_url) = redis_url else {
        info!("ATHENA_REDIS_URL is not set; Redis cache integration is disabled");
        return false;
    };

    let client = match RedisCacheClient::from_url(&redis_url) {
        Ok(client) => client,
        Err(err) => {
            warn!(error = %err, "Failed to initialize Redis cache client; continuing without Redis");
            return false;
        }
    };

    if GLOBAL_REDIS.set(client).is_err() {
        warn!("GLOBAL_REDIS was already initialized; skipping Redis re-initialization");
        return true;
    }

    info!("Redis cache integration enabled from ATHENA_REDIS_URL");
    true
}

#[cfg(test)]
mod tests {
    use super::RedisCacheClient;

    #[test]
    fn redis_client_rejects_invalid_url() {
        let result = RedisCacheClient::from_url("not a redis url");
        assert!(result.is_err());
    }
}