Skip to main content

heliosdb_proxy/cache/
result.rs

1//! Cached Result Types
2//!
3//! Structures for storing and retrieving cached query results.
4
5use bytes::Bytes;
6use std::hash::{Hash, Hasher};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10use super::normalizer::NormalizedQuery;
11use super::CacheContext;
12
13/// Cached query result
14#[derive(Debug, Clone)]
15pub struct CachedResult {
16    /// Serialized result data
17    pub data: Bytes,
18
19    /// Number of rows in the result
20    pub row_count: usize,
21
22    /// When this result was cached
23    pub cached_at: Instant,
24
25    /// Time-to-live for this result
26    pub ttl: Duration,
27
28    /// Tables referenced by the query
29    pub tables: Vec<String>,
30
31    /// Original query execution time
32    pub execution_time: Duration,
33}
34
35impl CachedResult {
36    /// Create a new cached result
37    pub fn new(
38        data: Bytes,
39        row_count: usize,
40        ttl: Duration,
41        tables: Vec<String>,
42        execution_time: Duration,
43    ) -> Self {
44        Self {
45            data,
46            row_count,
47            cached_at: Instant::now(),
48            ttl,
49            tables,
50            execution_time,
51        }
52    }
53
54    /// Check if this cached result has expired
55    pub fn is_expired(&self) -> bool {
56        self.cached_at.elapsed() > self.ttl
57    }
58
59    /// Get the age of this cached result
60    pub fn age(&self) -> Duration {
61        self.cached_at.elapsed()
62    }
63
64    /// Get remaining TTL
65    pub fn remaining_ttl(&self) -> Duration {
66        self.ttl.saturating_sub(self.cached_at.elapsed())
67    }
68
69    /// Get size in bytes
70    pub fn size(&self) -> usize {
71        self.data.len()
72    }
73}
74
75/// Cache key for lookup operations
76#[derive(Debug, Clone)]
77pub struct CacheKey {
78    /// Hash of the normalized query
79    pub query_hash: u64,
80
81    /// Database name
82    pub database: String,
83
84    /// User (for RLS-aware caching)
85    pub user: Option<String>,
86
87    /// Branch (for HeliosDB branching support)
88    pub branch: Option<String>,
89
90    /// Pre-computed hash for fast lookups
91    cached_hash: u64,
92}
93
94impl CacheKey {
95    /// Create a new cache key from a normalized query and context
96    pub fn new(normalized: &NormalizedQuery, context: &CacheContext) -> Self {
97        let query_hash = normalized.hash;
98
99        // Compute combined hash
100        let mut hasher = std::collections::hash_map::DefaultHasher::new();
101        query_hash.hash(&mut hasher);
102        context.database.hash(&mut hasher);
103        context.user.hash(&mut hasher);
104        context.branch.hash(&mut hasher);
105        let cached_hash = hasher.finish();
106
107        Self {
108            query_hash,
109            database: context.database.clone(),
110            user: context.user.clone(),
111            branch: context.branch.clone(),
112            cached_hash,
113        }
114    }
115
116    /// Create a cache key from raw components
117    pub fn from_parts(
118        query_hash: u64,
119        database: String,
120        user: Option<String>,
121        branch: Option<String>,
122    ) -> Self {
123        let mut hasher = std::collections::hash_map::DefaultHasher::new();
124        query_hash.hash(&mut hasher);
125        database.hash(&mut hasher);
126        user.hash(&mut hasher);
127        branch.hash(&mut hasher);
128        let cached_hash = hasher.finish();
129
130        Self {
131            query_hash,
132            database,
133            user,
134            branch,
135            cached_hash,
136        }
137    }
138
139    /// Get the pre-computed hash
140    pub fn hash_value(&self) -> u64 {
141        self.cached_hash
142    }
143}
144
145impl Hash for CacheKey {
146    fn hash<H: Hasher>(&self, state: &mut H) {
147        state.write_u64(self.cached_hash);
148    }
149}
150
151impl PartialEq for CacheKey {
152    fn eq(&self, other: &Self) -> bool {
153        self.cached_hash == other.cached_hash
154            && self.query_hash == other.query_hash
155            && self.database == other.database
156            && self.user == other.user
157            && self.branch == other.branch
158    }
159}
160
161impl Eq for CacheKey {}
162
163/// Entry in the L1 hot cache
164///
165/// `access_count` is an `AtomicU64` so cache hits can bump it under a
166/// read lock on the containing map — `touch()` takes `&self`, not
167/// `&mut self`. `last_access` is deliberately cosmetic (not consulted
168/// by LRU eviction, which uses a separate ordered queue) and is not
169/// updated per-access.
170#[derive(Debug)]
171pub struct L1Entry {
172    /// The cached result
173    pub result: CachedResult,
174
175    /// Original query string (for exact match)
176    pub query: String,
177
178    /// Access count (atomic so hits only need a read lock on the map)
179    pub access_count: AtomicU64,
180
181    /// Creation / last-put time. Not updated on hits (LRU uses a separate
182    /// ordered queue), so this reflects when the entry was first stored.
183    pub last_access: Instant,
184}
185
186impl L1Entry {
187    /// Create a new L1 cache entry
188    pub fn new(query: String, result: CachedResult) -> Self {
189        Self {
190            result,
191            query,
192            access_count: AtomicU64::new(1),
193            last_access: Instant::now(),
194        }
195    }
196
197    /// Record an access to this entry (lock-free — takes `&self`).
198    pub fn touch(&self) {
199        self.access_count.fetch_add(1, Ordering::Relaxed);
200    }
201
202    /// Get the current access count.
203    pub fn access_count(&self) -> u64 {
204        self.access_count.load(Ordering::Relaxed)
205    }
206
207    /// Check if this entry has expired
208    pub fn is_expired(&self) -> bool {
209        self.result.is_expired()
210    }
211}
212
213/// Entry in the L2 warm cache
214#[derive(Debug, Clone)]
215pub struct L2Entry {
216    /// The cached result
217    pub result: CachedResult,
218
219    /// Normalized query fingerprint
220    pub fingerprint: String,
221
222    /// Cache key
223    pub key: CacheKey,
224
225    /// Access count
226    pub access_count: u64,
227
228    /// Last access time
229    pub last_access: Instant,
230
231    /// Estimated memory size
232    pub memory_size: usize,
233}
234
235impl L2Entry {
236    /// Create a new L2 cache entry
237    pub fn new(key: CacheKey, fingerprint: String, result: CachedResult) -> Self {
238        let memory_size = result.size()
239            + fingerprint.len()
240            + std::mem::size_of::<Self>()
241            + key.database.len()
242            + key.user.as_ref().map(|s| s.len()).unwrap_or(0)
243            + key.branch.as_ref().map(|s| s.len()).unwrap_or(0);
244
245        Self {
246            result,
247            fingerprint,
248            key,
249            access_count: 1,
250            last_access: Instant::now(),
251            memory_size,
252        }
253    }
254
255    /// Record an access to this entry
256    pub fn touch(&mut self) {
257        self.access_count += 1;
258        self.last_access = Instant::now();
259    }
260
261    /// Check if this entry has expired
262    pub fn is_expired(&self) -> bool {
263        self.result.is_expired()
264    }
265}
266
267/// Entry in the L3 semantic cache
268#[derive(Debug, Clone)]
269pub struct L3Entry {
270    /// The cached result
271    pub result: CachedResult,
272
273    /// Original query string
274    pub query: String,
275
276    /// Query embedding vector
277    pub embedding: Vec<f32>,
278
279    /// Cache context
280    pub context: CacheContext,
281
282    /// Access count
283    pub access_count: u64,
284
285    /// Last access time
286    pub last_access: Instant,
287}
288
289impl L3Entry {
290    /// Create a new L3 cache entry
291    pub fn new(
292        query: String,
293        embedding: Vec<f32>,
294        context: CacheContext,
295        result: CachedResult,
296    ) -> Self {
297        Self {
298            result,
299            query,
300            embedding,
301            context,
302            access_count: 1,
303            last_access: Instant::now(),
304        }
305    }
306
307    /// Record an access to this entry
308    pub fn touch(&mut self) {
309        self.access_count += 1;
310        self.last_access = Instant::now();
311    }
312
313    /// Check if this entry has expired
314    pub fn is_expired(&self) -> bool {
315        self.result.is_expired()
316    }
317
318    /// Compute cosine similarity with another embedding
319    pub fn similarity(&self, other: &[f32]) -> f32 {
320        if self.embedding.len() != other.len() {
321            return 0.0;
322        }
323
324        let mut dot_product = 0.0f32;
325        let mut norm_a = 0.0f32;
326        let mut norm_b = 0.0f32;
327
328        for (a, b) in self.embedding.iter().zip(other.iter()) {
329            dot_product += a * b;
330            norm_a += a * a;
331            norm_b += b * b;
332        }
333
334        let norm_a = norm_a.sqrt();
335        let norm_b = norm_b.sqrt();
336
337        if norm_a == 0.0 || norm_b == 0.0 {
338            return 0.0;
339        }
340
341        dot_product / (norm_a * norm_b)
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_cached_result_expiry() {
351        let result = CachedResult::new(
352            Bytes::from("test"),
353            1,
354            Duration::from_millis(10),
355            vec!["users".to_string()],
356            Duration::from_millis(5),
357        );
358
359        assert!(!result.is_expired());
360
361        // Wait for expiry
362        std::thread::sleep(Duration::from_millis(15));
363        assert!(result.is_expired());
364    }
365
366    #[test]
367    fn test_cache_key_equality() {
368        let ctx1 = CacheContext {
369            database: "db1".to_string(),
370            user: Some("user1".to_string()),
371            branch: None,
372            connection_id: None,
373        };
374
375        let ctx2 = CacheContext {
376            database: "db1".to_string(),
377            user: Some("user1".to_string()),
378            branch: None,
379            connection_id: Some(123), // Different connection_id shouldn't matter
380        };
381
382        let normalized = NormalizedQuery {
383            fingerprint: "SELECT * FROM users WHERE id = ?".to_string(),
384            hash: 12345,
385            tables: vec!["users".to_string()],
386            parameters: vec!["1".to_string()],
387        };
388
389        let key1 = CacheKey::new(&normalized, &ctx1);
390        let key2 = CacheKey::new(&normalized, &ctx2);
391
392        assert_eq!(key1, key2);
393    }
394
395    #[test]
396    fn test_cache_key_different_users() {
397        let ctx1 = CacheContext {
398            database: "db1".to_string(),
399            user: Some("user1".to_string()),
400            branch: None,
401            connection_id: None,
402        };
403
404        let ctx2 = CacheContext {
405            database: "db1".to_string(),
406            user: Some("user2".to_string()),
407            branch: None,
408            connection_id: None,
409        };
410
411        let normalized = NormalizedQuery {
412            fingerprint: "SELECT * FROM users".to_string(),
413            hash: 12345,
414            tables: vec!["users".to_string()],
415            parameters: vec![],
416        };
417
418        let key1 = CacheKey::new(&normalized, &ctx1);
419        let key2 = CacheKey::new(&normalized, &ctx2);
420
421        // Different users should have different cache keys (for RLS)
422        assert_ne!(key1, key2);
423    }
424
425    #[test]
426    fn test_l3_entry_similarity() {
427        let result = CachedResult::new(
428            Bytes::from("test"),
429            1,
430            Duration::from_secs(60),
431            vec![],
432            Duration::from_millis(5),
433        );
434
435        let ctx = CacheContext::default();
436
437        let entry = L3Entry::new(
438            "SELECT * FROM users".to_string(),
439            vec![1.0, 0.0, 0.0],
440            ctx,
441            result,
442        );
443
444        // Same vector should have similarity 1.0
445        assert!((entry.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 0.001);
446
447        // Orthogonal vector should have similarity 0.0
448        assert!((entry.similarity(&[0.0, 1.0, 0.0])).abs() < 0.001);
449
450        // Opposite vector should have similarity -1.0
451        assert!((entry.similarity(&[-1.0, 0.0, 0.0]) + 1.0).abs() < 0.001);
452    }
453
454    #[test]
455    fn test_l1_entry_touch() {
456        let result = CachedResult::new(
457            Bytes::from("test"),
458            1,
459            Duration::from_secs(60),
460            vec![],
461            Duration::from_millis(5),
462        );
463
464        let entry = L1Entry::new("SELECT 1".to_string(), result);
465        assert_eq!(entry.access_count(), 1);
466
467        entry.touch();
468        assert_eq!(entry.access_count(), 2);
469
470        entry.touch();
471        assert_eq!(entry.access_count(), 3);
472    }
473}