1use bytes::Bytes;
6use std::hash::{Hash, Hasher};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10use super::normalizer::NormalizedQuery;
11use super::CacheContext;
12
13#[derive(Debug, Clone)]
15pub struct CachedResult {
16 pub data: Bytes,
18
19 pub row_count: usize,
21
22 pub cached_at: Instant,
24
25 pub ttl: Duration,
27
28 pub tables: Vec<String>,
30
31 pub execution_time: Duration,
33}
34
35impl CachedResult {
36 pub fn new(
38 data: Bytes,
39 row_count: usize,
40 ttl: Duration,
41 tables: Vec<String>,
42 execution_time: Duration,
43 ) -> Self {
44 Self {
45 data,
46 row_count,
47 cached_at: Instant::now(),
48 ttl,
49 tables,
50 execution_time,
51 }
52 }
53
54 pub fn is_expired(&self) -> bool {
56 self.cached_at.elapsed() > self.ttl
57 }
58
59 pub fn age(&self) -> Duration {
61 self.cached_at.elapsed()
62 }
63
64 pub fn remaining_ttl(&self) -> Duration {
66 self.ttl.saturating_sub(self.cached_at.elapsed())
67 }
68
69 pub fn size(&self) -> usize {
71 self.data.len()
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct CacheKey {
78 pub query_hash: u64,
80
81 pub database: String,
83
84 pub user: Option<String>,
86
87 pub branch: Option<String>,
89
90 cached_hash: u64,
92}
93
94impl CacheKey {
95 pub fn new(normalized: &NormalizedQuery, context: &CacheContext) -> Self {
97 let query_hash = normalized.hash;
98
99 let mut hasher = std::collections::hash_map::DefaultHasher::new();
101 query_hash.hash(&mut hasher);
102 context.database.hash(&mut hasher);
103 context.user.hash(&mut hasher);
104 context.branch.hash(&mut hasher);
105 let cached_hash = hasher.finish();
106
107 Self {
108 query_hash,
109 database: context.database.clone(),
110 user: context.user.clone(),
111 branch: context.branch.clone(),
112 cached_hash,
113 }
114 }
115
116 pub fn from_parts(
118 query_hash: u64,
119 database: String,
120 user: Option<String>,
121 branch: Option<String>,
122 ) -> Self {
123 let mut hasher = std::collections::hash_map::DefaultHasher::new();
124 query_hash.hash(&mut hasher);
125 database.hash(&mut hasher);
126 user.hash(&mut hasher);
127 branch.hash(&mut hasher);
128 let cached_hash = hasher.finish();
129
130 Self {
131 query_hash,
132 database,
133 user,
134 branch,
135 cached_hash,
136 }
137 }
138
139 pub fn hash_value(&self) -> u64 {
141 self.cached_hash
142 }
143}
144
145impl Hash for CacheKey {
146 fn hash<H: Hasher>(&self, state: &mut H) {
147 state.write_u64(self.cached_hash);
148 }
149}
150
151impl PartialEq for CacheKey {
152 fn eq(&self, other: &Self) -> bool {
153 self.cached_hash == other.cached_hash
154 && self.query_hash == other.query_hash
155 && self.database == other.database
156 && self.user == other.user
157 && self.branch == other.branch
158 }
159}
160
161impl Eq for CacheKey {}
162
163#[derive(Debug)]
171pub struct L1Entry {
172 pub result: CachedResult,
174
175 pub query: String,
177
178 pub access_count: AtomicU64,
180
181 pub last_access: Instant,
184}
185
186impl L1Entry {
187 pub fn new(query: String, result: CachedResult) -> Self {
189 Self {
190 result,
191 query,
192 access_count: AtomicU64::new(1),
193 last_access: Instant::now(),
194 }
195 }
196
197 pub fn touch(&self) {
199 self.access_count.fetch_add(1, Ordering::Relaxed);
200 }
201
202 pub fn access_count(&self) -> u64 {
204 self.access_count.load(Ordering::Relaxed)
205 }
206
207 pub fn is_expired(&self) -> bool {
209 self.result.is_expired()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct L2Entry {
216 pub result: CachedResult,
218
219 pub fingerprint: String,
221
222 pub key: CacheKey,
224
225 pub access_count: u64,
227
228 pub last_access: Instant,
230
231 pub memory_size: usize,
233}
234
235impl L2Entry {
236 pub fn new(key: CacheKey, fingerprint: String, result: CachedResult) -> Self {
238 let memory_size = result.size()
239 + fingerprint.len()
240 + std::mem::size_of::<Self>()
241 + key.database.len()
242 + key.user.as_ref().map(|s| s.len()).unwrap_or(0)
243 + key.branch.as_ref().map(|s| s.len()).unwrap_or(0);
244
245 Self {
246 result,
247 fingerprint,
248 key,
249 access_count: 1,
250 last_access: Instant::now(),
251 memory_size,
252 }
253 }
254
255 pub fn touch(&mut self) {
257 self.access_count += 1;
258 self.last_access = Instant::now();
259 }
260
261 pub fn is_expired(&self) -> bool {
263 self.result.is_expired()
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct L3Entry {
270 pub result: CachedResult,
272
273 pub query: String,
275
276 pub embedding: Vec<f32>,
278
279 pub context: CacheContext,
281
282 pub access_count: u64,
284
285 pub last_access: Instant,
287}
288
289impl L3Entry {
290 pub fn new(query: String, embedding: Vec<f32>, context: CacheContext, result: CachedResult) -> Self {
292 Self {
293 result,
294 query,
295 embedding,
296 context,
297 access_count: 1,
298 last_access: Instant::now(),
299 }
300 }
301
302 pub fn touch(&mut self) {
304 self.access_count += 1;
305 self.last_access = Instant::now();
306 }
307
308 pub fn is_expired(&self) -> bool {
310 self.result.is_expired()
311 }
312
313 pub fn similarity(&self, other: &[f32]) -> f32 {
315 if self.embedding.len() != other.len() {
316 return 0.0;
317 }
318
319 let mut dot_product = 0.0f32;
320 let mut norm_a = 0.0f32;
321 let mut norm_b = 0.0f32;
322
323 for (a, b) in self.embedding.iter().zip(other.iter()) {
324 dot_product += a * b;
325 norm_a += a * a;
326 norm_b += b * b;
327 }
328
329 let norm_a = norm_a.sqrt();
330 let norm_b = norm_b.sqrt();
331
332 if norm_a == 0.0 || norm_b == 0.0 {
333 return 0.0;
334 }
335
336 dot_product / (norm_a * norm_b)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_cached_result_expiry() {
346 let result = CachedResult::new(
347 Bytes::from("test"),
348 1,
349 Duration::from_millis(10),
350 vec!["users".to_string()],
351 Duration::from_millis(5),
352 );
353
354 assert!(!result.is_expired());
355
356 std::thread::sleep(Duration::from_millis(15));
358 assert!(result.is_expired());
359 }
360
361 #[test]
362 fn test_cache_key_equality() {
363 let ctx1 = CacheContext {
364 database: "db1".to_string(),
365 user: Some("user1".to_string()),
366 branch: None,
367 connection_id: None,
368 };
369
370 let ctx2 = CacheContext {
371 database: "db1".to_string(),
372 user: Some("user1".to_string()),
373 branch: None,
374 connection_id: Some(123), };
376
377 let normalized = NormalizedQuery {
378 fingerprint: "SELECT * FROM users WHERE id = ?".to_string(),
379 hash: 12345,
380 tables: vec!["users".to_string()],
381 parameters: vec!["1".to_string()],
382 };
383
384 let key1 = CacheKey::new(&normalized, &ctx1);
385 let key2 = CacheKey::new(&normalized, &ctx2);
386
387 assert_eq!(key1, key2);
388 }
389
390 #[test]
391 fn test_cache_key_different_users() {
392 let ctx1 = CacheContext {
393 database: "db1".to_string(),
394 user: Some("user1".to_string()),
395 branch: None,
396 connection_id: None,
397 };
398
399 let ctx2 = CacheContext {
400 database: "db1".to_string(),
401 user: Some("user2".to_string()),
402 branch: None,
403 connection_id: None,
404 };
405
406 let normalized = NormalizedQuery {
407 fingerprint: "SELECT * FROM users".to_string(),
408 hash: 12345,
409 tables: vec!["users".to_string()],
410 parameters: vec![],
411 };
412
413 let key1 = CacheKey::new(&normalized, &ctx1);
414 let key2 = CacheKey::new(&normalized, &ctx2);
415
416 assert_ne!(key1, key2);
418 }
419
420 #[test]
421 fn test_l3_entry_similarity() {
422 let result = CachedResult::new(
423 Bytes::from("test"),
424 1,
425 Duration::from_secs(60),
426 vec![],
427 Duration::from_millis(5),
428 );
429
430 let ctx = CacheContext::default();
431
432 let entry = L3Entry::new(
433 "SELECT * FROM users".to_string(),
434 vec![1.0, 0.0, 0.0],
435 ctx,
436 result,
437 );
438
439 assert!((entry.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 0.001);
441
442 assert!((entry.similarity(&[0.0, 1.0, 0.0])).abs() < 0.001);
444
445 assert!((entry.similarity(&[-1.0, 0.0, 0.0]) + 1.0).abs() < 0.001);
447 }
448
449 #[test]
450 fn test_l1_entry_touch() {
451 let result = CachedResult::new(
452 Bytes::from("test"),
453 1,
454 Duration::from_secs(60),
455 vec![],
456 Duration::from_millis(5),
457 );
458
459 let entry = L1Entry::new("SELECT 1".to_string(), result);
460 assert_eq!(entry.access_count(), 1);
461
462 entry.touch();
463 assert_eq!(entry.access_count(), 2);
464
465 entry.touch();
466 assert_eq!(entry.access_count(), 3);
467 }
468}