use std::sync::Arc;
use std::time::Duration;
use crate::error::Error;
const DEFAULT_SCAN_COUNT: usize = 100;
pub struct RedisCache {
conn: redis::aio::MultiplexedConnection,
key_prefix: String,
}
impl RedisCache {
pub async fn new(url: &str, key_prefix: String) -> Result<Self, Error> {
let client = redis::Client::open(url)
.map_err(|e| Error::cache("connect", None, format!("failed: {e}")))?;
let conn = client
.get_multiplexed_async_connection()
.await
.map_err(|e| {
Error::cache("connect", None, format!("connection creation failed: {e}"))
})?;
let mut ping_conn = conn.clone();
let _: String = redis::cmd("PING")
.query_async(&mut ping_conn)
.await
.map_err(|e| Error::cache("ping", None, format!("failed: {e}")))?;
Ok(Self { conn, key_prefix })
}
fn build_key(&self, key: &str) -> String {
if self.key_prefix.is_empty() {
key.to_string()
} else {
format!("{}:{}", self.key_prefix, key)
}
}
fn build_scan_pattern(&self) -> String {
if self.key_prefix.is_empty() {
"*".to_string()
} else {
format!("{}:*", self.key_prefix)
}
}
}
#[async_trait::async_trait]
impl super::Cache for RedisCache {
async fn get(&self, key: &str) -> Option<Arc<str>> {
let mut conn = self.conn.clone();
let full_key = self.build_key(key);
let result: redis::RedisResult<Option<String>> = redis::cmd("GET")
.arg(&full_key)
.query_async(&mut conn)
.await;
result.ok().flatten().map(|s| Arc::from(s.into_boxed_str()))
}
#[allow(clippy::cast_possible_truncation)]
async fn set(
&self,
key: String,
value: String,
ttl: Option<Duration>,
) -> crate::error::Result<()> {
let mut conn = self.conn.clone();
let full_key = self.build_key(&key);
let result: redis::RedisResult<()> = if let Some(ttl) = ttl {
let ms = ttl.as_millis() as u64;
redis::cmd("SET")
.arg(&full_key)
.arg(&value)
.arg("PX")
.arg(ms)
.query_async(&mut conn)
.await
} else {
redis::cmd("SET")
.arg(&full_key)
.arg(&value)
.query_async(&mut conn)
.await
};
result.map_err(|e| Error::cache("set", Some(key.clone()), format!("failed: {e}")))
}
async fn delete(&self, key: &str) -> crate::error::Result<()> {
let mut conn = self.conn.clone();
let full_key = self.build_key(key);
let result: redis::RedisResult<()> = redis::cmd("DEL")
.arg(&full_key)
.query_async(&mut conn)
.await;
result.map_err(|e| Error::cache("delete", Some(key.to_string()), format!("failed: {e}")))
}
async fn clear(&self) -> crate::error::Result<()> {
let mut conn = self.conn.clone();
let pattern = self.build_scan_pattern();
let mut cursor: u64 = 0;
let mut total_deleted: u64 = 0;
loop {
let scan_result: redis::RedisResult<(u64, Vec<String>)> = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(DEFAULT_SCAN_COUNT)
.query_async(&mut conn)
.await;
match scan_result {
Ok((new_cursor, keys)) => {
if !keys.is_empty() {
let del_result: redis::RedisResult<u64> =
redis::cmd("DEL").arg(&keys).query_async(&mut conn).await;
match del_result {
Ok(deleted) => total_deleted += deleted,
Err(e) => {
return Err(Error::cache(
"clear",
None,
format!("DEL failed: {e}"),
));
}
}
}
cursor = new_cursor;
if cursor == 0 {
break;
}
}
Err(e) => {
return Err(Error::cache("clear", None, format!("SCAN failed: {e}")));
}
}
}
if total_deleted > 0 {
tracing::debug!(
"Cleared {} cache entries with prefix '{}'",
total_deleted,
self.key_prefix
);
}
Ok(())
}
async fn exists(&self, key: &str) -> bool {
let mut conn = self.conn.clone();
let full_key = self.build_key(key);
redis::cmd("EXISTS")
.arg(&full_key)
.query_async(&mut conn)
.await
.unwrap_or(0)
> 0
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cache::Cache;
#[tokio::test]
#[ignore = "Requires Redis server"]
async fn test_redis_cache_basic() {
let cache = RedisCache::new("redis://localhost:6379", "test_prefix".to_string()).await;
assert!(cache.is_ok());
let cache = cache.unwrap();
cache
.set("test_key".to_string(), "test_value".to_string(), None)
.await
.expect("set should succeed");
let value = cache.get("test_key").await;
assert!(value.is_some());
assert_eq!(value.unwrap().as_ref(), "test_value");
cache
.delete("test_key")
.await
.expect("delete should succeed");
let value = cache.get("test_key").await;
assert_eq!(value, None);
cache
.set("exists_key".to_string(), "exists_value".to_string(), None)
.await
.expect("set should succeed");
assert!(cache.exists("exists_key").await);
assert!(!cache.exists("non_exists_key").await);
cache
.set("clear_test".to_string(), "value".to_string(), None)
.await
.expect("set should succeed");
cache.clear().await.expect("clear should succeed");
let cleared_value = cache.get("clear_test").await;
assert!(cleared_value.is_none());
}
#[test]
fn test_build_key() {
let prefix = "";
let key = "mykey";
let expected = "mykey";
let result = if prefix.is_empty() {
key.to_string()
} else {
format!("{prefix}:{key}")
};
assert_eq!(result, expected);
let prefix = "myapp";
let key = "mykey";
let expected = "myapp:mykey";
let result = if prefix.is_empty() {
key.to_string()
} else {
format!("{prefix}:{key}")
};
assert_eq!(result, expected);
}
#[test]
fn test_build_scan_pattern() {
let prefix = "";
let expected = "*";
let result = if prefix.is_empty() {
"*".to_string()
} else {
format!("{prefix}:*")
};
assert_eq!(result, expected);
let prefix = "myapp";
let expected = "myapp:*";
let result = if prefix.is_empty() {
"*".to_string()
} else {
format!("{prefix}:*")
};
assert_eq!(result, expected);
}
}