use super::Cache;
use crate::core::constants::MAX_JSON_DEPTH;
use crate::error::{CacheError, Result};
use crate::traits::CacheKey;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[cfg(any(feature = "tracing", feature = "full"))]
use tracing::instrument;
fn json_depth(value: &serde_json::Value) -> usize {
match value {
serde_json::Value::Object(map) => {
if map.is_empty() {
1
} else {
map.values().map(json_depth).max().unwrap_or(0) + 1
}
}
serde_json::Value::Array(arr) => {
if arr.is_empty() {
1
} else {
arr.iter().map(json_depth).max().unwrap_or(0) + 1
}
}
_ => 1,
}
}
static GET_OR_LOCKS: Lazy<Mutex<HashMap<String, Arc<tokio::sync::Notify>>>> = Lazy::new(|| Mutex::new(HashMap::new()));
struct GetOrGuard<'a> {
map: &'a Mutex<HashMap<String, Arc<tokio::sync::Notify>>>,
key: String,
removed: bool,
}
impl Drop for GetOrGuard<'_> {
fn drop(&mut self) {
if !self.removed {
if let Ok(mut map) = self.map.lock() {
map.remove(&self.key);
}
}
}
}
#[cfg(any(feature = "serialization", feature = "full"))]
fn deserialize_value<V: serde::de::DeserializeOwned>(data: &[u8]) -> Result<V> {
let depth_limit: usize = MAX_JSON_DEPTH;
let json_value: serde_json::Value =
serde_json::from_slice(data).map_err(|e| CacheError::Serialization(e.to_string()))?;
if json_depth(&json_value) > depth_limit {
return Err(CacheError::Serialization(format!(
"JSON深度 {} 超过最大限制 {}",
json_depth(&json_value),
depth_limit
)));
}
serde_json::from_value(json_value).map_err(|e| CacheError::Serialization(e.to_string()))
}
#[cfg(not(any(feature = "serialization", feature = "full")))]
fn deserialize_value<V>(data: &[u8]) -> Result<V> {
let _ = data;
Err(CacheError::Serialization(
"Serialization feature is required for typed get operations".to_string(),
))
}
impl<K, V> Cache<K, V>
where
K: CacheKey,
V: serde::Serialize + for<'de> serde::Deserialize<'de>,
{
#[cfg_attr(
any(feature = "tracing", feature = "full"),
instrument(skip(self, key), level = "debug", fields(key))
)]
pub async fn get(&self, key: &K) -> Result<Option<V>> {
let key_str = key.to_key_string();
let bytes = self.backend.get(&key_str).await?;
match bytes {
Some(data) => deserialize_value(&data).map(Some),
None => Ok(None),
}
}
pub async fn clear(&self) -> Result<()> {
self.backend.clear().await
}
pub async fn shutdown(&self) {
self.backend.shutdown().await
}
pub async fn health_check(&self) -> Result<()> {
self.backend.health_check().await
}
pub async fn stats(&self) -> Result<std::collections::HashMap<String, String>> {
self.backend.stats().await
}
pub async fn len(&self) -> Result<u64> {
self.backend.len().await
}
pub async fn is_empty(&self) -> Result<bool> {
self.backend.is_empty().await
}
pub async fn capacity(&self) -> Result<u64> {
self.backend.capacity().await
}
#[cfg_attr(
any(feature = "tracing", feature = "full"),
instrument(skip(self, key, value), level = "debug", fields(key))
)]
pub async fn set(&self, key: &K, value: &V) -> Result<()> {
self.set_with_ttl(key, value, None).await
}
pub async fn set_with_ttl(&self, key: &K, value: &V, ttl: Option<Duration>) -> Result<()> {
let key_str = key.to_key_string();
#[cfg(any(feature = "serialization", feature = "full"))]
{
let bytes = match serde_json::to_vec(value) {
Ok(b) => b,
Err(e) => return Err(CacheError::Serialization(e.to_string())),
};
self.backend.set(&key_str, bytes, ttl).await
}
#[cfg(not(any(feature = "serialization", feature = "full")))]
{
let _ = (key_str, value);
Err(CacheError::Serialization(
"Serialization feature is required for typed set operations".to_string(),
))
}
}
#[cfg_attr(
any(feature = "tracing", feature = "full"),
instrument(skip(self, key), level = "debug", fields(key))
)]
pub async fn delete(&self, key: &K) -> Result<()> {
let key_str = key.to_key_string();
self.backend.delete(&key_str).await
}
pub async fn exists(&self, key: &K) -> Result<bool> {
let key_str = key.to_key_string();
self.backend.exists(&key_str).await
}
pub async fn get_or<F, Fut>(&self, key: &K, fallback: F) -> Result<V>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V>>,
{
if let Some(value) = self.get(key).await? {
return Ok(value);
}
let key_str = key.to_key_string();
let (is_follower, notify) = {
let mut map = GET_OR_LOCKS
.lock()
.expect("GET_OR_LOCKS poisoned - concurrent operation panic detected");
match map.entry(key_str.clone()) {
std::collections::hash_map::Entry::Occupied(entry) => {
(true, entry.get().clone())
}
std::collections::hash_map::Entry::Vacant(entry) => {
let n = Arc::new(tokio::sync::Notify::new());
entry.insert(n.clone());
(false, n)
}
}
};
if is_follower {
notify.notified().await;
return self.get(key).await?.ok_or_else(|| {
CacheError::L1Error("get_or: concurrent fetch leader failed to cache result".to_string())
});
}
let mut guard = GetOrGuard {
map: &GET_OR_LOCKS,
key: key_str.clone(),
removed: false,
};
if let Some(value) = self.get(key).await? {
GET_OR_LOCKS
.lock()
.expect("GET_OR_LOCKS poisoned - concurrent operation panic detected")
.remove(&key_str);
guard.removed = true;
notify.notify_waiters();
return Ok(value);
}
self.execute_fallback(key, &key_str, fallback, ¬ify, &mut guard)
.await
}
async fn execute_fallback<F, Fut>(
&self,
key: &K,
key_str: &str,
fallback: F,
notify: &Arc<tokio::sync::Notify>,
guard: &mut GetOrGuard<'_>,
) -> Result<V>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<V>>,
{
let result = fallback().await;
match result {
Ok(value) => {
self.set(key, &value).await?;
GET_OR_LOCKS
.lock()
.expect("GET_OR_LOCKS poisoned - concurrent operation panic detected")
.remove(key_str);
guard.removed = true;
notify.notify_waiters();
Ok(value)
}
Err(e) => {
GET_OR_LOCKS
.lock()
.expect("GET_OR_LOCKS poisoned - concurrent operation panic detected")
.remove(key_str);
guard.removed = true;
notify.notify_waiters();
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_clear() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"key".to_string(), &"value".to_string()).await.unwrap();
cache.clear().await.unwrap();
assert!(cache.get(&"key".to_string()).await.unwrap().is_none());
}
#[tokio::test]
async fn test_cache_len() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"key1".to_string(), &"v1".to_string()).await.unwrap();
let len = cache.len().await.unwrap();
assert!(len <= 100, "len should be reasonable after single insert");
}
#[tokio::test]
async fn test_cache_is_empty() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"key".to_string(), &"value".to_string()).await.unwrap();
let _ = cache.is_empty().await.unwrap();
}
#[tokio::test]
async fn test_cache_exists() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
assert!(!cache.exists(&"key".to_string()).await.unwrap());
cache.set(&"key".to_string(), &"value".to_string()).await.unwrap();
assert!(cache.exists(&"key".to_string()).await.unwrap());
}
#[tokio::test]
async fn test_cache_delete() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"key".to_string(), &"value".to_string()).await.unwrap();
cache.delete(&"key".to_string()).await.unwrap();
assert!(cache.get(&"key".to_string()).await.unwrap().is_none());
}
#[tokio::test]
async fn test_cache_get_or() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
let value = cache
.get_or(&"key".to_string(), || async { Ok("computed".to_string()) })
.await
.unwrap();
assert_eq!(value, "computed");
let cached = cache.get(&"key".to_string()).await.unwrap().unwrap();
assert_eq!(cached, "computed");
}
#[tokio::test]
async fn test_cache_health_check() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
assert!(cache.health_check().await.is_ok());
}
#[tokio::test]
async fn test_cache_stats() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
let stats = cache.stats().await.unwrap();
assert!(stats.contains_key("type"));
}
#[tokio::test]
async fn test_cache_get_miss_returns_none() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
let result = cache.get(&"missing".to_string()).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_cache_set_overwrite() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"k".to_string(), &"v1".to_string()).await.unwrap();
assert_eq!(cache.get(&"k".to_string()).await.unwrap().unwrap(), "v1".to_string());
cache.set(&"k".to_string(), &"v2".to_string()).await.unwrap();
assert_eq!(cache.get(&"k".to_string()).await.unwrap().unwrap(), "v2".to_string());
}
#[tokio::test]
async fn test_cache_delete_missing_key_no_error() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
assert!(cache.delete(&"never".to_string()).await.is_ok());
}
#[tokio::test]
async fn test_cache_exists_after_delete() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"k".to_string(), &"v".to_string()).await.unwrap();
assert!(cache.exists(&"k".to_string()).await.unwrap());
cache.delete(&"k".to_string()).await.unwrap();
assert!(!cache.exists(&"k".to_string()).await.unwrap());
}
#[tokio::test]
async fn test_cache_set_with_ttl() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache
.set_with_ttl(&"k".to_string(), &"v".to_string(), Some(Duration::from_secs(60)))
.await
.unwrap();
assert_eq!(cache.get(&"k".to_string()).await.unwrap().unwrap(), "v".to_string());
}
#[tokio::test]
async fn test_cache_set_with_ttl_none() {
let cache: Cache<String, i32> = Cache::builder().build().await.unwrap();
cache.set_with_ttl(&"k".to_string(), &42, None).await.unwrap();
assert_eq!(cache.get(&"k".to_string()).await.unwrap().unwrap(), 42);
}
#[tokio::test]
async fn test_cache_get_set_integer_type() {
let cache: Cache<String, i64> = Cache::builder().build().await.unwrap();
cache.set(&"count".to_string(), &12345).await.unwrap();
assert_eq!(cache.get(&"count".to_string()).await.unwrap().unwrap(), 12345);
}
#[tokio::test]
async fn test_cache_get_set_struct_type() {
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct User {
id: u64,
name: String,
}
let cache: Cache<String, User> = Cache::builder().build().await.unwrap();
let user = User {
id: 1,
name: "alice".to_string(),
};
cache.set(&"user:1".to_string(), &user).await.unwrap();
let result = cache.get(&"user:1".to_string()).await.unwrap().unwrap();
assert_eq!(result, user);
}
#[tokio::test]
async fn test_cache_get_or_cache_hit_fast_path() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"k".to_string(), &"cached".to_string()).await.unwrap();
let value = cache
.get_or(&"k".to_string(), || async {
Err(CacheError::Operation("fallback should not be called".to_string()))
})
.await
.unwrap();
assert_eq!(value, "cached");
}
#[tokio::test]
async fn test_cache_get_or_fallback_error_propagates() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
let result: Result<String> = cache
.get_or(&"missing".to_string(), || async {
Err(CacheError::Operation("db down".to_string()))
})
.await;
assert!(result.is_err());
match result {
Err(CacheError::Operation(msg)) => assert_eq!(msg, "db down"),
_ => panic!("expected CacheError::Operation"),
}
}
#[tokio::test]
async fn test_cache_get_or_writes_to_cache() {
let cache: Cache<String, i32> = Cache::builder().build().await.unwrap();
let v1 = cache.get_or(&"k".to_string(), || async { Ok(99) }).await.unwrap();
assert_eq!(v1, 99);
let cached = cache.get(&"k".to_string()).await.unwrap().unwrap();
assert_eq!(cached, 99);
}
#[tokio::test]
async fn test_cache_capacity() {
let cache: Cache<String, String> = Cache::builder().capacity(500).build().await.unwrap();
let capacity = cache.capacity().await.unwrap();
assert_eq!(capacity, 500);
}
#[tokio::test]
async fn test_cache_shutdown() {
let cache: Cache<String, String> = Cache::builder().build().await.unwrap();
cache.set(&"k".to_string(), &"v".to_string()).await.unwrap();
cache.shutdown().await;
}
#[test]
fn test_json_depth_scalar() {
let v = serde_json::json!(42);
assert_eq!(json_depth(&v), 1);
}
#[test]
fn test_json_depth_empty_object() {
let v = serde_json::json!({});
assert_eq!(json_depth(&v), 1);
}
#[test]
fn test_json_depth_empty_array() {
let v = serde_json::json!([]);
assert_eq!(json_depth(&v), 1);
}
#[test]
fn test_json_depth_nested_object() {
let v = serde_json::json!({"a": {"b": {"c": 1}}});
assert_eq!(json_depth(&v), 4);
}
#[test]
fn test_json_depth_nested_array() {
let v = serde_json::json!([[[1]]]);
assert_eq!(json_depth(&v), 4);
}
#[test]
fn test_json_depth_mixed() {
let v = serde_json::json!({"a": [1, {"b": 2}]});
assert_eq!(json_depth(&v), 4);
}
#[tokio::test]
async fn test_deserialize_value_valid() {
let cache: Cache<String, i32> = Cache::builder().build().await.unwrap();
cache.set(&"k".to_string(), &42).await.unwrap();
let v = cache.get(&"k".to_string()).await.unwrap().unwrap();
assert_eq!(v, 42);
}
#[tokio::test]
async fn test_deserialize_value_invalid_json() {
let cache: Cache<String, i32> = Cache::builder().build().await.unwrap();
cache.backend.set("bad", b"not json".to_vec(), None).await.unwrap();
let result = cache.get(&"bad".to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_deserialize_value_depth_exceeded() {
let mut json_str = String::new();
for _ in 0..(MAX_JSON_DEPTH + 5) {
json_str.push('[');
}
for _ in 0..(MAX_JSON_DEPTH + 5) {
json_str.push(']');
}
let cache: Cache<String, serde_json::Value> = Cache::builder().build().await.unwrap();
cache.backend.set("deep", json_str.into_bytes(), None).await.unwrap();
let result = cache.get(&"deep".to_string()).await;
assert!(result.is_err());
match result {
Err(CacheError::Serialization(msg)) => {
assert!(msg.contains("深度") || msg.contains("depth"));
}
_ => panic!("expected CacheError::Serialization"),
}
}
}