use bytes::Bytes;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use super::normalizer::NormalizedQuery;
use super::CacheContext;
#[derive(Debug, Clone)]
pub struct CachedResult {
pub data: Bytes,
pub row_count: usize,
pub cached_at: Instant,
pub ttl: Duration,
pub tables: Vec<String>,
pub execution_time: Duration,
}
impl CachedResult {
pub fn new(
data: Bytes,
row_count: usize,
ttl: Duration,
tables: Vec<String>,
execution_time: Duration,
) -> Self {
Self {
data,
row_count,
cached_at: Instant::now(),
ttl,
tables,
execution_time,
}
}
pub fn is_expired(&self) -> bool {
self.cached_at.elapsed() > self.ttl
}
pub fn age(&self) -> Duration {
self.cached_at.elapsed()
}
pub fn remaining_ttl(&self) -> Duration {
self.ttl.saturating_sub(self.cached_at.elapsed())
}
pub fn size(&self) -> usize {
self.data.len()
}
}
#[derive(Debug, Clone)]
pub struct CacheKey {
pub query_hash: u64,
pub database: String,
pub user: Option<String>,
pub branch: Option<String>,
cached_hash: u64,
}
impl CacheKey {
pub fn new(normalized: &NormalizedQuery, context: &CacheContext) -> Self {
let query_hash = normalized.hash;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query_hash.hash(&mut hasher);
context.database.hash(&mut hasher);
context.user.hash(&mut hasher);
context.branch.hash(&mut hasher);
let cached_hash = hasher.finish();
Self {
query_hash,
database: context.database.clone(),
user: context.user.clone(),
branch: context.branch.clone(),
cached_hash,
}
}
pub fn from_parts(
query_hash: u64,
database: String,
user: Option<String>,
branch: Option<String>,
) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query_hash.hash(&mut hasher);
database.hash(&mut hasher);
user.hash(&mut hasher);
branch.hash(&mut hasher);
let cached_hash = hasher.finish();
Self {
query_hash,
database,
user,
branch,
cached_hash,
}
}
pub fn hash_value(&self) -> u64 {
self.cached_hash
}
}
impl Hash for CacheKey {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u64(self.cached_hash);
}
}
impl PartialEq for CacheKey {
fn eq(&self, other: &Self) -> bool {
self.cached_hash == other.cached_hash
&& self.query_hash == other.query_hash
&& self.database == other.database
&& self.user == other.user
&& self.branch == other.branch
}
}
impl Eq for CacheKey {}
#[derive(Debug)]
pub struct L1Entry {
pub result: CachedResult,
pub query: String,
pub access_count: AtomicU64,
pub last_access: Instant,
}
impl L1Entry {
pub fn new(query: String, result: CachedResult) -> Self {
Self {
result,
query,
access_count: AtomicU64::new(1),
last_access: Instant::now(),
}
}
pub fn touch(&self) {
self.access_count.fetch_add(1, Ordering::Relaxed);
}
pub fn access_count(&self) -> u64 {
self.access_count.load(Ordering::Relaxed)
}
pub fn is_expired(&self) -> bool {
self.result.is_expired()
}
}
#[derive(Debug, Clone)]
pub struct L2Entry {
pub result: CachedResult,
pub fingerprint: String,
pub key: CacheKey,
pub access_count: u64,
pub last_access: Instant,
pub memory_size: usize,
}
impl L2Entry {
pub fn new(key: CacheKey, fingerprint: String, result: CachedResult) -> Self {
let memory_size = result.size()
+ fingerprint.len()
+ std::mem::size_of::<Self>()
+ key.database.len()
+ key.user.as_ref().map(|s| s.len()).unwrap_or(0)
+ key.branch.as_ref().map(|s| s.len()).unwrap_or(0);
Self {
result,
fingerprint,
key,
access_count: 1,
last_access: Instant::now(),
memory_size,
}
}
pub fn touch(&mut self) {
self.access_count += 1;
self.last_access = Instant::now();
}
pub fn is_expired(&self) -> bool {
self.result.is_expired()
}
}
#[derive(Debug, Clone)]
pub struct L3Entry {
pub result: CachedResult,
pub query: String,
pub embedding: Vec<f32>,
pub context: CacheContext,
pub access_count: u64,
pub last_access: Instant,
}
impl L3Entry {
pub fn new(query: String, embedding: Vec<f32>, context: CacheContext, result: CachedResult) -> Self {
Self {
result,
query,
embedding,
context,
access_count: 1,
last_access: Instant::now(),
}
}
pub fn touch(&mut self) {
self.access_count += 1;
self.last_access = Instant::now();
}
pub fn is_expired(&self) -> bool {
self.result.is_expired()
}
pub fn similarity(&self, other: &[f32]) -> f32 {
if self.embedding.len() != other.len() {
return 0.0;
}
let mut dot_product = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (a, b) in self.embedding.iter().zip(other.iter()) {
dot_product += a * b;
norm_a += a * a;
norm_b += b * b;
}
let norm_a = norm_a.sqrt();
let norm_b = norm_b.sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cached_result_expiry() {
let result = CachedResult::new(
Bytes::from("test"),
1,
Duration::from_millis(10),
vec!["users".to_string()],
Duration::from_millis(5),
);
assert!(!result.is_expired());
std::thread::sleep(Duration::from_millis(15));
assert!(result.is_expired());
}
#[test]
fn test_cache_key_equality() {
let ctx1 = CacheContext {
database: "db1".to_string(),
user: Some("user1".to_string()),
branch: None,
connection_id: None,
};
let ctx2 = CacheContext {
database: "db1".to_string(),
user: Some("user1".to_string()),
branch: None,
connection_id: Some(123), };
let normalized = NormalizedQuery {
fingerprint: "SELECT * FROM users WHERE id = ?".to_string(),
hash: 12345,
tables: vec!["users".to_string()],
parameters: vec!["1".to_string()],
};
let key1 = CacheKey::new(&normalized, &ctx1);
let key2 = CacheKey::new(&normalized, &ctx2);
assert_eq!(key1, key2);
}
#[test]
fn test_cache_key_different_users() {
let ctx1 = CacheContext {
database: "db1".to_string(),
user: Some("user1".to_string()),
branch: None,
connection_id: None,
};
let ctx2 = CacheContext {
database: "db1".to_string(),
user: Some("user2".to_string()),
branch: None,
connection_id: None,
};
let normalized = NormalizedQuery {
fingerprint: "SELECT * FROM users".to_string(),
hash: 12345,
tables: vec!["users".to_string()],
parameters: vec![],
};
let key1 = CacheKey::new(&normalized, &ctx1);
let key2 = CacheKey::new(&normalized, &ctx2);
assert_ne!(key1, key2);
}
#[test]
fn test_l3_entry_similarity() {
let result = CachedResult::new(
Bytes::from("test"),
1,
Duration::from_secs(60),
vec![],
Duration::from_millis(5),
);
let ctx = CacheContext::default();
let entry = L3Entry::new(
"SELECT * FROM users".to_string(),
vec![1.0, 0.0, 0.0],
ctx,
result,
);
assert!((entry.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 0.001);
assert!((entry.similarity(&[0.0, 1.0, 0.0])).abs() < 0.001);
assert!((entry.similarity(&[-1.0, 0.0, 0.0]) + 1.0).abs() < 0.001);
}
#[test]
fn test_l1_entry_touch() {
let result = CachedResult::new(
Bytes::from("test"),
1,
Duration::from_secs(60),
vec![],
Duration::from_millis(5),
);
let entry = L1Entry::new("SELECT 1".to_string(), result);
assert_eq!(entry.access_count(), 1);
entry.touch();
assert_eq!(entry.access_count(), 2);
entry.touch();
assert_eq!(entry.access_count(), 3);
}
}