use once_cell::sync::{Lazy, OnceCell};
use redis::{Client, Commands};
use serde_json::Value;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::task;
use tracing::{info, warn};
pub struct RedisCacheClient {
client: Client,
}
impl RedisCacheClient {
pub fn from_url(url: &str) -> Result<Self, String> {
let client: Client =
Client::open(url).map_err(|err| format!("invalid redis URL: {err}"))?;
Ok(Self { client })
}
pub async fn get(&self, key: &str) -> Result<Value, String> {
let client: Client = self.client.clone();
let redis_key: String = key.to_string();
task::spawn_blocking(move || {
let mut conn = client
.get_connection()
.map_err(|err| format!("redis connection failed: {err}"))?;
let raw: Option<String> = conn
.get(redis_key)
.map_err(|err| format!("redis get failed: {err}"))?;
let Some(raw) = raw else {
return Ok(Value::Null);
};
serde_json::from_str::<Value>(&raw)
.map_err(|err| format!("invalid JSON payload from redis: {err}"))
})
.await
.map_err(|err| format!("redis task join error: {err}"))?
}
pub async fn set_with_ttl(
&self,
key: &str,
value: &Value,
ttl_secs: u64,
) -> Result<(), String> {
let client: Client = self.client.clone();
let redis_key: String = key.to_string();
let payload: String = serde_json::to_string(value)
.map_err(|err| format!("failed to serialize JSON for redis: {err}"))?;
task::spawn_blocking(move || {
let mut conn = client
.get_connection()
.map_err(|err| format!("redis connection failed: {err}"))?;
let _: () = conn
.set_ex(redis_key, payload, ttl_secs)
.map_err(|err| format!("redis set_ex failed: {err}"))?;
Ok(())
})
.await
.map_err(|err| format!("redis task join error: {err}"))?
}
pub async fn increment_counter_with_ttl(
&self,
key: &str,
ttl_secs: u64,
) -> Result<u64, String> {
let client: Client = self.client.clone();
let redis_key: String = key.to_string();
task::spawn_blocking(move || {
let mut conn = client
.get_connection()
.map_err(|err| format!("redis connection failed: {err}"))?;
let count: i64 = conn
.incr(&redis_key, 1_i64)
.map_err(|err| format!("redis incr failed: {err}"))?;
if count == 1 {
let _: bool = conn
.expire(&redis_key, ttl_secs as i64)
.map_err(|err| format!("redis expire failed: {err}"))?;
}
Ok(count.max(0) as u64)
})
.await
.map_err(|err| format!("redis task join error: {err}"))?
}
pub fn increment_counter_with_ttl_blocking(
&self,
key: &str,
ttl_secs: u64,
) -> Result<u64, String> {
let mut conn = self
.client
.get_connection()
.map_err(|err| format!("redis connection failed: {err}"))?;
let count: i64 = conn
.incr(key, 1_i64)
.map_err(|err| format!("redis incr failed: {err}"))?;
if count == 1 {
let _: bool = conn
.expire(key, ttl_secs as i64)
.map_err(|err| format!("redis expire failed: {err}"))?;
}
Ok(count.max(0) as u64)
}
}
pub static GLOBAL_REDIS: OnceCell<RedisCacheClient> = OnceCell::new();
static REDIS_OPERATION_TIMEOUT: Lazy<Duration> = Lazy::new(|| {
let timeout_ms: u64 = std::env::var("ATHENA_REDIS_OP_TIMEOUT_MS")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(30);
Duration::from_millis(timeout_ms)
});
static REDIS_COOLDOWN_UNTIL: Lazy<Mutex<Option<Instant>>> = Lazy::new(|| Mutex::new(None));
static REDIS_COOLDOWN_MS: Lazy<u64> = Lazy::new(|| {
std::env::var("ATHENA_REDIS_COOLDOWN_MS")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.filter(|value| *value > 0)
.unwrap_or(5_000)
});
pub fn redis_operation_timeout() -> Duration {
*REDIS_OPERATION_TIMEOUT
}
pub fn should_bypass_redis_temporarily() -> bool {
let Ok(guard) = REDIS_COOLDOWN_UNTIL.lock() else {
return false;
};
guard
.as_ref()
.map(|deadline| Instant::now() < *deadline)
.unwrap_or(false)
}
pub fn note_redis_failure_and_start_cooldown() {
if let Ok(mut guard) = REDIS_COOLDOWN_UNTIL.lock() {
*guard = Some(Instant::now() + Duration::from_millis(*REDIS_COOLDOWN_MS));
}
}
pub fn note_redis_success() {
if let Ok(mut guard) = REDIS_COOLDOWN_UNTIL.lock() {
*guard = None;
}
}
pub fn initialize_global_redis_from_env() -> bool {
if GLOBAL_REDIS.get().is_some() {
return true;
}
let redis_url: Option<String> = std::env::var("ATHENA_REDIS_URL")
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
let Some(redis_url) = redis_url else {
info!("ATHENA_REDIS_URL is not set; Redis cache integration is disabled");
return false;
};
let client = match RedisCacheClient::from_url(&redis_url) {
Ok(client) => client,
Err(err) => {
warn!(error = %err, "Failed to initialize Redis cache client; continuing without Redis");
return false;
}
};
if GLOBAL_REDIS.set(client).is_err() {
warn!("GLOBAL_REDIS was already initialized; skipping Redis re-initialization");
return true;
}
info!("Redis cache integration enabled from ATHENA_REDIS_URL");
true
}
#[cfg(test)]
mod tests {
use super::RedisCacheClient;
#[test]
fn redis_client_rejects_invalid_url() {
let result = RedisCacheClient::from_url("not a redis url");
assert!(result.is_err());
}
}