#![allow(clippy::field_reassign_with_default)]
use crate::smart_ttl::SmartTtlManager;
use parking_lot::RwLock; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
const REDIS_CACHE_PREFIX: &str = "gqlcache:";
const REDIS_TYPE_PREFIX: &str = "gqltype:";
const REDIS_ENTITY_PREFIX: &str = "gqlentity:";
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub max_size: usize,
pub default_ttl: Duration,
pub stale_while_revalidate: Option<Duration>,
pub invalidate_on_mutation: bool,
pub redis_url: Option<String>,
pub vary_headers: Vec<String>,
pub smart_ttl_manager: Option<Arc<SmartTtlManager>>,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size: 10_000,
default_ttl: Duration::from_secs(60),
stale_while_revalidate: None,
invalidate_on_mutation: true,
redis_url: None,
vary_headers: vec!["Authorization".to_string()],
smart_ttl_manager: None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CachedResponse {
pub data: serde_json::Value,
#[serde(with = "serde_millis")]
pub created_at: SystemTime,
pub ttl_secs: u64,
pub referenced_types: HashSet<String>,
pub referenced_entities: HashSet<String>,
}
mod serde_millis {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::{SystemTime, UNIX_EPOCH};
pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let millis = time
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
serializer.serialize_u64(millis)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(UNIX_EPOCH + std::time::Duration::from_millis(millis))
}
}
impl CachedResponse {
pub fn is_expired(&self) -> bool {
self.created_at.elapsed().unwrap_or_default() > Duration::from_secs(self.ttl_secs)
}
pub fn is_stale_but_usable(&self, stale_window: Duration) -> bool {
let elapsed = self.created_at.elapsed().unwrap_or_default();
let ttl = Duration::from_secs(self.ttl_secs);
elapsed > ttl && elapsed <= ttl + stale_window
}
}
#[derive(Debug)]
pub enum CacheLookupResult {
Hit(CachedResponse),
Stale(CachedResponse),
Miss,
}
pub struct ResponseCache {
pub config: CacheConfig,
backend: CacheBackend,
}
enum CacheBackend {
Memory {
cache: RwLock<HashMap<String, CachedResponse>>,
insertion_order: RwLock<Vec<String>>,
type_index: RwLock<HashMap<String, HashSet<String>>>,
entity_index: RwLock<HashMap<String, HashSet<String>>>,
},
Redis {
client: redis::Client,
},
}
impl ResponseCache {
pub fn new(config: CacheConfig) -> Self {
let backend = if let Some(url) = &config.redis_url {
match redis::Client::open(url.as_str()) {
Ok(client) => CacheBackend::Redis { client },
Err(e) => {
tracing::error!("Failed to create Redis client: {}", e);
Self::new_memory()
}
}
} else {
Self::new_memory()
};
Self { config, backend }
}
fn new_memory() -> CacheBackend {
CacheBackend::Memory {
cache: RwLock::new(HashMap::with_capacity(10_000)),
insertion_order: RwLock::new(Vec::with_capacity(10_000)),
type_index: RwLock::new(HashMap::new()),
entity_index: RwLock::new(HashMap::new()),
}
}
pub fn generate_cache_key(
query: &str,
variables: Option<&serde_json::Value>,
operation_name: Option<&str>,
extra_key_components: &[String],
) -> String {
let mut hasher = Sha256::new();
let normalized_query: String = query.split_whitespace().collect::<Vec<_>>().join(" ");
hasher.update(normalized_query.as_bytes());
if let Some(vars) = variables {
if !vars.is_null() {
let sorted_vars = sort_json_value(vars);
if let Ok(vars_str) = serde_json::to_string(&sorted_vars) {
hasher.update(vars_str.as_bytes());
}
}
}
if let Some(op_name) = operation_name {
hasher.update(op_name.as_bytes());
}
for component in extra_key_components {
hasher.update(component.as_bytes());
}
let result = hasher.finalize();
hex::encode(result)
}
pub async fn get(&self, cache_key: &str) -> CacheLookupResult {
match &self.backend {
CacheBackend::Memory { cache, .. } => {
let cache = cache.read();
let Some(entry) = cache.get(cache_key) else {
return CacheLookupResult::Miss;
};
if !entry.is_expired() {
tracing::debug!(cache_key = %cache_key, "Response cache hit (Memory)");
return CacheLookupResult::Hit(entry.clone());
}
if let Some(stale_window) = self.config.stale_while_revalidate {
if entry.is_stale_but_usable(stale_window) {
tracing::debug!(cache_key = %cache_key, "Response cache stale hit (revalidating) (Memory)");
return CacheLookupResult::Stale(entry.clone());
}
}
tracing::debug!(cache_key = %cache_key, "Response cache miss (expired) (Memory)");
CacheLookupResult::Miss
}
CacheBackend::Redis { client, .. } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
tracing::error!("Redis connection error: {}", e);
return CacheLookupResult::Miss;
}
};
let data: Option<String> = match redis::cmd("GET")
.arg(format!("{}{}", REDIS_CACHE_PREFIX, cache_key))
.query_async(&mut conn)
.await
{
Ok(d) => d,
Err(e) => {
tracing::error!("Redis GET error: {}", e);
return CacheLookupResult::Miss;
}
};
if let Some(json) = data {
if let Ok(entry) = serde_json::from_str::<CachedResponse>(&json) {
if !entry.is_expired() {
tracing::debug!(cache_key = %cache_key, "Response cache hit (Redis)");
return CacheLookupResult::Hit(entry);
}
if let Some(stale_window) = self.config.stale_while_revalidate {
if entry.is_stale_but_usable(stale_window) {
tracing::debug!(cache_key = %cache_key, "Response cache stale hit (revalidating) (Redis)");
return CacheLookupResult::Stale(entry);
}
}
}
}
CacheLookupResult::Miss
}
}
}
pub async fn get_entity(&self, type_name: &str, id: &str) -> Option<serde_json::Value> {
let entity_key = format!("entity:{}#{}", type_name, id);
match self.get(&entity_key).await {
CacheLookupResult::Hit(entry) => Some(entry.data),
_ => None,
}
}
pub async fn put_with_query(
&self,
cache_key: String,
query: &str,
query_type: &str,
response: serde_json::Value,
referenced_types: HashSet<String>,
referenced_entities: HashSet<String>,
) {
let ttl_secs = if let Some(smart_ttl) = &self.config.smart_ttl_manager {
let ttl_result = smart_ttl.calculate_ttl(query, query_type, None).await;
tracing::debug!(
cache_key = %cache_key,
ttl_secs = ttl_result.ttl.as_secs(),
strategy = ?ttl_result.strategy,
"Smart TTL calculated"
);
ttl_result.ttl.as_secs()
} else {
self.config.default_ttl.as_secs()
};
self.put_with_ttl(
cache_key,
response,
referenced_types,
referenced_entities,
ttl_secs,
)
.await;
}
pub async fn put_with_ttl(
&self,
cache_key: String,
response: serde_json::Value,
referenced_types: HashSet<String>,
referenced_entities: HashSet<String>,
ttl_secs: u64,
) {
match &self.backend {
CacheBackend::Memory {
cache,
insertion_order,
type_index,
entity_index,
..
} => {
self.evict_if_needed();
let entry = CachedResponse {
data: response,
created_at: SystemTime::now(),
ttl_secs,
referenced_types: referenced_types.clone(),
referenced_entities: referenced_entities.clone(),
};
{
let mut c = cache.write();
c.insert(cache_key.clone(), entry);
}
{
let mut order = insertion_order.write();
order.retain(|k| k != &cache_key);
order.push(cache_key.clone());
}
{
let mut type_idx = type_index.write();
for type_name in &referenced_types {
type_idx
.entry(type_name.clone())
.or_insert_with(HashSet::new)
.insert(cache_key.clone());
}
}
{
let mut entity_idx = entity_index.write();
for entity_key in &referenced_entities {
entity_idx
.entry(entity_key.clone())
.or_insert_with(HashSet::new)
.insert(cache_key.clone());
}
}
tracing::debug!(cache_key = %cache_key, ttl_secs = ttl_secs, "Response cached (Memory)");
}
CacheBackend::Redis { client, .. } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(e) => {
tracing::error!("Redis connection error: {}", e);
return;
}
};
let entry = CachedResponse {
data: response,
created_at: SystemTime::now(),
ttl_secs,
referenced_types: referenced_types.clone(),
referenced_entities: referenced_entities.clone(),
};
let json = match serde_json::to_string(&entry) {
Ok(j) => j,
Err(e) => {
tracing::error!("Serialization error: {}", e);
return;
}
};
let mut pipe = redis::pipe();
pipe.atomic();
pipe.set_ex(
format!("{}{}", REDIS_CACHE_PREFIX, &cache_key),
json,
ttl_secs,
);
for type_name in &referenced_types {
pipe.sadd(
format!("{}{}", REDIS_TYPE_PREFIX, type_name),
format!("{}{}", REDIS_CACHE_PREFIX, &cache_key),
);
}
for entity_key in &referenced_entities {
pipe.sadd(
format!("{}{}", REDIS_ENTITY_PREFIX, entity_key),
format!("{}{}", REDIS_CACHE_PREFIX, &cache_key),
);
}
if let Err(e) = pipe.query_async::<()>(&mut conn).await {
tracing::error!("Redis PUT error: {}", e);
} else {
tracing::debug!(cache_key = %cache_key, ttl_secs = ttl_secs, "Response cached (Redis)");
}
}
}
}
pub async fn put(
&self,
cache_key: String,
response: serde_json::Value,
referenced_types: HashSet<String>,
referenced_entities: HashSet<String>,
) {
let ttl_secs = self.config.default_ttl.as_secs();
self.put_with_ttl(
cache_key,
response,
referenced_types,
referenced_entities,
ttl_secs,
)
.await;
}
pub async fn put_all_entities(&self, response: &serde_json::Value, ttl: Option<Duration>) {
let entities = extract_entities_with_data(response);
let ttl_secs = ttl
.map(|t| t.as_secs())
.unwrap_or(self.config.default_ttl.as_secs());
for (entity_key, data) in entities {
match &self.backend {
CacheBackend::Memory { cache, .. } => {
let entry = CachedResponse {
data: data.clone(),
created_at: SystemTime::now(),
ttl_secs,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
cache
.write()
.insert(format!("entity:{}", entity_key), entry);
}
CacheBackend::Redis { client } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(_) => continue,
};
if let Ok(json) = serde_json::to_string(&data) {
let _: () = redis::cmd("SETEX")
.arg(format!("entity:{}", entity_key))
.arg(ttl_secs)
.arg(json)
.query_async(&mut conn)
.await
.unwrap_or(());
}
}
}
}
}
pub async fn invalidate_by_type(&self, type_name: &str) -> usize {
match &self.backend {
CacheBackend::Memory { type_index, .. } => {
let cache_keys = {
let type_idx = type_index.read();
type_idx.get(type_name).cloned().unwrap_or_default()
};
let count = cache_keys.len();
if count > 0 {
self.remove_entries(&cache_keys);
tracing::debug!(type_name = %type_name, count = count, "Invalidated cache by type (Memory)");
}
count
}
CacheBackend::Redis { client, .. } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(_) => return 0,
};
let index_key = format!("{}{}", REDIS_TYPE_PREFIX, type_name);
let keys: Vec<String> = match redis::cmd("SMEMBERS")
.arg(&index_key)
.query_async(&mut conn)
.await
{
Ok(k) => k,
Err(_) => return 0,
};
if keys.is_empty() {
return 0;
}
let mut pipe = redis::pipe();
pipe.atomic();
for key in &keys {
pipe.del(key);
}
pipe.del(&index_key);
let _: () = pipe.query_async(&mut conn).await.unwrap_or(());
tracing::debug!(type_name = %type_name, count = keys.len(), "Invalidated cache by type (Redis)");
keys.len()
}
}
}
pub async fn invalidate_by_entity(&self, entity_key: &str) -> usize {
match &self.backend {
CacheBackend::Memory { entity_index, .. } => {
let cache_keys = {
let entity_idx = entity_index.read();
entity_idx.get(entity_key).cloned().unwrap_or_default()
};
let count = cache_keys.len();
if count > 0 {
self.remove_entries(&cache_keys);
tracing::debug!(entity_key = %entity_key, count = count, "Invalidated cache by entity (Memory)");
}
count
}
CacheBackend::Redis { client, .. } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(_) => return 0,
};
let index_key = format!("{}{}", REDIS_ENTITY_PREFIX, entity_key);
let keys: Vec<String> = match redis::cmd("SMEMBERS")
.arg(&index_key)
.query_async(&mut conn)
.await
{
Ok(k) => k,
Err(_) => return 0,
};
if keys.is_empty() {
return 0;
}
let mut pipe = redis::pipe();
pipe.atomic();
for key in &keys {
pipe.del(key);
}
pipe.del(&index_key);
let _: () = pipe.query_async(&mut conn).await.unwrap_or(());
tracing::debug!(entity_key = %entity_key, count = keys.len(), "Invalidated cache by entity (Redis)");
keys.len()
}
}
}
pub async fn invalidate_for_mutation(&self, mutation_response: &serde_json::Value) -> usize {
if !self.config.invalidate_on_mutation {
return 0;
}
let mut total_invalidated = 0;
let types = extract_types_from_response(mutation_response);
for type_name in types {
total_invalidated += self.invalidate_by_type(&type_name).await;
}
let entities = extract_entities_from_response(mutation_response);
for entity_key in entities {
total_invalidated += self.invalidate_by_entity(&entity_key).await;
}
if total_invalidated > 0 {
tracing::info!(
count = total_invalidated,
"Cache invalidated after mutation"
);
}
total_invalidated
}
pub async fn clear(&self) {
match &self.backend {
CacheBackend::Memory {
cache,
insertion_order,
type_index,
entity_index,
..
} => {
cache.write().clear();
insertion_order.write().clear();
type_index.write().clear();
entity_index.write().clear();
tracing::debug!("Response cache cleared (Memory)");
}
CacheBackend::Redis { client, .. } => {
let mut conn = match client.get_multiplexed_async_connection().await {
Ok(c) => c,
Err(_) => return,
};
let mut cursor: u64 = 0;
let mut total_deleted = 0;
loop {
let (next_cursor, keys): (u64, Vec<String>) = match redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(format!("{}*", REDIS_CACHE_PREFIX))
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await
{
Ok(result) => result,
Err(e) => {
tracing::error!("Redis SCAN error: {}", e);
break;
}
};
if !keys.is_empty() {
let _: () = redis::cmd("DEL")
.arg(&keys)
.query_async(&mut conn)
.await
.unwrap_or(());
total_deleted += keys.len();
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
for prefix in [REDIS_TYPE_PREFIX, REDIS_ENTITY_PREFIX] {
cursor = 0;
loop {
let (next_cursor, keys): (u64, Vec<String>) = match redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(format!("{}*", prefix))
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await
{
Ok(result) => result,
Err(_) => break,
};
if !keys.is_empty() {
let _: () = redis::cmd("DEL")
.arg(&keys)
.query_async(&mut conn)
.await
.unwrap_or(());
total_deleted += keys.len();
}
cursor = next_cursor;
if cursor == 0 {
break;
}
}
}
tracing::debug!(
"Response cache cleared - {} keys deleted (Redis SCAN)",
total_deleted
);
}
}
}
pub fn len(&self) -> usize {
match &self.backend {
CacheBackend::Memory { cache, .. } => cache.read().len(),
CacheBackend::Redis { .. } => 0, }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn stats(&self) -> CacheStats {
CacheStats {
size: self.len(),
max_size: self.config.max_size,
type_index_size: match &self.backend {
CacheBackend::Memory { type_index, .. } => type_index.read().len(),
_ => 0,
},
entity_index_size: match &self.backend {
CacheBackend::Memory { entity_index, .. } => entity_index.read().len(),
_ => 0,
},
}
}
fn evict_if_needed(&self) {
if let CacheBackend::Memory { .. } = &self.backend {
let current_len = self.len();
if current_len >= self.config.max_size {
let to_remove = current_len - self.config.max_size + 1;
if let CacheBackend::Memory {
insertion_order, ..
} = &self.backend
{
let mut order = insertion_order.write();
let drain_count = to_remove.min(order.len());
let keys_to_remove: Vec<String> = order.drain(..drain_count).collect();
drop(order); self.remove_entries_internal(&keys_to_remove);
}
}
}
}
fn remove_entries(&self, cache_keys: &HashSet<String>) {
if let CacheBackend::Memory {
insertion_order, ..
} = &self.backend
{
let keys_vec: Vec<String> = cache_keys.iter().cloned().collect();
self.remove_entries_internal(&keys_vec);
let mut order = insertion_order.write();
order.retain(|k| !cache_keys.contains(k));
}
}
fn remove_entries_internal(&self, cache_keys: &[String]) {
if let CacheBackend::Memory {
cache,
type_index,
entity_index,
..
} = &self.backend
{
let mut c = cache.write();
for key in cache_keys {
if let Some(entry) = c.remove(key) {
let mut type_idx = type_index.write();
for type_name in &entry.referenced_types {
if let Some(keys) = type_idx.get_mut(type_name) {
keys.remove(key);
if keys.is_empty() {
type_idx.remove(type_name);
}
}
}
drop(type_idx);
let mut entity_idx = entity_index.write();
for entity_key in &entry.referenced_entities {
if let Some(keys) = entity_idx.get_mut(entity_key) {
keys.remove(key);
if keys.is_empty() {
entity_idx.remove(entity_key);
}
}
}
}
}
}
}
}
impl Clone for ResponseCache {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
backend: match &self.backend {
CacheBackend::Memory {
cache,
insertion_order,
type_index,
entity_index,
} => {
let cache = cache.read().clone();
let order = insertion_order.read().clone();
let type_idx = type_index.read().clone();
let entity_idx = entity_index.read().clone();
CacheBackend::Memory {
cache: RwLock::new(cache),
insertion_order: RwLock::new(order),
type_index: RwLock::new(type_idx),
entity_index: RwLock::new(entity_idx),
}
}
CacheBackend::Redis { client } => CacheBackend::Redis {
client: client.clone(),
},
},
}
}
}
impl std::fmt::Debug for ResponseCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseCache")
.field("config", &self.config)
.field("cached_responses", &self.len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub size: usize,
pub max_size: usize,
pub type_index_size: usize,
pub entity_index_size: usize,
}
pub type SharedResponseCache = Arc<ResponseCache>;
pub fn create_response_cache(config: CacheConfig) -> SharedResponseCache {
Arc::new(ResponseCache::new(config))
}
pub fn is_mutation(query: &str) -> bool {
let query_trimmed = query.trim();
query_trimmed.starts_with("mutation")
|| query_trimmed.contains("mutation ")
|| query_trimmed.contains("mutation{")
}
fn extract_types_from_response(response: &serde_json::Value) -> HashSet<String> {
let mut types = HashSet::new();
extract_types_recursive(response, &mut types);
types
}
fn extract_types_recursive(value: &serde_json::Value, types: &mut HashSet<String>) {
match value {
serde_json::Value::Object(map) => {
if let Some(serde_json::Value::String(type_name)) = map.get("__typename") {
types.insert(type_name.clone());
}
for v in map.values() {
extract_types_recursive(v, types);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_types_recursive(item, types);
}
}
_ => {}
}
}
fn extract_entities_from_response(response: &serde_json::Value) -> HashSet<String> {
let mut entities = HashSet::new();
extract_entities_recursive(response, &mut entities);
entities
}
fn extract_entities_recursive(value: &serde_json::Value, entities: &mut HashSet<String>) {
match value {
serde_json::Value::Object(map) => {
let type_name = map.get("__typename").and_then(|t| t.as_str());
let id = map
.get("id")
.and_then(|i| i.as_str())
.or_else(|| map.get("_id").and_then(|i| i.as_str()));
if let (Some(tn), Some(id_val)) = (type_name, id) {
entities.insert(format!("{}#{}", tn, id_val));
}
for v in map.values() {
extract_entities_recursive(v, entities);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_entities_recursive(item, entities);
}
}
_ => {}
}
}
fn extract_entities_with_data(response: &serde_json::Value) -> Vec<(String, serde_json::Value)> {
let mut entities = Vec::new();
extract_entities_data_recursive(response, &mut entities);
entities
}
fn extract_entities_data_recursive(
value: &serde_json::Value,
entities: &mut Vec<(String, serde_json::Value)>,
) {
match value {
serde_json::Value::Object(map) => {
let type_name = map.get("__typename").and_then(|t| t.as_str());
let id = map
.get("id")
.and_then(|i| i.as_str())
.or_else(|| map.get("_id").and_then(|i| i.as_str()));
if let (Some(tn), Some(id_val)) = (type_name, id) {
entities.push((format!("{}#{}", tn, id_val), value.clone()));
}
for v in map.values() {
extract_entities_data_recursive(v, entities);
}
}
serde_json::Value::Array(arr) => {
for item in arr {
extract_entities_data_recursive(item, entities);
}
}
_ => {}
}
}
fn sort_json_value(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut sorted: Vec<_> = map.iter().collect();
sorted.sort_by_key(|(k, _)| *k);
let sorted_map: serde_json::Map<String, serde_json::Value> = sorted
.into_iter()
.map(|(k, v)| (k.clone(), sort_json_value(v)))
.collect();
serde_json::Value::Object(sorted_map)
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(sort_json_value).collect())
}
_ => value.clone(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_generation() {
let query = "{ users { id name } }";
let key1 = ResponseCache::generate_cache_key(query, None, None, &[]);
assert_eq!(
key1,
ResponseCache::generate_cache_key(query, None, None, &[])
);
let key2 = ResponseCache::generate_cache_key("{ products { id } }", None, None, &[]);
assert_ne!(key1, key2);
let key3 = ResponseCache::generate_cache_key("{ users { id name } }", None, None, &[]);
assert_eq!(key1, key3);
}
#[test]
fn test_cache_key_with_variables() {
let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
let vars1 = serde_json::json!({"id": "123"});
let vars2 = serde_json::json!({"id": "456"});
let key1 = ResponseCache::generate_cache_key(query, Some(&vars1), None, &[]);
let key2 = ResponseCache::generate_cache_key(query, Some(&vars2), None, &[]);
assert_ne!(key1, key2);
}
#[test]
fn test_cache_key_with_vary_headers() {
let query = "{ users { id name } }";
let headers1 = vec!["Authorization:Bearer TokenA".to_string()];
let headers2 = vec!["Authorization:Bearer TokenB".to_string()];
let key1 = ResponseCache::generate_cache_key(query, None, None, &headers1);
let key2 = ResponseCache::generate_cache_key(query, None, None, &headers2);
let key3 = ResponseCache::generate_cache_key(query, None, None, &[]);
assert_ne!(key1, key2);
assert_ne!(key1, key3);
let key1_again = ResponseCache::generate_cache_key(query, None, None, &headers1);
assert_eq!(key1, key1_again);
}
#[tokio::test]
async fn test_cache_put_and_get() {
let cache = ResponseCache::new(CacheConfig::default());
let cache_key = "test_key".to_string();
let response = serde_json::json!({"data": {"user": {"id": "1", "name": "Alice"}}});
cache
.put(
cache_key.clone(),
response.clone(),
HashSet::from(["User".to_string()]),
HashSet::from(["User#1".to_string()]),
)
.await;
match cache.get(&cache_key).await {
CacheLookupResult::Hit(entry) => {
assert_eq!(entry.data, response);
}
_ => panic!("Expected cache hit"),
}
}
#[tokio::test]
async fn test_cache_miss() {
let cache = ResponseCache::new(CacheConfig::default());
match cache.get("nonexistent").await {
CacheLookupResult::Miss => {}
_ => panic!("Expected cache miss"),
}
}
#[tokio::test]
async fn test_ttl_expiration() {
let config = CacheConfig {
max_size: 100,
default_ttl: Duration::from_secs(1),
stale_while_revalidate: None,
invalidate_on_mutation: true,
redis_url: None,
vary_headers: vec![],
smart_ttl_manager: None,
};
let cache = ResponseCache::new(config);
let cache_key = "expiring".to_string();
cache
.put(
cache_key.clone(),
serde_json::json!({"test": true}),
HashSet::new(),
HashSet::new(),
)
.await;
assert!(matches!(
cache.get(&cache_key).await,
CacheLookupResult::Hit(_)
));
std::thread::sleep(Duration::from_millis(1100));
assert!(matches!(
cache.get(&cache_key).await,
CacheLookupResult::Miss
));
}
#[tokio::test]
async fn test_stale_while_revalidate() {
let config = CacheConfig {
max_size: 100,
default_ttl: Duration::from_millis(50),
stale_while_revalidate: Some(Duration::from_millis(500)),
invalidate_on_mutation: true,
redis_url: None,
vary_headers: vec![],
smart_ttl_manager: None,
};
let cache = ResponseCache::new(config);
let cache_key = "stale_test".to_string();
cache
.put(
cache_key.clone(),
serde_json::json!({"test": true}),
HashSet::new(),
HashSet::new(),
)
.await;
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(matches!(
cache.get(&cache_key).await,
CacheLookupResult::Stale(_)
));
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(matches!(
cache.get(&cache_key).await,
CacheLookupResult::Miss
));
}
#[tokio::test]
async fn test_invalidate_by_type() {
let cache = ResponseCache::new(CacheConfig::default());
cache
.put(
"key1".to_string(),
serde_json::json!({}),
HashSet::from(["User".to_string()]),
HashSet::new(),
)
.await;
cache
.put(
"key2".to_string(),
serde_json::json!({}),
HashSet::from(["User".to_string()]),
HashSet::new(),
)
.await;
cache
.put(
"key3".to_string(),
serde_json::json!({}),
HashSet::from(["Product".to_string()]),
HashSet::new(),
)
.await;
assert_eq!(cache.len(), 3);
let invalidated = cache.invalidate_by_type("User").await;
assert_eq!(invalidated, 2);
assert_eq!(cache.len(), 1);
assert!(matches!(cache.get("key3").await, CacheLookupResult::Hit(_)));
}
#[tokio::test]
async fn test_invalidate_by_entity() {
let cache = ResponseCache::new(CacheConfig::default());
cache
.put(
"key1".to_string(),
serde_json::json!({}),
HashSet::new(),
HashSet::from(["User#123".to_string()]),
)
.await;
cache
.put(
"key2".to_string(),
serde_json::json!({}),
HashSet::new(),
HashSet::from(["User#456".to_string()]),
)
.await;
assert_eq!(cache.len(), 2);
let invalidated = cache.invalidate_by_entity("User#123").await;
assert_eq!(invalidated, 1);
assert_eq!(cache.len(), 1);
}
#[tokio::test]
async fn test_lru_eviction() {
let config = CacheConfig {
max_size: 3,
default_ttl: Duration::from_secs(60),
stale_while_revalidate: None,
invalidate_on_mutation: true,
redis_url: None,
vary_headers: vec![],
smart_ttl_manager: None,
};
let cache = ResponseCache::new(config);
for i in 0..4 {
cache
.put(
format!("key{}", i),
serde_json::json!({"num": i}),
HashSet::new(),
HashSet::new(),
)
.await;
}
assert!(matches!(cache.get("key0").await, CacheLookupResult::Miss));
assert!(matches!(cache.get("key1").await, CacheLookupResult::Hit(_)));
assert!(matches!(cache.get("key2").await, CacheLookupResult::Hit(_)));
assert!(matches!(cache.get("key3").await, CacheLookupResult::Hit(_)));
}
#[test]
fn test_is_mutation() {
assert!(is_mutation("mutation { createUser { id } }"));
assert!(is_mutation("mutation CreateUser { createUser { id } }"));
assert!(is_mutation(" mutation { test }"));
assert!(!is_mutation("query { users { id } }"));
assert!(!is_mutation("{ users { id } }"));
}
#[test]
fn test_extract_types() {
let response = serde_json::json!({
"data": {
"user": {
"__typename": "User",
"id": "1",
"posts": [
{"__typename": "Post", "id": "10"},
{"__typename": "Post", "id": "11"}
]
}
}
});
let types = extract_types_from_response(&response);
assert!(types.contains("User"));
assert!(types.contains("Post"));
assert_eq!(types.len(), 2);
}
#[test]
fn test_extract_entities() {
let response = serde_json::json!({
"data": {
"user": {
"__typename": "User",
"id": "123",
"friend": {
"__typename": "User",
"id": "456"
}
}
}
});
let entities = extract_entities_from_response(&response);
assert!(entities.contains("User#123"));
assert!(entities.contains("User#456"));
assert_eq!(entities.len(), 2);
}
#[tokio::test]
async fn test_clear_cache() {
let cache = ResponseCache::new(CacheConfig::default());
cache
.put(
"key1".to_string(),
serde_json::json!({}),
HashSet::from(["User".to_string()]),
HashSet::from(["User#1".to_string()]),
)
.await;
assert_eq!(cache.len(), 1);
cache.clear().await;
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[tokio::test]
async fn test_put_all_entities() {
let cache = ResponseCache::new(CacheConfig::default());
let response = serde_json::json!({
"data": {
"user": {
"__typename": "User",
"id": "123",
"name": "Alice",
"friend": {
"__typename": "User",
"id": "456",
"name": "Bob"
}
}
}
});
cache.put_all_entities(&response, None).await;
let alice = cache
.get_entity("User", "123")
.await
.expect("Alice should be cached");
assert_eq!(alice["name"], "Alice");
let bob = cache
.get_entity("User", "456")
.await
.expect("Bob should be cached");
assert_eq!(bob["name"], "Bob");
}
#[test]
fn test_cache_config_default() {
let config = CacheConfig::default();
assert_eq!(config.max_size, 10_000);
assert_eq!(config.default_ttl, std::time::Duration::from_secs(60));
assert!(config.invalidate_on_mutation);
assert!(config.redis_url.is_none());
}
#[test]
fn test_cache_config_clone() {
let config1 = CacheConfig::default();
let config2 = config1.clone();
assert_eq!(config1.max_size, config2.max_size);
assert_eq!(config1.default_ttl, config2.default_ttl);
}
#[test]
fn test_cache_config_with_stale_while_revalidate() {
let mut config = CacheConfig::default();
config.stale_while_revalidate = Some(std::time::Duration::from_secs(30));
assert!(config.stale_while_revalidate.is_some());
}
#[test]
fn test_cache_config_vary_headers() {
let mut config = CacheConfig::default();
config.vary_headers.push("X-Tenant-ID".to_string());
assert!(config.vary_headers.contains(&"Authorization".to_string()));
assert!(config.vary_headers.contains(&"X-Tenant-ID".to_string()));
}
#[tokio::test]
async fn test_cache_len() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn test_cache_is_empty() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
assert!(cache.is_empty());
}
#[tokio::test]
async fn test_invalidate_all_types() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let count = cache.invalidate_by_type("User").await;
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_invalidate_all_entities() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let count = cache.invalidate_by_entity("User#123").await;
assert_eq!(count, 0);
}
#[test]
fn test_cached_response_expiration() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now() - std::time::Duration::from_secs(120),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
assert!(entry.is_expired());
}
#[test]
fn test_cached_response_not_expired() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now(),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
assert!(!entry.is_expired());
}
#[test]
fn test_cached_response_stale_but_usable() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now() - std::time::Duration::from_secs(70),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
assert!(entry.is_stale_but_usable(std::time::Duration::from_secs(30)));
}
#[test]
fn test_cached_response_too_stale() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now() - std::time::Duration::from_secs(120),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
assert!(!entry.is_stale_but_usable(std::time::Duration::from_secs(30)));
}
#[test]
fn test_generate_cache_key_basic() {
let key = ResponseCache::generate_cache_key("{ hello }", None, None, &[]);
assert_eq!(key.len(), 64); }
#[test]
fn test_generate_cache_key_with_variables() {
let vars = serde_json::json!({"id": "123"});
let key =
ResponseCache::generate_cache_key("{ user(id: $id) { name } }", Some(&vars), None, &[]);
assert_eq!(key.len(), 64);
}
#[test]
fn test_generate_cache_key_with_operation_name() {
let key = ResponseCache::generate_cache_key("{ hello }", None, Some("GetHello"), &[]);
assert_eq!(key.len(), 64);
}
#[test]
fn test_generate_cache_key_deterministic() {
let key1 = ResponseCache::generate_cache_key("{ hello }", None, None, &[]);
let key2 = ResponseCache::generate_cache_key("{ hello }", None, None, &[]);
assert_eq!(key1, key2);
}
#[test]
fn test_cache_config_invalidation_disabled() {
let mut config = CacheConfig::default();
config.invalidate_on_mutation = false;
assert!(!config.invalidate_on_mutation);
}
#[tokio::test]
async fn test_cache_stats() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let stats = cache.stats();
assert_eq!(stats.size, 0);
assert_eq!(stats.max_size, 10_000);
}
#[tokio::test]
async fn test_cache_with_smart_ttl() {
let mut config = CacheConfig::default();
config.smart_ttl_manager = None;
assert!(config.smart_ttl_manager.is_none());
}
#[tokio::test]
async fn test_invalidate_for_mutation_disabled() {
let mut config = CacheConfig::default();
config.invalidate_on_mutation = false;
let cache = ResponseCache::new(config);
let mutation_response = serde_json::json!({
"data": {
"updateUser": {
"__typename": "User",
"id": "123"
}
}
});
let count = cache.invalidate_for_mutation(&mutation_response).await;
assert_eq!(count, 0);
}
#[test]
fn test_cache_lookup_result_debug() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now(),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
let result = CacheLookupResult::Hit(entry);
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("Hit"));
}
#[test]
fn test_cached_response_clone() {
let entry1 = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now(),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
let entry2 = entry1.clone();
assert_eq!(entry1.ttl_secs, entry2.ttl_secs);
}
#[test]
fn test_cached_response_debug() {
let entry = CachedResponse {
data: serde_json::json!({"test": "data"}),
created_at: std::time::SystemTime::now(),
ttl_secs: 60,
referenced_types: HashSet::new(),
referenced_entities: HashSet::new(),
};
let debug_str = format!("{:?}", entry);
assert!(debug_str.contains("CachedResponse"));
}
#[test]
fn test_cache_config_debug() {
let config = CacheConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("CacheConfig"));
}
#[tokio::test]
async fn test_response_cache_clone() {
let config = CacheConfig::default();
let cache1 = ResponseCache::new(config);
let cache2 = cache1.clone();
assert_eq!(cache1.len(), cache2.len());
}
#[test]
fn test_generate_cache_key_with_vary_headers() {
let extra = vec!["Bearer token123".to_string()];
let key = ResponseCache::generate_cache_key("{ hello }", None, None, &extra);
assert_eq!(key.len(), 64);
}
#[test]
fn test_generate_cache_key_whitespace_normalization() {
let key1 = ResponseCache::generate_cache_key("{ hello }", None, None, &[]);
let key2 = ResponseCache::generate_cache_key("{ hello }", None, None, &[]);
assert_eq!(key1, key2); }
#[tokio::test]
async fn test_put_with_ttl() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let cache_key = "test_key".to_string();
let response = serde_json::json!({"data": {"hello": "world"}});
cache
.put_with_ttl(
cache_key.clone(),
response,
HashSet::new(),
HashSet::new(),
120,
)
.await;
let result = cache.get(&cache_key).await;
assert!(matches!(result, CacheLookupResult::Hit(_)));
}
#[tokio::test]
async fn test_put_basic() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let cache_key = "test_basic".to_string();
let response = serde_json::json!({"data": "test"});
cache
.put(cache_key.clone(), response, HashSet::new(), HashSet::new())
.await;
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_config_redis_url() {
let mut config = CacheConfig::default();
config.redis_url = Some("redis://localhost:6379".to_string());
assert!(config.redis_url.is_some());
}
#[tokio::test]
async fn test_lru_eviction_strictly_enforced() {
let config = CacheConfig {
max_size: 2,
..Default::default()
};
let cache = ResponseCache::new(config);
for i in 1..=3 {
let key = format!("key_{}", i);
let response = serde_json::json!({"data": i});
cache
.put(key, response, HashSet::new(), HashSet::new())
.await;
}
assert_eq!(cache.len(), 2);
assert!(matches!(cache.get("key_1").await, CacheLookupResult::Miss));
assert!(matches!(
cache.get("key_2").await,
CacheLookupResult::Hit(_)
));
assert!(matches!(
cache.get("key_3").await,
CacheLookupResult::Hit(_)
));
}
#[tokio::test]
async fn test_lru_eviction_access_updates_order() {
let config = CacheConfig {
max_size: 2,
..Default::default()
};
let cache = ResponseCache::new(config);
cache
.put(
"key_1".to_string(),
serde_json::json!({"data": 1}),
HashSet::new(),
HashSet::new(),
)
.await;
cache
.put(
"key_2".to_string(),
serde_json::json!({"data": 2}),
HashSet::new(),
HashSet::new(),
)
.await;
cache
.put(
"key_1".to_string(),
serde_json::json!({"data": 1}),
HashSet::new(),
HashSet::new(),
)
.await;
cache
.put(
"key_3".to_string(),
serde_json::json!({"data": 3}),
HashSet::new(),
HashSet::new(),
)
.await;
assert!(matches!(
cache.get("key_1").await,
CacheLookupResult::Hit(_)
)); assert!(matches!(cache.get("key_2").await, CacheLookupResult::Miss)); assert!(matches!(
cache.get("key_3").await,
CacheLookupResult::Hit(_)
)); }
#[tokio::test]
async fn test_backend_memory_clear() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
cache
.put(
"key_1".to_string(),
serde_json::json!({"data": 1}),
HashSet::new(),
HashSet::new(),
)
.await;
assert!(!cache.is_empty());
cache.clear().await;
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn test_invalidation_removes_from_indices() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let mut types = HashSet::new();
types.insert("User".to_string());
cache
.put(
"key_1".to_string(),
serde_json::json!({"data": 1}),
types.clone(),
HashSet::new(),
)
.await;
let stats = cache.stats();
assert_eq!(stats.size, 1);
assert!(stats.type_index_size > 0);
cache.invalidate_by_type("User").await;
assert_eq!(cache.len(), 0);
let stats_after = cache.stats();
assert_eq!(stats_after.type_index_size, 0);
}
#[tokio::test]
async fn test_invalidate_by_entity_multiple_entries() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
let types = HashSet::new();
let mut entities = HashSet::new();
entities.insert("User#123".to_string());
cache
.put(
"query_1".to_string(),
serde_json::json!({"data": 1}),
types.clone(),
entities.clone(),
)
.await;
cache
.put(
"query_2".to_string(),
serde_json::json!({"data": 2}),
types.clone(),
entities.clone(),
)
.await;
assert_eq!(cache.len(), 2);
let count = cache.invalidate_by_entity("User#123").await;
assert_eq!(count, 2);
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn test_explicit_ttl_expiration_short() {
let config = CacheConfig::default();
let cache = ResponseCache::new(config);
cache
.put_with_ttl(
"short_ttl".to_string(),
serde_json::json!({"data": "expire"}),
HashSet::new(),
HashSet::new(),
1,
)
.await;
assert!(matches!(
cache.get("short_ttl").await,
CacheLookupResult::Hit(_)
));
tokio::time::sleep(tokio::time::Duration::from_millis(1100)).await;
assert!(matches!(
cache.get("short_ttl").await,
CacheLookupResult::Miss
));
}
}