use async_trait::async_trait;
use std::collections::HashMap;
use std::hash::Hash;
#[cfg(feature = "redis")]
use deadpool_redis::redis::{self, AsyncCommands};
#[cfg(feature = "redis")]
use serde::{Deserialize, Serialize};
use crate::error::AppResult;
#[async_trait]
pub trait BatchCache {
type Key: Send + Sync + Clone + Hash + Eq;
type Value: Send + Sync + Clone;
async fn get(&self, req: Self::Key) -> AppResult<Option<Self::Value>>;
async fn get_batch(&self, req: &[Self::Key]) -> AppResult<HashMap<Self::Key, Self::Value>>;
async fn delete(&self, req: Self::Key) -> AppResult<()>;
async fn delete_batch(&self, req: &[Self::Key]) -> AppResult<()>;
async fn refresh(&self, req: &HashMap<Self::Key, Self::Value>) -> AppResult<()> {
let keys: Vec<Self::Key> = req.keys().cloned().collect();
self.delete_batch(&keys).await
}
}
#[cfg(feature = "redis")]
#[async_trait]
pub trait RedisBatchCache: BatchCache
where
Self::Value: Serialize + for<'de> Deserialize<'de>,
{
async fn get_redis_connection(&self) -> AppResult<deadpool_redis::Connection>;
fn build_cache_key(&self, key: &Self::Key) -> String;
fn cache_expire(&self) -> Option<u64> {
None
}
async fn get_from_redis(&self, req: Self::Key) -> AppResult<Option<Self::Value>> {
use crate::error::AppError;
let key = self.build_cache_key(&req);
let mut conn = self.get_redis_connection().await?;
let value: Option<String> = conn.get(&key).await.map_err(AppError::Redis)?;
if let Some(json_str) = value {
let value: Self::Value = serde_json::from_str(&json_str)
.map_err(|e| AppError::Internal(anyhow::anyhow!("JSON 反序列化失败: {}", e)))?;
return Ok(Some(value));
}
Ok(None)
}
async fn get_batch_from_redis(
&self,
req: &[Self::Key],
) -> AppResult<HashMap<Self::Key, Self::Value>> {
use crate::error::AppError;
if req.is_empty() {
return Ok(HashMap::new());
}
let mut conn = self.get_redis_connection().await?;
let mut result = HashMap::new();
let keys: Vec<String> = req.iter().map(|k| self.build_cache_key(k)).collect();
let values: Vec<Option<String>> = conn.mget(&keys).await.map_err(AppError::Redis)?;
for (key, value_opt) in req.iter().zip(values.iter()) {
if let Some(json_str) = value_opt {
if let Ok(value) = serde_json::from_str::<Self::Value>(json_str) {
result.insert(key.clone(), value);
}
}
}
Ok(result)
}
async fn delete_from_redis(&self, req: Self::Key) -> AppResult<()> {
use crate::error::AppError;
let key = self.build_cache_key(&req);
let mut conn = self.get_redis_connection().await?;
let _: () = conn.del(&key).await.map_err(AppError::Redis)?;
Ok(())
}
async fn delete_batch_from_redis(&self, req: &[Self::Key]) -> AppResult<()> {
use crate::error::AppError;
if req.is_empty() {
return Ok(());
}
let mut conn = self.get_redis_connection().await?;
let keys: Vec<String> = req.iter().map(|k| self.build_cache_key(k)).collect();
if !keys.is_empty() {
let _: () = conn.del(&keys).await.map_err(AppError::Redis)?;
}
Ok(())
}
async fn set_to_redis(&self, key: Self::Key, value: &Self::Value) -> AppResult<()> {
use crate::error::AppError;
let cache_key = self.build_cache_key(&key);
let json_str = serde_json::to_string(value)
.map_err(|e| AppError::Internal(anyhow::anyhow!("JSON 序列化失败: {}", e)))?;
let mut conn = self.get_redis_connection().await?;
if let Some(expire) = self.cache_expire() {
let _: () = conn
.set_ex(&cache_key, &json_str, expire)
.await
.map_err(AppError::Redis)?;
} else {
let _: () = conn
.set(&cache_key, &json_str)
.await
.map_err(AppError::Redis)?;
}
Ok(())
}
async fn set_batch_to_redis(&self, items: &HashMap<Self::Key, Self::Value>) -> AppResult<()> {
use crate::error::AppError;
if items.is_empty() {
return Ok(());
}
let mut conn = self.get_redis_connection().await?;
let expire = self.cache_expire();
let mut pipe = redis::pipe();
pipe.atomic();
for (key, value) in items {
let cache_key = self.build_cache_key(key);
let json_str = serde_json::to_string(value)
.map_err(|e| AppError::Internal(anyhow::anyhow!("JSON 序列化失败: {}", e)))?;
if let Some(expire_secs) = expire {
pipe.set_ex(&cache_key, &json_str, expire_secs);
} else {
pipe.set(&cache_key, &json_str);
}
}
pipe.query_async::<()>(&mut conn)
.await
.map_err(AppError::Redis)?;
Ok(())
}
}
#[cfg(feature = "local_cache")]
#[async_trait]
pub trait LocalBatchCache: BatchCache
where
Self::Value: Send + Sync + Clone + 'static,
Self::Key: Send + Sync + Clone + Hash + Eq + 'static,
{
fn get_local_cache(&self) -> std::sync::Arc<moka::future::Cache<String, Self::Value>>;
fn build_cache_key(&self, key: &Self::Key) -> String;
async fn get_from_local(&self, req: Self::Key) -> AppResult<Option<Self::Value>> {
let cache = self.get_local_cache();
let key = self.build_cache_key(&req);
let value = cache.get(&key).await;
Ok(value)
}
async fn get_batch_from_local(
&self,
req: &[Self::Key],
) -> AppResult<HashMap<Self::Key, Self::Value>> {
if req.is_empty() {
return Ok(HashMap::new());
}
let cache = self.get_local_cache();
let mut result = HashMap::new();
for key in req {
let cache_key = self.build_cache_key(key);
if let Some(value) = cache.get(&cache_key).await {
result.insert(key.clone(), value);
}
}
Ok(result)
}
async fn delete_from_local(&self, req: Self::Key) -> AppResult<()> {
let cache = self.get_local_cache();
let key = self.build_cache_key(&req);
cache.invalidate(&key).await;
Ok(())
}
async fn delete_batch_from_local(&self, req: &[Self::Key]) -> AppResult<()> {
if req.is_empty() {
return Ok(());
}
let cache = self.get_local_cache();
let keys: Vec<String> = req.iter().map(|k| self.build_cache_key(k)).collect();
for key in keys {
cache.invalidate(&key).await;
}
Ok(())
}
async fn set_to_local(&self, key: Self::Key, value: &Self::Value) -> AppResult<()> {
let cache = self.get_local_cache();
let cache_key = self.build_cache_key(&key);
cache.insert(cache_key, value.clone()).await;
Ok(())
}
async fn set_batch_to_local(&self, items: &HashMap<Self::Key, Self::Value>) -> AppResult<()> {
if items.is_empty() {
return Ok(());
}
let cache = self.get_local_cache();
for (key, value) in items {
let cache_key = self.build_cache_key(key);
cache.insert(cache_key, value.clone()).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
struct TestCache {
data: HashMap<i64, String>,
}
#[async_trait]
impl BatchCache for TestCache {
type Key = i64;
type Value = String;
async fn get(&self, req: Self::Key) -> AppResult<Option<Self::Value>> {
Ok(self.data.get(&req).cloned())
}
async fn get_batch(&self, req: &[Self::Key]) -> AppResult<HashMap<Self::Key, Self::Value>> {
let mut result = HashMap::new();
for key in req {
if let Some(value) = self.data.get(key) {
result.insert(*key, value.clone());
}
}
Ok(result)
}
async fn delete(&self, _req: Self::Key) -> AppResult<()> {
Ok(())
}
async fn delete_batch(&self, _req: &[Self::Key]) -> AppResult<()> {
Ok(())
}
}
#[tokio::test]
async fn test_get() {
let mut data = HashMap::new();
data.insert(1, "value1".to_string());
data.insert(2, "value2".to_string());
let cache = TestCache { data };
let result = cache.get(1).await.unwrap();
assert_eq!(result, Some("value1".to_string()));
let result = cache.get(3).await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_get_batch() {
let mut data = HashMap::new();
data.insert(1, "value1".to_string());
data.insert(2, "value2".to_string());
data.insert(3, "value3".to_string());
let cache = TestCache { data };
let result = cache.get_batch(&[1, 2, 4]).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result.get(&1), Some(&"value1".to_string()));
assert_eq!(result.get(&2), Some(&"value2".to_string()));
assert_eq!(result.get(&4), None);
}
#[tokio::test]
async fn test_delete() {
let data = HashMap::new();
let cache = TestCache { data };
let result = cache.delete(1).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_delete_batch() {
let data = HashMap::new();
let cache = TestCache { data };
let result = cache.delete_batch(&[1, 2, 3]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_refresh() {
let mut data = HashMap::new();
data.insert(1, "value1".to_string());
data.insert(2, "value2".to_string());
data.insert(3, "value3".to_string());
let cache = TestCache { data };
let mut refresh_map = HashMap::new();
refresh_map.insert(1, "ignored".to_string());
refresh_map.insert(2, "ignored".to_string());
let result = cache.refresh(&refresh_map).await;
assert!(result.is_ok());
}
}