1use bytes::Bytes;
6use std::hash::{Hash, Hasher};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10use super::normalizer::NormalizedQuery;
11use super::CacheContext;
12
13#[derive(Debug, Clone)]
15pub struct CachedResult {
16 pub data: Bytes,
18
19 pub row_count: usize,
21
22 pub cached_at: Instant,
24
25 pub ttl: Duration,
27
28 pub tables: Vec<String>,
30
31 pub execution_time: Duration,
33}
34
35impl CachedResult {
36 pub fn new(
38 data: Bytes,
39 row_count: usize,
40 ttl: Duration,
41 tables: Vec<String>,
42 execution_time: Duration,
43 ) -> Self {
44 Self {
45 data,
46 row_count,
47 cached_at: Instant::now(),
48 ttl,
49 tables,
50 execution_time,
51 }
52 }
53
54 pub fn is_expired(&self) -> bool {
56 self.cached_at.elapsed() > self.ttl
57 }
58
59 pub fn age(&self) -> Duration {
61 self.cached_at.elapsed()
62 }
63
64 pub fn remaining_ttl(&self) -> Duration {
66 self.ttl.saturating_sub(self.cached_at.elapsed())
67 }
68
69 pub fn size(&self) -> usize {
71 self.data.len()
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct CacheKey {
78 pub query_hash: u64,
80
81 pub database: String,
83
84 pub user: Option<String>,
86
87 pub branch: Option<String>,
89
90 cached_hash: u64,
92}
93
94impl CacheKey {
95 pub fn new(normalized: &NormalizedQuery, context: &CacheContext) -> Self {
97 let query_hash = normalized.hash;
98
99 let mut hasher = std::collections::hash_map::DefaultHasher::new();
101 query_hash.hash(&mut hasher);
102 context.database.hash(&mut hasher);
103 context.user.hash(&mut hasher);
104 context.branch.hash(&mut hasher);
105 let cached_hash = hasher.finish();
106
107 Self {
108 query_hash,
109 database: context.database.clone(),
110 user: context.user.clone(),
111 branch: context.branch.clone(),
112 cached_hash,
113 }
114 }
115
116 pub fn from_parts(
118 query_hash: u64,
119 database: String,
120 user: Option<String>,
121 branch: Option<String>,
122 ) -> Self {
123 let mut hasher = std::collections::hash_map::DefaultHasher::new();
124 query_hash.hash(&mut hasher);
125 database.hash(&mut hasher);
126 user.hash(&mut hasher);
127 branch.hash(&mut hasher);
128 let cached_hash = hasher.finish();
129
130 Self {
131 query_hash,
132 database,
133 user,
134 branch,
135 cached_hash,
136 }
137 }
138
139 pub fn hash_value(&self) -> u64 {
141 self.cached_hash
142 }
143}
144
145impl Hash for CacheKey {
146 fn hash<H: Hasher>(&self, state: &mut H) {
147 state.write_u64(self.cached_hash);
148 }
149}
150
151impl PartialEq for CacheKey {
152 fn eq(&self, other: &Self) -> bool {
153 self.cached_hash == other.cached_hash
154 && self.query_hash == other.query_hash
155 && self.database == other.database
156 && self.user == other.user
157 && self.branch == other.branch
158 }
159}
160
161impl Eq for CacheKey {}
162
163#[derive(Debug)]
171pub struct L1Entry {
172 pub result: CachedResult,
174
175 pub query: String,
177
178 pub access_count: AtomicU64,
180
181 pub last_access: Instant,
184}
185
186impl L1Entry {
187 pub fn new(query: String, result: CachedResult) -> Self {
189 Self {
190 result,
191 query,
192 access_count: AtomicU64::new(1),
193 last_access: Instant::now(),
194 }
195 }
196
197 pub fn touch(&self) {
199 self.access_count.fetch_add(1, Ordering::Relaxed);
200 }
201
202 pub fn access_count(&self) -> u64 {
204 self.access_count.load(Ordering::Relaxed)
205 }
206
207 pub fn is_expired(&self) -> bool {
209 self.result.is_expired()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct L2Entry {
216 pub result: CachedResult,
218
219 pub fingerprint: String,
221
222 pub key: CacheKey,
224
225 pub access_count: u64,
227
228 pub last_access: Instant,
230
231 pub memory_size: usize,
233}
234
235impl L2Entry {
236 pub fn new(key: CacheKey, fingerprint: String, result: CachedResult) -> Self {
238 let memory_size = result.size()
239 + fingerprint.len()
240 + std::mem::size_of::<Self>()
241 + key.database.len()
242 + key.user.as_ref().map(|s| s.len()).unwrap_or(0)
243 + key.branch.as_ref().map(|s| s.len()).unwrap_or(0);
244
245 Self {
246 result,
247 fingerprint,
248 key,
249 access_count: 1,
250 last_access: Instant::now(),
251 memory_size,
252 }
253 }
254
255 pub fn touch(&mut self) {
257 self.access_count += 1;
258 self.last_access = Instant::now();
259 }
260
261 pub fn is_expired(&self) -> bool {
263 self.result.is_expired()
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct L3Entry {
270 pub result: CachedResult,
272
273 pub query: String,
275
276 pub embedding: Vec<f32>,
278
279 pub context: CacheContext,
281
282 pub access_count: u64,
284
285 pub last_access: Instant,
287}
288
289impl L3Entry {
290 pub fn new(
292 query: String,
293 embedding: Vec<f32>,
294 context: CacheContext,
295 result: CachedResult,
296 ) -> Self {
297 Self {
298 result,
299 query,
300 embedding,
301 context,
302 access_count: 1,
303 last_access: Instant::now(),
304 }
305 }
306
307 pub fn touch(&mut self) {
309 self.access_count += 1;
310 self.last_access = Instant::now();
311 }
312
313 pub fn is_expired(&self) -> bool {
315 self.result.is_expired()
316 }
317
318 pub fn similarity(&self, other: &[f32]) -> f32 {
320 if self.embedding.len() != other.len() {
321 return 0.0;
322 }
323
324 let mut dot_product = 0.0f32;
325 let mut norm_a = 0.0f32;
326 let mut norm_b = 0.0f32;
327
328 for (a, b) in self.embedding.iter().zip(other.iter()) {
329 dot_product += a * b;
330 norm_a += a * a;
331 norm_b += b * b;
332 }
333
334 let norm_a = norm_a.sqrt();
335 let norm_b = norm_b.sqrt();
336
337 if norm_a == 0.0 || norm_b == 0.0 {
338 return 0.0;
339 }
340
341 dot_product / (norm_a * norm_b)
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_cached_result_expiry() {
351 let result = CachedResult::new(
352 Bytes::from("test"),
353 1,
354 Duration::from_millis(10),
355 vec!["users".to_string()],
356 Duration::from_millis(5),
357 );
358
359 assert!(!result.is_expired());
360
361 std::thread::sleep(Duration::from_millis(15));
363 assert!(result.is_expired());
364 }
365
366 #[test]
367 fn test_cache_key_equality() {
368 let ctx1 = CacheContext {
369 database: "db1".to_string(),
370 user: Some("user1".to_string()),
371 branch: None,
372 connection_id: None,
373 };
374
375 let ctx2 = CacheContext {
376 database: "db1".to_string(),
377 user: Some("user1".to_string()),
378 branch: None,
379 connection_id: Some(123), };
381
382 let normalized = NormalizedQuery {
383 fingerprint: "SELECT * FROM users WHERE id = ?".to_string(),
384 hash: 12345,
385 tables: vec!["users".to_string()],
386 parameters: vec!["1".to_string()],
387 };
388
389 let key1 = CacheKey::new(&normalized, &ctx1);
390 let key2 = CacheKey::new(&normalized, &ctx2);
391
392 assert_eq!(key1, key2);
393 }
394
395 #[test]
396 fn test_cache_key_different_users() {
397 let ctx1 = CacheContext {
398 database: "db1".to_string(),
399 user: Some("user1".to_string()),
400 branch: None,
401 connection_id: None,
402 };
403
404 let ctx2 = CacheContext {
405 database: "db1".to_string(),
406 user: Some("user2".to_string()),
407 branch: None,
408 connection_id: None,
409 };
410
411 let normalized = NormalizedQuery {
412 fingerprint: "SELECT * FROM users".to_string(),
413 hash: 12345,
414 tables: vec!["users".to_string()],
415 parameters: vec![],
416 };
417
418 let key1 = CacheKey::new(&normalized, &ctx1);
419 let key2 = CacheKey::new(&normalized, &ctx2);
420
421 assert_ne!(key1, key2);
423 }
424
425 #[test]
426 fn test_l3_entry_similarity() {
427 let result = CachedResult::new(
428 Bytes::from("test"),
429 1,
430 Duration::from_secs(60),
431 vec![],
432 Duration::from_millis(5),
433 );
434
435 let ctx = CacheContext::default();
436
437 let entry = L3Entry::new(
438 "SELECT * FROM users".to_string(),
439 vec![1.0, 0.0, 0.0],
440 ctx,
441 result,
442 );
443
444 assert!((entry.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 0.001);
446
447 assert!((entry.similarity(&[0.0, 1.0, 0.0])).abs() < 0.001);
449
450 assert!((entry.similarity(&[-1.0, 0.0, 0.0]) + 1.0).abs() < 0.001);
452 }
453
454 #[test]
455 fn test_l1_entry_touch() {
456 let result = CachedResult::new(
457 Bytes::from("test"),
458 1,
459 Duration::from_secs(60),
460 vec![],
461 Duration::from_millis(5),
462 );
463
464 let entry = L1Entry::new("SELECT 1".to_string(), result);
465 assert_eq!(entry.access_count(), 1);
466
467 entry.touch();
468 assert_eq!(entry.access_count(), 2);
469
470 entry.touch();
471 assert_eq!(entry.access_count(), 3);
472 }
473}