Skip to main content

heliosdb_proxy/distribcache/
prefetcher.rs

1//! Predictive prefetcher for intelligent cache warming
2//!
3//! Uses query sequence patterns and temporal patterns to predict
4//! and pre-warm cache with likely future queries.
5
6use chrono::{Datelike, Timelike};
7use dashmap::DashMap;
8use std::collections::{HashMap, VecDeque};
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use super::{DistribCacheConfig, QueryFingerprint, SessionId};
13
14/// Prefetch request
15#[derive(Debug, Clone)]
16pub struct PrefetchRequest {
17    /// Query fingerprint to prefetch
18    pub fingerprint: QueryFingerprint,
19    /// Priority (0-100)
20    pub priority: u32,
21}
22
23/// Prefetch queue
24pub struct PrefetchQueue {
25    /// Queue of pending requests
26    queue: std::sync::Mutex<VecDeque<PrefetchRequest>>,
27    /// Notifier for new items
28    notify: tokio::sync::Notify,
29}
30
31impl PrefetchQueue {
32    fn new() -> Self {
33        Self {
34            queue: std::sync::Mutex::new(VecDeque::new()),
35            notify: tokio::sync::Notify::new(),
36        }
37    }
38
39    pub fn enqueue(&self, request: PrefetchRequest) {
40        let mut queue = self.queue.lock().unwrap();
41
42        // Insert by priority (higher priority first)
43        let pos = queue.iter()
44            .position(|r| r.priority < request.priority)
45            .unwrap_or(queue.len());
46
47        queue.insert(pos, request);
48        self.notify.notify_one();
49    }
50
51    pub async fn dequeue(&self) -> Option<PrefetchRequest> {
52        loop {
53            {
54                let mut queue = self.queue.lock().unwrap();
55                if let Some(request) = queue.pop_front() {
56                    return Some(request);
57                }
58            }
59            self.notify.notified().await;
60        }
61    }
62
63    pub fn len(&self) -> usize {
64        self.queue.lock().unwrap().len()
65    }
66
67    pub fn is_empty(&self) -> bool {
68        self.queue.lock().unwrap().is_empty()
69    }
70}
71
72/// Temporal pattern storage
73pub struct TemporalPatternStore {
74    /// Patterns by hour of day (0-23)
75    hourly_patterns: [DashMap<QueryFingerprint, u64>; 24],
76    /// Patterns by day of week (0-6)
77    daily_patterns: [DashMap<QueryFingerprint, u64>; 7],
78}
79
80impl TemporalPatternStore {
81    fn new() -> Self {
82        Self {
83            hourly_patterns: std::array::from_fn(|_| DashMap::new()),
84            daily_patterns: std::array::from_fn(|_| DashMap::new()),
85        }
86    }
87
88    fn record(&self, fingerprint: &QueryFingerprint, hour: usize, weekday: usize) {
89        if hour < 24 {
90            self.hourly_patterns[hour]
91                .entry(fingerprint.clone())
92                .and_modify(|c| *c += 1)
93                .or_insert(1);
94        }
95        if weekday < 7 {
96            self.daily_patterns[weekday]
97                .entry(fingerprint.clone())
98                .and_modify(|c| *c += 1)
99                .or_insert(1);
100        }
101    }
102
103    fn predict_for_hour(&self, hour: usize) -> Vec<QueryFingerprint> {
104        if hour >= 24 {
105            return Vec::new();
106        }
107
108        let patterns = &self.hourly_patterns[hour];
109        let mut predictions: Vec<_> = patterns.iter()
110            .map(|e| (e.key().clone(), *e.value()))
111            .collect();
112
113        predictions.sort_by(|a, b| b.1.cmp(&a.1));
114        predictions.into_iter()
115            .take(10)
116            .map(|(fp, _)| fp)
117            .collect()
118    }
119}
120
121/// Predictive prefetcher
122pub struct PredictivePrefetcher {
123    /// Configuration
124    config: DistribCacheConfig,
125
126    /// Query sequence patterns (prev -> next queries)
127    patterns: DashMap<QueryFingerprint, Vec<QueryFingerprint>>,
128
129    /// Session-based sequences
130    session_sequences: DashMap<SessionId, VecDeque<QueryFingerprint>>,
131
132    /// Temporal patterns
133    temporal_patterns: TemporalPatternStore,
134
135    /// Prefetch queue
136    prefetch_queue: Arc<PrefetchQueue>,
137
138    /// Running flag
139    running: AtomicBool,
140
141    /// Statistics
142    predictions_made: AtomicU64,
143    prefetch_hits: AtomicU64,
144    prefetch_misses: AtomicU64,
145}
146
147impl PredictivePrefetcher {
148    /// Create a new prefetcher
149    pub fn new(config: DistribCacheConfig) -> Self {
150        Self {
151            config,
152            patterns: DashMap::new(),
153            session_sequences: DashMap::new(),
154            temporal_patterns: TemporalPatternStore::new(),
155            prefetch_queue: Arc::new(PrefetchQueue::new()),
156            running: AtomicBool::new(false),
157            predictions_made: AtomicU64::new(0),
158            prefetch_hits: AtomicU64::new(0),
159            prefetch_misses: AtomicU64::new(0),
160        }
161    }
162
163    /// Record a query for pattern learning
164    pub fn record(&self, session: &SessionId, fingerprint: QueryFingerprint) {
165        // Get or create session sequence
166        let mut seq = self.session_sequences
167            .entry(session.clone())
168            .or_insert_with(|| VecDeque::with_capacity(100));
169
170        // Learn pattern from sequence
171        if seq.len() >= 1 {
172            if let Some(prev) = seq.back() {
173                self.patterns
174                    .entry(prev.clone())
175                    .or_default()
176                    .push(fingerprint.clone());
177            }
178        }
179
180        // Add to sequence
181        seq.push_back(fingerprint.clone());
182
183        // Maintain size limit
184        while seq.len() > 100 {
185            seq.pop_front();
186        }
187
188        // Record temporal pattern
189        let now = chrono::Utc::now();
190        self.temporal_patterns.record(
191            &fingerprint,
192            now.hour() as usize,
193            now.weekday().num_days_from_monday() as usize,
194        );
195    }
196
197    /// Predict and enqueue prefetch requests
198    pub fn predict_and_prefetch(&self, current: &QueryFingerprint, _session: &SessionId) {
199        if !self.config.prefetch_enabled {
200            return;
201        }
202
203        // 1. Pattern-based prediction
204        if let Some(next_queries) = self.patterns.get(current) {
205            let predictions = self.get_top_predictions(next_queries.value());
206
207            for (fingerprint, confidence) in predictions {
208                if confidence > self.config.prefetch_confidence_threshold {
209                    self.prefetch_queue.enqueue(PrefetchRequest {
210                        fingerprint,
211                        priority: (confidence * 100.0) as u32,
212                    });
213                    self.predictions_made.fetch_add(1, Ordering::Relaxed);
214                }
215            }
216        }
217
218        // 2. Temporal prediction
219        let hour = chrono::Utc::now().hour() as usize;
220        let temporal_predictions = self.temporal_patterns.predict_for_hour(hour);
221
222        for fingerprint in temporal_predictions.into_iter().take(self.config.prefetch_lookahead as usize) {
223            self.prefetch_queue.enqueue(PrefetchRequest {
224                fingerprint,
225                priority: 50, // Medium priority for temporal
226            });
227        }
228    }
229
230    /// Get top predictions with confidence scores
231    fn get_top_predictions(&self, next_queries: &[QueryFingerprint]) -> Vec<(QueryFingerprint, f32)> {
232        // Count occurrences
233        let mut counts: HashMap<&QueryFingerprint, u32> = HashMap::new();
234        for fp in next_queries {
235            *counts.entry(fp).or_default() += 1;
236        }
237
238        let total = next_queries.len() as f32;
239
240        // Calculate confidence and sort
241        let mut predictions: Vec<_> = counts.into_iter()
242            .map(|(fp, count)| (fp.clone(), count as f32 / total))
243            .collect();
244
245        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246        predictions.into_iter()
247            .take(self.config.prefetch_lookahead as usize)
248            .collect()
249    }
250
251    /// Start the prefetch background worker
252    pub async fn start(&self) {
253        self.running.store(true, Ordering::SeqCst);
254
255        // In production, this would spawn a background task
256        // that processes the prefetch queue
257    }
258
259    /// Stop the prefetcher
260    pub async fn stop(&self) {
261        self.running.store(false, Ordering::SeqCst);
262    }
263
264    /// Record prefetch hit
265    pub fn record_hit(&self) {
266        self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
267    }
268
269    /// Record prefetch miss
270    pub fn record_miss(&self) {
271        self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
272    }
273
274    /// Get prefetch statistics
275    pub fn stats(&self) -> PrefetchStats {
276        let hits = self.prefetch_hits.load(Ordering::Relaxed);
277        let misses = self.prefetch_misses.load(Ordering::Relaxed);
278
279        PrefetchStats {
280            predictions_made: self.predictions_made.load(Ordering::Relaxed),
281            queue_size: self.prefetch_queue.len(),
282            hit_rate: if hits + misses > 0 {
283                hits as f64 / (hits + misses) as f64
284            } else {
285                0.0
286            },
287            patterns_learned: self.patterns.len(),
288            sessions_tracked: self.session_sequences.len(),
289        }
290    }
291
292    /// Clean up old sessions
293    pub fn cleanup_old_sessions(&self, _max_age: std::time::Duration) {
294        // In production, track timestamps and clean up
295        // For now, just limit total sessions
296        if self.session_sequences.len() > 10000 {
297            // Remove random entries to stay under limit
298            let to_remove: Vec<_> = self.session_sequences.iter()
299                .take(1000)
300                .map(|e| e.key().clone())
301                .collect();
302
303            for key in to_remove {
304                self.session_sequences.remove(&key);
305            }
306        }
307    }
308}
309
310/// Prefetch statistics
311#[derive(Debug, Clone)]
312pub struct PrefetchStats {
313    /// Total predictions made
314    pub predictions_made: u64,
315    /// Current queue size
316    pub queue_size: usize,
317    /// Prefetch hit rate
318    pub hit_rate: f64,
319    /// Number of patterns learned
320    pub patterns_learned: usize,
321    /// Number of sessions tracked
322    pub sessions_tracked: usize,
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_prefetch_queue() {
331        let queue = PrefetchQueue::new();
332
333        let fp1 = QueryFingerprint::from_query("SELECT 1");
334        let fp2 = QueryFingerprint::from_query("SELECT 2");
335        let fp3 = QueryFingerprint::from_query("SELECT 3");
336
337        // Add with different priorities
338        queue.enqueue(PrefetchRequest { fingerprint: fp1.clone(), priority: 50 });
339        queue.enqueue(PrefetchRequest { fingerprint: fp2.clone(), priority: 100 });
340        queue.enqueue(PrefetchRequest { fingerprint: fp3.clone(), priority: 25 });
341
342        assert_eq!(queue.len(), 3);
343    }
344
345    #[test]
346    fn test_pattern_learning() {
347        let config = DistribCacheConfig::default();
348        let prefetcher = PredictivePrefetcher::new(config);
349        let session = SessionId::new("test");
350
351        let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
352        let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
353        let fp3 = QueryFingerprint::from_query("SELECT * FROM items");
354
355        // Simulate sequence: fp1 -> fp2 -> fp3
356        prefetcher.record(&session, fp1.clone());
357        prefetcher.record(&session, fp2.clone());
358        prefetcher.record(&session, fp3.clone());
359
360        // Pattern fp1 -> fp2 should be learned
361        assert!(prefetcher.patterns.contains_key(&fp1));
362        let next = prefetcher.patterns.get(&fp1).unwrap();
363        assert!(next.contains(&fp2));
364    }
365
366    #[test]
367    fn test_prediction() {
368        let config = DistribCacheConfig::builder()
369            .prefetch_enabled(true)
370            .prefetch_confidence_threshold(0.0) // Accept all predictions for test
371            .build();
372        let prefetcher = PredictivePrefetcher::new(config);
373        let session = SessionId::new("test");
374
375        // Train pattern: query1 -> query2 (repeated)
376        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
377        let fp2 = QueryFingerprint::from_query("SELECT * FROM orders WHERE user_id = ?");
378
379        for _ in 0..10 {
380            prefetcher.record(&session, fp1.clone());
381            prefetcher.record(&session, fp2.clone());
382        }
383
384        // Now predict after fp1
385        prefetcher.predict_and_prefetch(&fp1, &session);
386
387        // Should have enqueued prefetch for fp2
388        assert!(!prefetcher.prefetch_queue.is_empty());
389    }
390
391    #[test]
392    fn test_temporal_patterns() {
393        let store = TemporalPatternStore::new();
394        let fp = QueryFingerprint::from_query("SELECT * FROM reports");
395
396        // Record at hour 9 multiple times
397        for _ in 0..10 {
398            store.record(&fp, 9, 1);
399        }
400
401        // Predict for hour 9 should include our query
402        let predictions = store.predict_for_hour(9);
403        assert!(predictions.contains(&fp));
404    }
405
406    #[test]
407    fn test_stats() {
408        let config = DistribCacheConfig::default();
409        let prefetcher = PredictivePrefetcher::new(config);
410
411        prefetcher.record_hit();
412        prefetcher.record_hit();
413        prefetcher.record_miss();
414
415        let stats = prefetcher.stats();
416        assert!((stats.hit_rate - 0.666).abs() < 0.01);
417    }
418}