heliosdb_proxy/distribcache/
prefetcher.rs1use chrono::{Datelike, Timelike};
7use dashmap::DashMap;
8use std::collections::{HashMap, VecDeque};
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use super::{DistribCacheConfig, QueryFingerprint, SessionId};
13
14#[derive(Debug, Clone)]
16pub struct PrefetchRequest {
17 pub fingerprint: QueryFingerprint,
19 pub priority: u32,
21}
22
23pub struct PrefetchQueue {
25 queue: std::sync::Mutex<VecDeque<PrefetchRequest>>,
27 notify: tokio::sync::Notify,
29}
30
31impl PrefetchQueue {
32 fn new() -> Self {
33 Self {
34 queue: std::sync::Mutex::new(VecDeque::new()),
35 notify: tokio::sync::Notify::new(),
36 }
37 }
38
39 pub fn enqueue(&self, request: PrefetchRequest) {
40 let mut queue = self.queue.lock().unwrap();
41
42 let pos = queue.iter()
44 .position(|r| r.priority < request.priority)
45 .unwrap_or(queue.len());
46
47 queue.insert(pos, request);
48 self.notify.notify_one();
49 }
50
51 pub async fn dequeue(&self) -> Option<PrefetchRequest> {
52 loop {
53 {
54 let mut queue = self.queue.lock().unwrap();
55 if let Some(request) = queue.pop_front() {
56 return Some(request);
57 }
58 }
59 self.notify.notified().await;
60 }
61 }
62
63 pub fn len(&self) -> usize {
64 self.queue.lock().unwrap().len()
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.queue.lock().unwrap().is_empty()
69 }
70}
71
72pub struct TemporalPatternStore {
74 hourly_patterns: [DashMap<QueryFingerprint, u64>; 24],
76 daily_patterns: [DashMap<QueryFingerprint, u64>; 7],
78}
79
80impl TemporalPatternStore {
81 fn new() -> Self {
82 Self {
83 hourly_patterns: std::array::from_fn(|_| DashMap::new()),
84 daily_patterns: std::array::from_fn(|_| DashMap::new()),
85 }
86 }
87
88 fn record(&self, fingerprint: &QueryFingerprint, hour: usize, weekday: usize) {
89 if hour < 24 {
90 self.hourly_patterns[hour]
91 .entry(fingerprint.clone())
92 .and_modify(|c| *c += 1)
93 .or_insert(1);
94 }
95 if weekday < 7 {
96 self.daily_patterns[weekday]
97 .entry(fingerprint.clone())
98 .and_modify(|c| *c += 1)
99 .or_insert(1);
100 }
101 }
102
103 fn predict_for_hour(&self, hour: usize) -> Vec<QueryFingerprint> {
104 if hour >= 24 {
105 return Vec::new();
106 }
107
108 let patterns = &self.hourly_patterns[hour];
109 let mut predictions: Vec<_> = patterns.iter()
110 .map(|e| (e.key().clone(), *e.value()))
111 .collect();
112
113 predictions.sort_by(|a, b| b.1.cmp(&a.1));
114 predictions.into_iter()
115 .take(10)
116 .map(|(fp, _)| fp)
117 .collect()
118 }
119}
120
121pub struct PredictivePrefetcher {
123 config: DistribCacheConfig,
125
126 patterns: DashMap<QueryFingerprint, Vec<QueryFingerprint>>,
128
129 session_sequences: DashMap<SessionId, VecDeque<QueryFingerprint>>,
131
132 temporal_patterns: TemporalPatternStore,
134
135 prefetch_queue: Arc<PrefetchQueue>,
137
138 running: AtomicBool,
140
141 predictions_made: AtomicU64,
143 prefetch_hits: AtomicU64,
144 prefetch_misses: AtomicU64,
145}
146
147impl PredictivePrefetcher {
148 pub fn new(config: DistribCacheConfig) -> Self {
150 Self {
151 config,
152 patterns: DashMap::new(),
153 session_sequences: DashMap::new(),
154 temporal_patterns: TemporalPatternStore::new(),
155 prefetch_queue: Arc::new(PrefetchQueue::new()),
156 running: AtomicBool::new(false),
157 predictions_made: AtomicU64::new(0),
158 prefetch_hits: AtomicU64::new(0),
159 prefetch_misses: AtomicU64::new(0),
160 }
161 }
162
163 pub fn record(&self, session: &SessionId, fingerprint: QueryFingerprint) {
165 let mut seq = self.session_sequences
167 .entry(session.clone())
168 .or_insert_with(|| VecDeque::with_capacity(100));
169
170 if seq.len() >= 1 {
172 if let Some(prev) = seq.back() {
173 self.patterns
174 .entry(prev.clone())
175 .or_default()
176 .push(fingerprint.clone());
177 }
178 }
179
180 seq.push_back(fingerprint.clone());
182
183 while seq.len() > 100 {
185 seq.pop_front();
186 }
187
188 let now = chrono::Utc::now();
190 self.temporal_patterns.record(
191 &fingerprint,
192 now.hour() as usize,
193 now.weekday().num_days_from_monday() as usize,
194 );
195 }
196
197 pub fn predict_and_prefetch(&self, current: &QueryFingerprint, _session: &SessionId) {
199 if !self.config.prefetch_enabled {
200 return;
201 }
202
203 if let Some(next_queries) = self.patterns.get(current) {
205 let predictions = self.get_top_predictions(next_queries.value());
206
207 for (fingerprint, confidence) in predictions {
208 if confidence > self.config.prefetch_confidence_threshold {
209 self.prefetch_queue.enqueue(PrefetchRequest {
210 fingerprint,
211 priority: (confidence * 100.0) as u32,
212 });
213 self.predictions_made.fetch_add(1, Ordering::Relaxed);
214 }
215 }
216 }
217
218 let hour = chrono::Utc::now().hour() as usize;
220 let temporal_predictions = self.temporal_patterns.predict_for_hour(hour);
221
222 for fingerprint in temporal_predictions.into_iter().take(self.config.prefetch_lookahead as usize) {
223 self.prefetch_queue.enqueue(PrefetchRequest {
224 fingerprint,
225 priority: 50, });
227 }
228 }
229
230 fn get_top_predictions(&self, next_queries: &[QueryFingerprint]) -> Vec<(QueryFingerprint, f32)> {
232 let mut counts: HashMap<&QueryFingerprint, u32> = HashMap::new();
234 for fp in next_queries {
235 *counts.entry(fp).or_default() += 1;
236 }
237
238 let total = next_queries.len() as f32;
239
240 let mut predictions: Vec<_> = counts.into_iter()
242 .map(|(fp, count)| (fp.clone(), count as f32 / total))
243 .collect();
244
245 predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246 predictions.into_iter()
247 .take(self.config.prefetch_lookahead as usize)
248 .collect()
249 }
250
251 pub async fn start(&self) {
253 self.running.store(true, Ordering::SeqCst);
254
255 }
258
259 pub async fn stop(&self) {
261 self.running.store(false, Ordering::SeqCst);
262 }
263
264 pub fn record_hit(&self) {
266 self.prefetch_hits.fetch_add(1, Ordering::Relaxed);
267 }
268
269 pub fn record_miss(&self) {
271 self.prefetch_misses.fetch_add(1, Ordering::Relaxed);
272 }
273
274 pub fn stats(&self) -> PrefetchStats {
276 let hits = self.prefetch_hits.load(Ordering::Relaxed);
277 let misses = self.prefetch_misses.load(Ordering::Relaxed);
278
279 PrefetchStats {
280 predictions_made: self.predictions_made.load(Ordering::Relaxed),
281 queue_size: self.prefetch_queue.len(),
282 hit_rate: if hits + misses > 0 {
283 hits as f64 / (hits + misses) as f64
284 } else {
285 0.0
286 },
287 patterns_learned: self.patterns.len(),
288 sessions_tracked: self.session_sequences.len(),
289 }
290 }
291
292 pub fn cleanup_old_sessions(&self, _max_age: std::time::Duration) {
294 if self.session_sequences.len() > 10000 {
297 let to_remove: Vec<_> = self.session_sequences.iter()
299 .take(1000)
300 .map(|e| e.key().clone())
301 .collect();
302
303 for key in to_remove {
304 self.session_sequences.remove(&key);
305 }
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct PrefetchStats {
313 pub predictions_made: u64,
315 pub queue_size: usize,
317 pub hit_rate: f64,
319 pub patterns_learned: usize,
321 pub sessions_tracked: usize,
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_prefetch_queue() {
331 let queue = PrefetchQueue::new();
332
333 let fp1 = QueryFingerprint::from_query("SELECT 1");
334 let fp2 = QueryFingerprint::from_query("SELECT 2");
335 let fp3 = QueryFingerprint::from_query("SELECT 3");
336
337 queue.enqueue(PrefetchRequest { fingerprint: fp1.clone(), priority: 50 });
339 queue.enqueue(PrefetchRequest { fingerprint: fp2.clone(), priority: 100 });
340 queue.enqueue(PrefetchRequest { fingerprint: fp3.clone(), priority: 25 });
341
342 assert_eq!(queue.len(), 3);
343 }
344
345 #[test]
346 fn test_pattern_learning() {
347 let config = DistribCacheConfig::default();
348 let prefetcher = PredictivePrefetcher::new(config);
349 let session = SessionId::new("test");
350
351 let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
352 let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
353 let fp3 = QueryFingerprint::from_query("SELECT * FROM items");
354
355 prefetcher.record(&session, fp1.clone());
357 prefetcher.record(&session, fp2.clone());
358 prefetcher.record(&session, fp3.clone());
359
360 assert!(prefetcher.patterns.contains_key(&fp1));
362 let next = prefetcher.patterns.get(&fp1).unwrap();
363 assert!(next.contains(&fp2));
364 }
365
366 #[test]
367 fn test_prediction() {
368 let config = DistribCacheConfig::builder()
369 .prefetch_enabled(true)
370 .prefetch_confidence_threshold(0.0) .build();
372 let prefetcher = PredictivePrefetcher::new(config);
373 let session = SessionId::new("test");
374
375 let fp1 = QueryFingerprint::from_query("SELECT * FROM users WHERE id = ?");
377 let fp2 = QueryFingerprint::from_query("SELECT * FROM orders WHERE user_id = ?");
378
379 for _ in 0..10 {
380 prefetcher.record(&session, fp1.clone());
381 prefetcher.record(&session, fp2.clone());
382 }
383
384 prefetcher.predict_and_prefetch(&fp1, &session);
386
387 assert!(!prefetcher.prefetch_queue.is_empty());
389 }
390
391 #[test]
392 fn test_temporal_patterns() {
393 let store = TemporalPatternStore::new();
394 let fp = QueryFingerprint::from_query("SELECT * FROM reports");
395
396 for _ in 0..10 {
398 store.record(&fp, 9, 1);
399 }
400
401 let predictions = store.predict_for_hour(9);
403 assert!(predictions.contains(&fp));
404 }
405
406 #[test]
407 fn test_stats() {
408 let config = DistribCacheConfig::default();
409 let prefetcher = PredictivePrefetcher::new(config);
410
411 prefetcher.record_hit();
412 prefetcher.record_hit();
413 prefetcher.record_miss();
414
415 let stats = prefetcher.stats();
416 assert!((stats.hit_rate - 0.666).abs() < 0.01);
417 }
418}