Skip to main content

oxigdal_query/cache/
mod.rs

1//! Query result caching.
2
3use crate::executor::scan::RecordBatch;
4use crate::parser::ast::Statement;
5use blake3::Hash;
6use dashmap::DashMap;
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12/// Query cache.
13pub struct QueryCache {
14    /// Cache entries.
15    entries: DashMap<Hash, CacheEntry>,
16    /// Configuration.
17    config: CacheConfig,
18    /// Statistics.
19    stats: Arc<RwLock<CacheStatistics>>,
20}
21
22/// Cache configuration.
23#[derive(Debug, Clone)]
24pub struct CacheConfig {
25    /// Maximum cache size in bytes.
26    pub max_size_bytes: usize,
27    /// Time-to-live for cache entries.
28    pub ttl: Duration,
29    /// Enable cache.
30    pub enabled: bool,
31}
32
33impl Default for CacheConfig {
34    fn default() -> Self {
35        Self {
36            max_size_bytes: 1024 * 1024 * 1024, // 1 GB
37            ttl: Duration::from_secs(300),      // 5 minutes
38            enabled: true,
39        }
40    }
41}
42
43/// Cache entry.
44#[derive(Clone)]
45struct CacheEntry {
46    /// Cached result.
47    result: Arc<Vec<RecordBatch>>,
48    /// Creation time.
49    created_at: Instant,
50    /// Size in bytes (approximate).
51    size_bytes: usize,
52    /// Access count.
53    access_count: usize,
54}
55
56impl CacheEntry {
57    fn new(result: Vec<RecordBatch>) -> Self {
58        let size_bytes = Self::estimate_size(&result);
59        Self {
60            result: Arc::new(result),
61            created_at: Instant::now(),
62            size_bytes,
63            access_count: 0,
64        }
65    }
66
67    fn estimate_size(batches: &[RecordBatch]) -> usize {
68        batches
69            .iter()
70            .map(|batch| batch.num_rows * 100)
71            .sum::<usize>()
72    }
73
74    fn is_expired(&self, ttl: Duration) -> bool {
75        self.created_at.elapsed() > ttl
76    }
77}
78
79impl QueryCache {
80    /// Create a new query cache.
81    pub fn new(config: CacheConfig) -> Self {
82        Self {
83            entries: DashMap::new(),
84            config,
85            stats: Arc::new(RwLock::new(CacheStatistics::default())),
86        }
87    }
88
89    /// Get cached result.
90    pub fn get(&self, query: &Statement) -> Option<Vec<RecordBatch>> {
91        if !self.config.enabled {
92            return None;
93        }
94
95        let key = self.compute_key(query);
96
97        if let Some(mut entry) = self.entries.get_mut(&key) {
98            if entry.is_expired(self.config.ttl) {
99                drop(entry);
100                self.entries.remove(&key);
101                self.stats.write().misses += 1;
102                return None;
103            }
104
105            entry.access_count += 1;
106            let result = (*entry.result).clone();
107            self.stats.write().hits += 1;
108            Some(result)
109        } else {
110            self.stats.write().misses += 1;
111            None
112        }
113    }
114
115    /// Put result in cache.
116    pub fn put(&self, query: &Statement, result: Vec<RecordBatch>) {
117        if !self.config.enabled {
118            return;
119        }
120
121        let key = self.compute_key(query);
122        let entry = CacheEntry::new(result);
123
124        // Check cache size limit
125        self.evict_if_needed(entry.size_bytes);
126
127        self.entries.insert(key, entry);
128        self.stats.write().inserts += 1;
129    }
130
131    /// Invalidate cache entry.
132    pub fn invalidate(&self, query: &Statement) {
133        let key = self.compute_key(query);
134        self.entries.remove(&key);
135    }
136
137    /// Clear all cache entries.
138    pub fn clear(&self) {
139        self.entries.clear();
140        self.stats.write().clears += 1;
141    }
142
143    /// Get cache statistics.
144    pub fn statistics(&self) -> CacheStatistics {
145        *self.stats.read()
146    }
147
148    /// Compute cache key from query.
149    fn compute_key(&self, query: &Statement) -> Hash {
150        let query_string = format!("{:?}", query);
151        blake3::hash(query_string.as_bytes())
152    }
153
154    /// Evict entries if cache is too large.
155    fn evict_if_needed(&self, incoming_size: usize) {
156        let mut current_size: usize = self
157            .entries
158            .iter()
159            .map(|entry| entry.value().size_bytes)
160            .sum();
161
162        if current_size + incoming_size <= self.config.max_size_bytes {
163            return;
164        }
165
166        // Evict least recently used entries
167        let mut entries: Vec<_> = self
168            .entries
169            .iter()
170            .map(|entry| {
171                (
172                    *entry.key(),
173                    entry.value().created_at,
174                    entry.value().access_count,
175                    entry.value().size_bytes,
176                )
177            })
178            .collect();
179
180        entries.sort_by_key(|(_, created, access_count, _)| {
181            (created.elapsed().as_secs(), *access_count)
182        });
183
184        for (key, _, _, size) in entries {
185            self.entries.remove(&key);
186            current_size -= size;
187            self.stats.write().evictions += 1;
188
189            if current_size + incoming_size <= self.config.max_size_bytes {
190                break;
191            }
192        }
193    }
194}
195
196/// Cache statistics.
197#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
198pub struct CacheStatistics {
199    /// Number of cache hits.
200    pub hits: u64,
201    /// Number of cache misses.
202    pub misses: u64,
203    /// Number of inserts.
204    pub inserts: u64,
205    /// Number of evictions.
206    pub evictions: u64,
207    /// Number of cache clears.
208    pub clears: u64,
209}
210
211impl CacheStatistics {
212    /// Get hit rate.
213    pub fn hit_rate(&self) -> f64 {
214        let total = self.hits + self.misses;
215        if total == 0 {
216            0.0
217        } else {
218            self.hits as f64 / total as f64
219        }
220    }
221
222    /// Get miss rate.
223    pub fn miss_rate(&self) -> f64 {
224        1.0 - self.hit_rate()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::executor::scan::{ColumnData, DataType, Field, Schema};
232    use crate::parser::sql::parse_sql;
233
234    #[test]
235    fn test_cache_put_get() {
236        let config = CacheConfig::default();
237        let cache = QueryCache::new(config);
238
239        let query = parse_sql("SELECT * FROM test").ok().unwrap_or_else(|| {
240            Statement::Select(crate::parser::ast::SelectStatement {
241                projection: vec![],
242                from: None,
243                selection: None,
244                group_by: vec![],
245                having: None,
246                order_by: vec![],
247                limit: None,
248                offset: None,
249            })
250        });
251
252        let schema = Arc::new(Schema::new(vec![Field::new(
253            "id".to_string(),
254            DataType::Int64,
255            false,
256        )]));
257
258        let columns = vec![ColumnData::Int64(vec![Some(1), Some(2)])];
259        let batch = RecordBatch::new(schema, columns, 2).ok();
260
261        if let Some(batch) = batch {
262            let result = vec![batch];
263
264            cache.put(&query, result.clone());
265
266            let cached = cache.get(&query);
267            assert!(cached.is_some());
268        }
269    }
270
271    #[test]
272    fn test_cache_statistics() {
273        let config = CacheConfig::default();
274        let cache = QueryCache::new(config);
275
276        let query = parse_sql("SELECT * FROM test").ok().unwrap_or_else(|| {
277            Statement::Select(crate::parser::ast::SelectStatement {
278                projection: vec![],
279                from: None,
280                selection: None,
281                group_by: vec![],
282                having: None,
283                order_by: vec![],
284                limit: None,
285                offset: None,
286            })
287        });
288
289        // Miss
290        let _ = cache.get(&query);
291
292        let stats = cache.statistics();
293        assert_eq!(stats.hits, 0);
294        assert_eq!(stats.misses, 1);
295    }
296}