heliosdb_proxy/distribcache/
prefetcher.rs1use chrono::{Datelike, Timelike};
7use dashmap::DashMap;
8use std::collections::{HashMap, VecDeque};
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use super::{DistribCacheConfig, QueryFingerprint, SessionId};
13
14#[derive(Debug, Clone)]
16pub struct PrefetchRequest {
17 pub fingerprint: QueryFingerprint,
19 pub priority: u32,
21}
22
23pub struct PrefetchQueue {
25 queue: std::sync::Mutex<VecDeque<PrefetchRequest>>,
27 notify: tokio::sync::Notify,
29}
30
31impl PrefetchQueue {
32 fn new() -> Self {
33 Self {
34 queue: std::sync::Mutex::new(VecDeque::new()),
35 notify: tokio::sync::Notify::new(),
36 }
37 }
38
39 pub fn enqueue(&self, request: PrefetchRequest) {
40 let mut queue = self.queue.lock().unwrap();
41
42 let pos = queue
44 .iter()
45 .position(|r| r.priority < request.priority)
46 .unwrap_or(queue.len());
47
48 queue.insert(pos, request);
49 self.notify.notify_one();
50 }
51
52 pub async fn dequeue(&self) -> Option<PrefetchRequest> {
53 loop {
54 {
55 let mut queue = self.queue.lock().unwrap();
56 if let Some(request) = queue.pop_front() {
57 return Some(request);
58 }
59 }
60 self.notify.notified().await;
61 }
62 }
63
64 pub fn len(&self) -> usize {
65 self.queue.lock().unwrap().len()
66 }
67
68 pub fn is_empty(&self) -> bool {
69 self.queue.lock().unwrap().is_empty()
70 }
71}
72
73pub struct TemporalPatternStore {
75 hourly_patterns: [DashMap<QueryFingerprint, u64>; 24],
77 daily_patterns: [DashMap<QueryFingerprint, u64>; 7],
79}
80
81impl TemporalPatternStore {
82 fn new() -> Self {
83 Self {
84 hourly_patterns: std::array::from_fn(|_| DashMap::new()),
85 daily_patterns: std::array::from_fn(|_| DashMap::new()),
86 }
87 }
88
89 fn record(&self, fingerprint: &QueryFingerprint, hour: usize, weekday: usize) {
90 if hour < 24 {
91 self.hourly_patterns[hour]
92 .entry(fingerprint.clone())
93 .and_modify(|c| *c += 1)
94 .or_insert(1);
95 }
96 if weekday < 7 {
97 self.daily_patterns[weekday]
98 .entry(fingerprint.clone())
99 .and_modify(|c| *c += 1)
100 .or_insert(1);
101 }
102 }
103
104 fn predict_for_hour(&self, hour: usize) -> Vec<QueryFingerprint> {
105 if hour >= 24 {
106 return Vec::new();
107 }
108
109 let patterns = &self.hourly_patterns[hour];
110 let mut predictions: Vec<_> = patterns
111 .iter()
112 .map(|e| (e.key().clone(), *e.value()))
113 .collect();
114
115 predictions.sort_by_key(|b| std::cmp::Reverse(b.1));
116 predictions.into_iter().take(10).map(|(fp, _)| fp).collect()
117 }
118}
119
120pub struct PredictivePrefetcher {
122 config: DistribCacheConfig,
124
125 patterns: DashMap<QueryFingerprint, Vec<QueryFingerprint>>,
127
128 session_sequences: DashMap<SessionId, VecDeque<QueryFingerprint>>,
130
131 temporal_patterns: TemporalPatternStore,
133
134 prefetch_queue: Arc<PrefetchQueue>,
136
137 running: AtomicBool,
139
140 predictions_made: AtomicU64,
142 prefetch_hits: AtomicU64,
143 prefetch_misses: AtomicU64,
144}
145
146impl PredictivePrefetcher {
147 pub fn new(config: DistribCacheConfig) -> Self {
149 Self {
150 config,
151 patterns: DashMap::new(),
152 session_sequences: DashMap::new(),
153 temporal_patterns: TemporalPatternStore::new(),
154 prefetch_queue: Arc::new(PrefetchQueue::new()),
155 running: AtomicBool::new(false),
156 predictions_made: AtomicU64::new(0),
157 prefetch_hits: AtomicU64::new(0),
158 prefetch_misses: AtomicU64::new(0),
159 }
160 }
161
162 pub fn record(&self, session: &SessionId, fingerprint: QueryFingerprint) {
164 let mut seq = self
166 .session_sequences
167 .entry(session.clone())
168 .or_insert_with(|| VecDeque::with_capacity(100));
169
170 if !seq.is_empty() {
172 if let Some(prev) = seq.back() {
173 self.patterns
174 .entry(prev.clone())
175 .or_default()
176 .push(fingerprint.clone());
177 }
178 }
179
180 seq.push_back(fingerprint.clone());
182
183 while seq.len() > 100 {
185 seq.pop_front();
186 }
187
188 let now = chrono::Utc::now();
190 self.temporal_patterns.record(
191 &fingerprint,
192 now.hour() as usize,
193 now.weekday().num_days_from_monday() as usize,
194 );
195 }
196
197 pub fn predict_and_prefetch(&self, current: &QueryFingerprint, _session: &SessionId) {
199 if !self.config.prefetch_enabled {
200 return;
201 }
202
203 if let Some(next_queries) = self.patterns.get(current) {
205 let predictions = self.get_top_predictions(next_queries.value());
206
207 for (fingerprint, confidence) in predictions {
208 if confidence > self.config.prefetch_confidence_threshold {
209 self.prefetch_queue.enqueue(PrefetchRequest {
210 fingerprint,
211 priority: (confidence * 100.0) as u32,
212 });
213 self.predictions_made.fetch_add(1, Ordering::Relaxed);
214 }
215 }
216 }
217
218 let hour = chrono::Utc::now().hour() as usize;
220 let temporal_predictions = self.temporal_patterns.predict_for_hour(hour);
221
222 for fingerprint in temporal_predictions
223 .into_iter()
224 .take(self.config.prefetch_lookahead as usize)
225 {
226 self.prefetch_queue.enqueue(PrefetchRequest {
227 fingerprint,
228 priority: 50, });
230 }
231 }
232
233 fn get_top_predictions(
235 &self,
236 next_queries: &[QueryFingerprint],
237 ) -> Vec<(QueryFingerprint, f32)> {
238 let mut counts: HashMap<&QueryFingerprint, u32> = HashMap::new();
240 for fp in next_queries {
241 *counts.entry(fp).or_default() += 1;
242 }
243
244 let total = next_queries.len() as f32;
245
246 let mut predictions: Vec<_> = counts
248 .into_iter()
249 .map(|(fp, count)| (fp.clone(), count as f32 / total))
250 .collect();
251
252 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
253 predictions
254 .into_iter()
255 .take(self.config.prefetch_lookahead as usize)
256 .collect()
257 }
258
259 pub async fn start(&self) {
261 self.running.store(true, Ordering::SeqCst);
262
263 }
266
267 pub async fn stop(&self) {
269 self.running.store(false, Ordering::SeqCst);
270 }
271
272 pub fn record_hit(&self) {
274 self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
275 }
276
277 pub fn record_miss(&self) {
279 self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
280 }
281
282 pub fn stats(&self) -> PrefetchStats {
284 let hits = self.prefetch_hits.load(Ordering::Relaxed);
285 let misses = self.prefetch_misses.load(Ordering::Relaxed);
286
287 PrefetchStats {
288 predictions_made: self.predictions_made.load(Ordering::Relaxed),
289 queue_size: self.prefetch_queue.len(),
290 hit_rate: if hits + misses > 0 {
291 hits as f64 / (hits + misses) as f64
292 } else {
293 0.0
294 },
295 patterns_learned: self.patterns.len(),
296 sessions_tracked: self.session_sequences.len(),
297 }
298 }
299
300 pub fn cleanup_old_sessions(&self, _max_age: std::time::Duration) {
302 if self.session_sequences.len() > 10000 {
305 let to_remove: Vec<_> = self
307 .session_sequences
308 .iter()
309 .take(1000)
310 .map(|e| e.key().clone())
311 .collect();
312
313 for key in to_remove {
314 self.session_sequences.remove(&key);
315 }
316 }
317 }
318}
319
320#[derive(Debug, Clone)]
322pub struct PrefetchStats {
323 pub predictions_made: u64,
325 pub queue_size: usize,
327 pub hit_rate: f64,
329 pub patterns_learned: usize,
331 pub sessions_tracked: usize,
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_prefetch_queue() {
341 let queue = PrefetchQueue::new();
342
343 let fp1 = QueryFingerprint::from_query("SELECT 1");
344 let fp2 = QueryFingerprint::from_query("SELECT 2");
345 let fp3 = QueryFingerprint::from_query("SELECT 3");
346
347 queue.enqueue(PrefetchRequest {
349 fingerprint: fp1.clone(),
350 priority: 50,
351 });
352 queue.enqueue(PrefetchRequest {
353 fingerprint: fp2.clone(),
354 priority: 100,
355 });
356 queue.enqueue(PrefetchRequest {
357 fingerprint: fp3.clone(),
358 priority: 25,
359 });
360
361 assert_eq!(queue.len(), 3);
362 }
363
364 #[test]
365 fn test_pattern_learning() {
366 let config = DistribCacheConfig::default();
367 let prefetcher = PredictivePrefetcher::new(config);
368 let session = SessionId::new("test");
369
370 let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
371 let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
372 let fp3 = QueryFingerprint::from_query("SELECT * FROM items");
373
374 prefetcher.record(&session, fp1.clone());
376 prefetcher.record(&session, fp2.clone());
377 prefetcher.record(&session, fp3.clone());
378
379 assert!(prefetcher.patterns.contains_key(&fp1));
381 let next = prefetcher.patterns.get(&fp1).unwrap();
382 assert!(next.contains(&fp2));
383 }
384
385 #[test]
386 fn test_prediction() {
387 let config = DistribCacheConfig::builder()
388 .prefetch_enabled(true)
389 .prefetch_confidence_threshold(0.0) .build();
391 let prefetcher = PredictivePrefetcher::new(config);
392 let session = SessionId::new("test");
393
394 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
396 let fp2 = QueryFingerprint::from_query("SELECT * FROM orders WHERE user_id = ?");
397
398 for _ in 0..10 {
399 prefetcher.record(&session, fp1.clone());
400 prefetcher.record(&session, fp2.clone());
401 }
402
403 prefetcher.predict_and_prefetch(&fp1, &session);
405
406 assert!(!prefetcher.prefetch_queue.is_empty());
408 }
409
410 #[test]
411 fn test_temporal_patterns() {
412 let store = TemporalPatternStore::new();
413 let fp = QueryFingerprint::from_query("SELECT * FROM reports");
414
415 for _ in 0..10 {
417 store.record(&fp, 9, 1);
418 }
419
420 let predictions = store.predict_for_hour(9);
422 assert!(predictions.contains(&fp));
423 }
424
425 #[test]
426 fn test_stats() {
427 let config = DistribCacheConfig::default();
428 let prefetcher = PredictivePrefetcher::new(config);
429
430 prefetcher.record_hit();
431 prefetcher.record_hit();
432 prefetcher.record_miss();
433
434 let stats = prefetcher.stats();
435 assert!((stats.hit_rate - 0.666).abs() < 0.01);
436 }
437}