use std::collections::HashMap;
use std::time::Duration;
use async_trait::async_trait;
use fred::prelude::*;
use fred::types::ConnectHandle;
use crate::backend::{Backend, HealthStatus, TtlInspectable};
use crate::error::{BackendError, BackendErrorKind};
fn sanitize_redis_message(msg: &str) -> String {
if let Some(proto_end) = msg.find("://") {
if let Some(at_pos) = msg[proto_end..].find('@') {
let mut sanitized = String::with_capacity(msg.len());
sanitized.push_str(&msg[..proto_end + 3]);
sanitized.push_str("[REDACTED]");
sanitized.push_str(&msg[proto_end + at_pos..]);
return sanitized;
}
}
msg.to_string()
}
fn redis_err(e: RedisError) -> BackendError {
let kind = match e.kind() {
RedisErrorKind::Auth => BackendErrorKind::Authentication,
RedisErrorKind::IO => BackendErrorKind::Transient,
RedisErrorKind::Timeout => BackendErrorKind::Timeout,
RedisErrorKind::Canceled => BackendErrorKind::Transient,
_ => BackendErrorKind::Permanent,
};
BackendError {
kind,
message: sanitize_redis_message(&e.to_string()),
source: Some(Box::new(e)),
}
}
pub struct RedisBackend {
client: RedisClient,
}
impl std::fmt::Debug for RedisBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisBackend").finish_non_exhaustive()
}
}
impl RedisBackend {
pub fn builder() -> RedisBackendBuilder {
RedisBackendBuilder::default()
}
pub async fn connect(&self) -> Result<ConnectHandle, BackendError> {
self.client.init().await.map_err(redis_err)
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(not(feature = "unsync"), async_trait)]
#[cfg_attr(feature = "unsync", async_trait(?Send))]
impl Backend for RedisBackend {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, BackendError> {
let result: Option<bytes::Bytes> = self.client.get(key).await.map_err(redis_err)?;
Ok(result.map(|b| b.to_vec()))
}
async fn set(
&self,
key: &str,
value: Vec<u8>,
ttl: Option<Duration>,
) -> Result<(), BackendError> {
let expiration = ttl.map(|d| {
let secs = i64::try_from(d.as_secs().max(1)).unwrap_or(i64::MAX);
Expiration::EX(secs)
});
self.client
.set::<(), _, _>(key, value.as_slice(), expiration, None, false)
.await
.map_err(redis_err)
}
async fn delete(&self, key: &str) -> Result<bool, BackendError> {
let removed: i64 = self.client.del(key).await.map_err(redis_err)?;
Ok(removed > 0)
}
async fn exists(&self, key: &str) -> Result<bool, BackendError> {
let count: i64 = self.client.exists(key).await.map_err(redis_err)?;
Ok(count > 0)
}
async fn health(&self) -> Result<HealthStatus, BackendError> {
let start = std::time::Instant::now();
let _pong: String = self.client.ping().await.map_err(redis_err)?;
let latency = start.elapsed();
let mut details = HashMap::new();
details.insert("latency_ms".to_string(), latency.as_millis().to_string());
Ok(HealthStatus {
is_healthy: true,
latency_ms: latency.as_secs_f64() * 1000.0,
backend_type: "redis".to_string(),
details,
})
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(not(feature = "unsync"), async_trait)]
#[cfg_attr(feature = "unsync", async_trait(?Send))]
impl TtlInspectable for RedisBackend {
async fn ttl(&self, key: &str) -> Result<Option<Duration>, BackendError> {
let secs: i64 = self.client.ttl(key).await.map_err(redis_err)?;
match secs {
..0 => Ok(None),
n => Ok(Some(Duration::from_secs(n.unsigned_abs()))),
}
}
}
#[derive(Default)]
#[must_use]
pub struct RedisBackendBuilder {
url: Option<String>,
}
impl RedisBackendBuilder {
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn build(self) -> Result<RedisBackend, crate::error::CachekitError> {
use crate::error::CachekitError;
let url = self
.url
.filter(|u| !u.is_empty())
.ok_or_else(|| CachekitError::Config("url is required".to_string()))?;
let config = RedisConfig::from_url(&url).map_err(|e| {
CachekitError::Config(format!(
"invalid Redis URL: {}",
sanitize_redis_message(&e.to_string())
))
})?;
let client = RedisClient::new(config, None, None, None);
Ok(RedisBackend { client })
}
}