Skip to main content

heliosdb_proxy/distribcache/
prefetcher.rs

1//! Predictive prefetcher for intelligent cache warming
2//!
3//! Uses query sequence patterns and temporal patterns to predict
4//! and pre-warm cache with likely future queries.
5
6use chrono::{Datelike, Timelike};
7use dashmap::DashMap;
8use std::collections::{HashMap, VecDeque};
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use super::{DistribCacheConfig, QueryFingerprint, SessionId};
13
14/// Prefetch request
15#[derive(Debug, Clone)]
16pub struct PrefetchRequest {
17    /// Query fingerprint to prefetch
18    pub fingerprint: QueryFingerprint,
19    /// Priority (0-100)
20    pub priority: u32,
21}
22
23/// Prefetch queue
24pub struct PrefetchQueue {
25    /// Queue of pending requests
26    queue: std::sync::Mutex<VecDeque<PrefetchRequest>>,
27    /// Notifier for new items
28    notify: tokio::sync::Notify,
29}
30
31impl PrefetchQueue {
32    fn new() -> Self {
33        Self {
34            queue: std::sync::Mutex::new(VecDeque::new()),
35            notify: tokio::sync::Notify::new(),
36        }
37    }
38
39    pub fn enqueue(&self, request: PrefetchRequest) {
40        let mut queue = self.queue.lock().unwrap();
41
42        // Insert by priority (higher priority first)
43        let pos = queue
44            .iter()
45            .position(|r| r.priority < request.priority)
46            .unwrap_or(queue.len());
47
48        queue.insert(pos, request);
49        self.notify.notify_one();
50    }
51
52    pub async fn dequeue(&self) -> Option<PrefetchRequest> {
53        loop {
54            {
55                let mut queue = self.queue.lock().unwrap();
56                if let Some(request) = queue.pop_front() {
57                    return Some(request);
58                }
59            }
60            self.notify.notified().await;
61        }
62    }
63
64    pub fn len(&self) -> usize {
65        self.queue.lock().unwrap().len()
66    }
67
68    pub fn is_empty(&self) -> bool {
69        self.queue.lock().unwrap().is_empty()
70    }
71}
72
73/// Temporal pattern storage
74pub struct TemporalPatternStore {
75    /// Patterns by hour of day (0-23)
76    hourly_patterns: [DashMap<QueryFingerprint, u64>; 24],
77    /// Patterns by day of week (0-6)
78    daily_patterns: [DashMap<QueryFingerprint, u64>; 7],
79}
80
81impl TemporalPatternStore {
82    fn new() -> Self {
83        Self {
84            hourly_patterns: std::array::from_fn(|_| DashMap::new()),
85            daily_patterns: std::array::from_fn(|_| DashMap::new()),
86        }
87    }
88
89    fn record(&self, fingerprint: &QueryFingerprint, hour: usize, weekday: usize) {
90        if hour < 24 {
91            self.hourly_patterns[hour]
92                .entry(fingerprint.clone())
93                .and_modify(|c| *c += 1)
94                .or_insert(1);
95        }
96        if weekday < 7 {
97            self.daily_patterns[weekday]
98                .entry(fingerprint.clone())
99                .and_modify(|c| *c += 1)
100                .or_insert(1);
101        }
102    }
103
104    fn predict_for_hour(&self, hour: usize) -> Vec<QueryFingerprint> {
105        if hour >= 24 {
106            return Vec::new();
107        }
108
109        let patterns = &self.hourly_patterns[hour];
110        let mut predictions: Vec<_> = patterns
111            .iter()
112            .map(|e| (e.key().clone(), *e.value()))
113            .collect();
114
115        predictions.sort_by_key(|b| std::cmp::Reverse(b.1));
116        predictions.into_iter().take(10).map(|(fp, _)| fp).collect()
117    }
118}
119
120/// Predictive prefetcher
121pub struct PredictivePrefetcher {
122    /// Configuration
123    config: DistribCacheConfig,
124
125    /// Query sequence patterns (prev -> next queries)
126    patterns: DashMap<QueryFingerprint, Vec<QueryFingerprint>>,
127
128    /// Session-based sequences
129    session_sequences: DashMap<SessionId, VecDeque<QueryFingerprint>>,
130
131    /// Temporal patterns
132    temporal_patterns: TemporalPatternStore,
133
134    /// Prefetch queue
135    prefetch_queue: Arc<PrefetchQueue>,
136
137    /// Running flag
138    running: AtomicBool,
139
140    /// Statistics
141    predictions_made: AtomicU64,
142    prefetch_hits: AtomicU64,
143    prefetch_misses: AtomicU64,
144}
145
146impl PredictivePrefetcher {
147    /// Create a new prefetcher
148    pub fn new(config: DistribCacheConfig) -> Self {
149        Self {
150            config,
151            patterns: DashMap::new(),
152            session_sequences: DashMap::new(),
153            temporal_patterns: TemporalPatternStore::new(),
154            prefetch_queue: Arc::new(PrefetchQueue::new()),
155            running: AtomicBool::new(false),
156            predictions_made: AtomicU64::new(0),
157            prefetch_hits: AtomicU64::new(0),
158            prefetch_misses: AtomicU64::new(0),
159        }
160    }
161
162    /// Record a query for pattern learning
163    pub fn record(&self, session: &SessionId, fingerprint: QueryFingerprint) {
164        // Get or create session sequence
165        let mut seq = self
166            .session_sequences
167            .entry(session.clone())
168            .or_insert_with(|| VecDeque::with_capacity(100));
169
170        // Learn pattern from sequence
171        if !seq.is_empty() {
172            if let Some(prev) = seq.back() {
173                self.patterns
174                    .entry(prev.clone())
175                    .or_default()
176                    .push(fingerprint.clone());
177            }
178        }
179
180        // Add to sequence
181        seq.push_back(fingerprint.clone());
182
183        // Maintain size limit
184        while seq.len() > 100 {
185            seq.pop_front();
186        }
187
188        // Record temporal pattern
189        let now = chrono::Utc::now();
190        self.temporal_patterns.record(
191            &fingerprint,
192            now.hour() as usize,
193            now.weekday().num_days_from_monday() as usize,
194        );
195    }
196
197    /// Predict and enqueue prefetch requests
198    pub fn predict_and_prefetch(&self, current: &QueryFingerprint, _session: &SessionId) {
199        if !self.config.prefetch_enabled {
200            return;
201        }
202
203        // 1. Pattern-based prediction
204        if let Some(next_queries) = self.patterns.get(current) {
205            let predictions = self.get_top_predictions(next_queries.value());
206
207            for (fingerprint, confidence) in predictions {
208                if confidence > self.config.prefetch_confidence_threshold {
209                    self.prefetch_queue.enqueue(PrefetchRequest {
210                        fingerprint,
211                        priority: (confidence * 100.0) as u32,
212                    });
213                    self.predictions_made.fetch_add(1, Ordering::Relaxed);
214                }
215            }
216        }
217
218        // 2. Temporal prediction
219        let hour = chrono::Utc::now().hour() as usize;
220        let temporal_predictions = self.temporal_patterns.predict_for_hour(hour);
221
222        for fingerprint in temporal_predictions
223            .into_iter()
224            .take(self.config.prefetch_lookahead as usize)
225        {
226            self.prefetch_queue.enqueue(PrefetchRequest {
227                fingerprint,
228                priority: 50, // Medium priority for temporal
229            });
230        }
231    }
232
233    /// Get top predictions with confidence scores
234    fn get_top_predictions(
235        &self,
236        next_queries: &[QueryFingerprint],
237    ) -> Vec<(QueryFingerprint, f32)> {
238        // Count occurrences
239        let mut counts: HashMap<&QueryFingerprint, u32> = HashMap::new();
240        for fp in next_queries {
241            *counts.entry(fp).or_default() += 1;
242        }
243
244        let total = next_queries.len() as f32;
245
246        // Calculate confidence and sort
247        let mut predictions: Vec<_> = counts
248            .into_iter()
249            .map(|(fp, count)| (fp.clone(), count as f32 / total))
250            .collect();
251
252        predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
253        predictions
254            .into_iter()
255            .take(self.config.prefetch_lookahead as usize)
256            .collect()
257    }
258
259    /// Start the prefetch background worker
260    pub async fn start(&self) {
261        self.running.store(true, Ordering::SeqCst);
262
263        // In production, this would spawn a background task
264        // that processes the prefetch queue
265    }
266
267    /// Stop the prefetcher
268    pub async fn stop(&self) {
269        self.running.store(false, Ordering::SeqCst);
270    }
271
272    /// Record prefetch hit
273    pub fn record_hit(&self) {
274        self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
275    }
276
277    /// Record prefetch miss
278    pub fn record_miss(&self) {
279        self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
280    }
281
282    /// Get prefetch statistics
283    pub fn stats(&self) -> PrefetchStats {
284        let hits = self.prefetch_hits.load(Ordering::Relaxed);
285        let misses = self.prefetch_misses.load(Ordering::Relaxed);
286
287        PrefetchStats {
288            predictions_made: self.predictions_made.load(Ordering::Relaxed),
289            queue_size: self.prefetch_queue.len(),
290            hit_rate: if hits + misses > 0 {
291                hits as f64 / (hits + misses) as f64
292            } else {
293                0.0
294            },
295            patterns_learned: self.patterns.len(),
296            sessions_tracked: self.session_sequences.len(),
297        }
298    }
299
300    /// Clean up old sessions
301    pub fn cleanup_old_sessions(&self, _max_age: std::time::Duration) {
302        // In production, track timestamps and clean up
303        // For now, just limit total sessions
304        if self.session_sequences.len() > 10000 {
305            // Remove random entries to stay under limit
306            let to_remove: Vec<_> = self
307                .session_sequences
308                .iter()
309                .take(1000)
310                .map(|e| e.key().clone())
311                .collect();
312
313            for key in to_remove {
314                self.session_sequences.remove(&key);
315            }
316        }
317    }
318}
319
320/// Prefetch statistics
321#[derive(Debug, Clone)]
322pub struct PrefetchStats {
323    /// Total predictions made
324    pub predictions_made: u64,
325    /// Current queue size
326    pub queue_size: usize,
327    /// Prefetch hit rate
328    pub hit_rate: f64,
329    /// Number of patterns learned
330    pub patterns_learned: usize,
331    /// Number of sessions tracked
332    pub sessions_tracked: usize,
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_prefetch_queue() {
341        let queue = PrefetchQueue::new();
342
343        let fp1 = QueryFingerprint::from_query("SELECT 1");
344        let fp2 = QueryFingerprint::from_query("SELECT 2");
345        let fp3 = QueryFingerprint::from_query("SELECT 3");
346
347        // Add with different priorities
348        queue.enqueue(PrefetchRequest {
349            fingerprint: fp1.clone(),
350            priority: 50,
351        });
352        queue.enqueue(PrefetchRequest {
353            fingerprint: fp2.clone(),
354            priority: 100,
355        });
356        queue.enqueue(PrefetchRequest {
357            fingerprint: fp3.clone(),
358            priority: 25,
359        });
360
361        assert_eq!(queue.len(), 3);
362    }
363
364    #[test]
365    fn test_pattern_learning() {
366        let config = DistribCacheConfig::default();
367        let prefetcher = PredictivePrefetcher::new(config);
368        let session = SessionId::new("test");
369
370        let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
371        let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
372        let fp3 = QueryFingerprint::from_query("SELECT * FROM items");
373
374        // Simulate sequence: fp1 -> fp2 -> fp3
375        prefetcher.record(&session, fp1.clone());
376        prefetcher.record(&session, fp2.clone());
377        prefetcher.record(&session, fp3.clone());
378
379        // Pattern fp1 -> fp2 should be learned
380        assert!(prefetcher.patterns.contains_key(&fp1));
381        let next = prefetcher.patterns.get(&fp1).unwrap();
382        assert!(next.contains(&fp2));
383    }
384
385    #[test]
386    fn test_prediction() {
387        let config = DistribCacheConfig::builder()
388            .prefetch_enabled(true)
389            .prefetch_confidence_threshold(0.0) // Accept all predictions for test
390            .build();
391        let prefetcher = PredictivePrefetcher::new(config);
392        let session = SessionId::new("test");
393
394        // Train pattern: query1 -> query2 (repeated)
395        let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
396        let fp2 = QueryFingerprint::from_query("SELECT * FROM orders WHERE user_id = ?");
397
398        for _ in 0..10 {
399            prefetcher.record(&session, fp1.clone());
400            prefetcher.record(&session, fp2.clone());
401        }
402
403        // Now predict after fp1
404        prefetcher.predict_and_prefetch(&fp1, &session);
405
406        // Should have enqueued prefetch for fp2
407        assert!(!prefetcher.prefetch_queue.is_empty());
408    }
409
410    #[test]
411    fn test_temporal_patterns() {
412        let store = TemporalPatternStore::new();
413        let fp = QueryFingerprint::from_query("SELECT * FROM reports");
414
415        // Record at hour 9 multiple times
416        for _ in 0..10 {
417            store.record(&fp, 9, 1);
418        }
419
420        // Predict for hour 9 should include our query
421        let predictions = store.predict_for_hour(9);
422        assert!(predictions.contains(&fp));
423    }
424
425    #[test]
426    fn test_stats() {
427        let config = DistribCacheConfig::default();
428        let prefetcher = PredictivePrefetcher::new(config);
429
430        prefetcher.record_hit();
431        prefetcher.record_hit();
432        prefetcher.record_miss();
433
434        let stats = prefetcher.stats();
435        assert!((stats.hit_rate - 0.666).abs() < 0.01);
436    }
437}