Skip to main content

heliosdb_proxy/cache/
result.rs

1//! Cached Result Types
2//!
3//! Structures for storing and retrieving cached query results.
4
5use bytes::Bytes;
6use std::hash::{Hash, Hasher};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10use super::normalizer::NormalizedQuery;
11use super::CacheContext;
12
13/// Cached query result
14#[derive(Debug, Clone)]
15pub struct CachedResult {
16    /// Serialized result data
17    pub data: Bytes,
18
19    /// Number of rows in the result
20    pub row_count: usize,
21
22    /// When this result was cached
23    pub cached_at: Instant,
24
25    /// Time-to-live for this result
26    pub ttl: Duration,
27
28    /// Tables referenced by the query
29    pub tables: Vec<String>,
30
31    /// Original query execution time
32    pub execution_time: Duration,
33}
34
35impl CachedResult {
36    /// Create a new cached result
37    pub fn new(
38        data: Bytes,
39        row_count: usize,
40        ttl: Duration,
41        tables: Vec<String>,
42        execution_time: Duration,
43    ) -> Self {
44        Self {
45            data,
46            row_count,
47            cached_at: Instant::now(),
48            ttl,
49            tables,
50            execution_time,
51        }
52    }
53
54    /// Check if this cached result has expired
55    pub fn is_expired(&self) -> bool {
56        self.cached_at.elapsed() > self.ttl
57    }
58
59    /// Get the age of this cached result
60    pub fn age(&self) -> Duration {
61        self.cached_at.elapsed()
62    }
63
64    /// Get remaining TTL
65    pub fn remaining_ttl(&self) -> Duration {
66        self.ttl.saturating_sub(self.cached_at.elapsed())
67    }
68
69    /// Get size in bytes
70    pub fn size(&self) -> usize {
71        self.data.len()
72    }
73}
74
75/// Cache key for lookup operations
76#[derive(Debug, Clone)]
77pub struct CacheKey {
78    /// Hash of the normalized query
79    pub query_hash: u64,
80
81    /// Database name
82    pub database: String,
83
84    /// User (for RLS-aware caching)
85    pub user: Option<String>,
86
87    /// Branch (for HeliosDB branching support)
88    pub branch: Option<String>,
89
90    /// Pre-computed hash for fast lookups
91    cached_hash: u64,
92}
93
94impl CacheKey {
95    /// Create a new cache key from a normalized query and context
96    pub fn new(normalized: &NormalizedQuery, context: &CacheContext) -> Self {
97        let query_hash = normalized.hash;
98
99        // Compute combined hash
100        let mut hasher = std::collections::hash_map::DefaultHasher::new();
101        query_hash.hash(&mut hasher);
102        context.database.hash(&mut hasher);
103        context.user.hash(&mut hasher);
104        context.branch.hash(&mut hasher);
105        let cached_hash = hasher.finish();
106
107        Self {
108            query_hash,
109            database: context.database.clone(),
110            user: context.user.clone(),
111            branch: context.branch.clone(),
112            cached_hash,
113        }
114    }
115
116    /// Create a cache key from raw components
117    pub fn from_parts(
118        query_hash: u64,
119        database: String,
120        user: Option<String>,
121        branch: Option<String>,
122    ) -> Self {
123        let mut hasher = std::collections::hash_map::DefaultHasher::new();
124        query_hash.hash(&mut hasher);
125        database.hash(&mut hasher);
126        user.hash(&mut hasher);
127        branch.hash(&mut hasher);
128        let cached_hash = hasher.finish();
129
130        Self {
131            query_hash,
132            database,
133            user,
134            branch,
135            cached_hash,
136        }
137    }
138
139    /// Get the pre-computed hash
140    pub fn hash_value(&self) -> u64 {
141        self.cached_hash
142    }
143}
144
145impl Hash for CacheKey {
146    fn hash<H: Hasher>(&self, state: &mut H) {
147        state.write_u64(self.cached_hash);
148    }
149}
150
151impl PartialEq for CacheKey {
152    fn eq(&self, other: &Self) -> bool {
153        self.cached_hash == other.cached_hash
154            && self.query_hash == other.query_hash
155            && self.database == other.database
156            && self.user == other.user
157            && self.branch == other.branch
158    }
159}
160
161impl Eq for CacheKey {}
162
163/// Entry in the L1 hot cache
164///
165/// `access_count` is an `AtomicU64` so cache hits can bump it under a
166/// read lock on the containing map — `touch()` takes `&self`, not
167/// `&mut self`. `last_access` is deliberately cosmetic (not consulted
168/// by LRU eviction, which uses a separate ordered queue) and is not
169/// updated per-access.
170#[derive(Debug)]
171pub struct L1Entry {
172    /// The cached result
173    pub result: CachedResult,
174
175    /// Original query string (for exact match)
176    pub query: String,
177
178    /// Access count (atomic so hits only need a read lock on the map)
179    pub access_count: AtomicU64,
180
181    /// Creation / last-put time. Not updated on hits (LRU uses a separate
182    /// ordered queue), so this reflects when the entry was first stored.
183    pub last_access: Instant,
184}
185
186impl L1Entry {
187    /// Create a new L1 cache entry
188    pub fn new(query: String, result: CachedResult) -> Self {
189        Self {
190            result,
191            query,
192            access_count: AtomicU64::new(1),
193            last_access: Instant::now(),
194        }
195    }
196
197    /// Record an access to this entry (lock-free — takes `&self`).
198    pub fn touch(&self) {
199        self.access_count.fetch_add(1, Ordering::Relaxed);
200    }
201
202    /// Get the current access count.
203    pub fn access_count(&self) -> u64 {
204        self.access_count.load(Ordering::Relaxed)
205    }
206
207    /// Check if this entry has expired
208    pub fn is_expired(&self) -> bool {
209        self.result.is_expired()
210    }
211}
212
213/// Entry in the L2 warm cache
214#[derive(Debug, Clone)]
215pub struct L2Entry {
216    /// The cached result
217    pub result: CachedResult,
218
219    /// Normalized query fingerprint
220    pub fingerprint: String,
221
222    /// Cache key
223    pub key: CacheKey,
224
225    /// Access count
226    pub access_count: u64,
227
228    /// Last access time
229    pub last_access: Instant,
230
231    /// Estimated memory size
232    pub memory_size: usize,
233}
234
235impl L2Entry {
236    /// Create a new L2 cache entry
237    pub fn new(key: CacheKey, fingerprint: String, result: CachedResult) -> Self {
238        let memory_size = result.size()
239            + fingerprint.len()
240            + std::mem::size_of::<Self>()
241            + key.database.len()
242            + key.user.as_ref().map(|s| s.len()).unwrap_or(0)
243            + key.branch.as_ref().map(|s| s.len()).unwrap_or(0);
244
245        Self {
246            result,
247            fingerprint,
248            key,
249            access_count: 1,
250            last_access: Instant::now(),
251            memory_size,
252        }
253    }
254
255    /// Record an access to this entry
256    pub fn touch(&mut self) {
257        self.access_count += 1;
258        self.last_access = Instant::now();
259    }
260
261    /// Check if this entry has expired
262    pub fn is_expired(&self) -> bool {
263        self.result.is_expired()
264    }
265}
266
267/// Entry in the L3 semantic cache
268#[derive(Debug, Clone)]
269pub struct L3Entry {
270    /// The cached result
271    pub result: CachedResult,
272
273    /// Original query string
274    pub query: String,
275
276    /// Query embedding vector
277    pub embedding: Vec<f32>,
278
279    /// Cache context
280    pub context: CacheContext,
281
282    /// Access count
283    pub access_count: u64,
284
285    /// Last access time
286    pub last_access: Instant,
287}
288
289impl L3Entry {
290    /// Create a new L3 cache entry
291    pub fn new(query: String, embedding: Vec<f32>, context: CacheContext, result: CachedResult) -> Self {
292        Self {
293            result,
294            query,
295            embedding,
296            context,
297            access_count: 1,
298            last_access: Instant::now(),
299        }
300    }
301
302    /// Record an access to this entry
303    pub fn touch(&mut self) {
304        self.access_count += 1;
305        self.last_access = Instant::now();
306    }
307
308    /// Check if this entry has expired
309    pub fn is_expired(&self) -> bool {
310        self.result.is_expired()
311    }
312
313    /// Compute cosine similarity with another embedding
314    pub fn similarity(&self, other: &[f32]) -> f32 {
315        if self.embedding.len() != other.len() {
316            return 0.0;
317        }
318
319        let mut dot_product = 0.0f32;
320        let mut norm_a = 0.0f32;
321        let mut norm_b = 0.0f32;
322
323        for (a, b) in self.embedding.iter().zip(other.iter()) {
324            dot_product += a * b;
325            norm_a += a * a;
326            norm_b += b * b;
327        }
328
329        let norm_a = norm_a.sqrt();
330        let norm_b = norm_b.sqrt();
331
332        if norm_a == 0.0 || norm_b == 0.0 {
333            return 0.0;
334        }
335
336        dot_product / (norm_a * norm_b)
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_cached_result_expiry() {
346        let result = CachedResult::new(
347            Bytes::from("test"),
348            1,
349            Duration::from_millis(10),
350            vec!["users".to_string()],
351            Duration::from_millis(5),
352        );
353
354        assert!(!result.is_expired());
355
356        // Wait for expiry
357        std::thread::sleep(Duration::from_millis(15));
358        assert!(result.is_expired());
359    }
360
361    #[test]
362    fn test_cache_key_equality() {
363        let ctx1 = CacheContext {
364            database: "db1".to_string(),
365            user: Some("user1".to_string()),
366            branch: None,
367            connection_id: None,
368        };
369
370        let ctx2 = CacheContext {
371            database: "db1".to_string(),
372            user: Some("user1".to_string()),
373            branch: None,
374            connection_id: Some(123), // Different connection_id shouldn't matter
375        };
376
377        let normalized = NormalizedQuery {
378            fingerprint: "SELECT * FROM users WHERE id = ?".to_string(),
379            hash: 12345,
380            tables: vec!["users".to_string()],
381            parameters: vec!["1".to_string()],
382        };
383
384        let key1 = CacheKey::new(&normalized, &ctx1);
385        let key2 = CacheKey::new(&normalized, &ctx2);
386
387        assert_eq!(key1, key2);
388    }
389
390    #[test]
391    fn test_cache_key_different_users() {
392        let ctx1 = CacheContext {
393            database: "db1".to_string(),
394            user: Some("user1".to_string()),
395            branch: None,
396            connection_id: None,
397        };
398
399        let ctx2 = CacheContext {
400            database: "db1".to_string(),
401            user: Some("user2".to_string()),
402            branch: None,
403            connection_id: None,
404        };
405
406        let normalized = NormalizedQuery {
407            fingerprint: "SELECT * FROM users".to_string(),
408            hash: 12345,
409            tables: vec!["users".to_string()],
410            parameters: vec![],
411        };
412
413        let key1 = CacheKey::new(&normalized, &ctx1);
414        let key2 = CacheKey::new(&normalized, &ctx2);
415
416        // Different users should have different cache keys (for RLS)
417        assert_ne!(key1, key2);
418    }
419
420    #[test]
421    fn test_l3_entry_similarity() {
422        let result = CachedResult::new(
423            Bytes::from("test"),
424            1,
425            Duration::from_secs(60),
426            vec![],
427            Duration::from_millis(5),
428        );
429
430        let ctx = CacheContext::default();
431
432        let entry = L3Entry::new(
433            "SELECT * FROM users".to_string(),
434            vec![1.0, 0.0, 0.0],
435            ctx,
436            result,
437        );
438
439        // Same vector should have similarity 1.0
440        assert!((entry.similarity(&[1.0, 0.0, 0.0]) - 1.0).abs() < 0.001);
441
442        // Orthogonal vector should have similarity 0.0
443        assert!((entry.similarity(&[0.0, 1.0, 0.0])).abs() < 0.001);
444
445        // Opposite vector should have similarity -1.0
446        assert!((entry.similarity(&[-1.0, 0.0, 0.0]) + 1.0).abs() < 0.001);
447    }
448
449    #[test]
450    fn test_l1_entry_touch() {
451        let result = CachedResult::new(
452            Bytes::from("test"),
453            1,
454            Duration::from_secs(60),
455            vec![],
456            Duration::from_millis(5),
457        );
458
459        let entry = L1Entry::new("SELECT 1".to_string(), result);
460        assert_eq!(entry.access_count(), 1);
461
462        entry.touch();
463        assert_eq!(entry.access_count(), 2);
464
465        entry.touch();
466        assert_eq!(entry.access_count(), 3);
467    }
468}