1use std::collections::HashMap;
7
8#[derive(Debug, Clone)]
10pub struct QueryStats {
11 pub query_hash: u64,
13 pub actual_rows: usize,
15 pub estimated_rows: usize,
17 pub execution_time_ms: u64,
19 pub join_order: Vec<String>,
21 pub selectivity_by_pattern: HashMap<String, f64>,
23}
24
25impl QueryStats {
26 pub fn new(
28 query_hash: u64,
29 actual_rows: usize,
30 estimated_rows: usize,
31 execution_time_ms: u64,
32 join_order: Vec<String>,
33 selectivity_by_pattern: HashMap<String, f64>,
34 ) -> Self {
35 Self {
36 query_hash,
37 actual_rows,
38 estimated_rows,
39 execution_time_ms,
40 join_order,
41 selectivity_by_pattern,
42 }
43 }
44
45 pub fn accuracy_ratio(&self) -> f64 {
47 if self.estimated_rows == 0 {
48 1.0
49 } else {
50 self.actual_rows as f64 / self.estimated_rows as f64
51 }
52 }
53}
54
55pub struct RuntimeFeedbackStore {
59 stats: HashMap<u64, Vec<QueryStats>>,
60 max_history: usize,
61}
62
63impl RuntimeFeedbackStore {
64 pub fn new(max_history: usize) -> Self {
66 Self {
67 stats: HashMap::new(),
68 max_history,
69 }
70 }
71
72 pub fn record(&mut self, stats: QueryStats) {
77 let history = self.stats.entry(stats.query_hash).or_default();
78 if history.len() >= self.max_history {
79 history.remove(0);
80 }
81 history.push(stats);
82 }
83
84 pub fn get_stats(&self, query_hash: u64) -> &[QueryStats] {
86 self.stats
87 .get(&query_hash)
88 .map(|v| v.as_slice())
89 .unwrap_or(&[])
90 }
91
92 pub fn estimate_selectivity(&self, pattern: &str) -> f64 {
98 let mut total = 0.0_f64;
99 let mut count = 0usize;
100
101 for history in self.stats.values() {
102 for entry in history {
103 if let Some(&sel) = entry.selectivity_by_pattern.get(pattern) {
104 total += sel;
105 count += 1;
106 }
107 }
108 }
109
110 if count == 0 {
111 0.1
112 } else {
113 total / count as f64
114 }
115 }
116
117 pub fn estimate_cardinality(&self, pattern: &str, base_estimate: usize) -> usize {
123 let sel = self.estimate_selectivity(pattern);
124 let adjusted = base_estimate as f64 * sel / 0.1;
126 adjusted.round() as usize
127 }
128
129 pub fn best_join_order(&self, query_hash: u64) -> Option<Vec<String>> {
132 let history = self.stats.get(&query_hash)?;
133 if history.is_empty() {
134 return None;
135 }
136
137 let mut order_times: HashMap<String, (u64, usize)> = HashMap::new();
139 for entry in history {
140 let key = entry.join_order.join(",");
141 let acc = order_times.entry(key).or_default();
142 acc.0 += entry.execution_time_ms;
143 acc.1 += 1;
144 }
145
146 let best_key = order_times
148 .iter()
149 .map(|(k, (total, cnt))| (k, *total / (*cnt as u64).max(1)))
150 .min_by_key(|(_, avg)| *avg)
151 .map(|(k, _)| k.clone())?;
152
153 let parts: Vec<String> = best_key
155 .split(',')
156 .filter(|s| !s.is_empty())
157 .map(str::to_string)
158 .collect();
159 Some(parts)
160 }
161
162 pub fn prune_old(&mut self, max_age_ms: u64) {
167 for history in self.stats.values_mut() {
168 history.retain(|s| s.execution_time_ms < max_age_ms);
169 }
170 self.stats.retain(|_, v| !v.is_empty());
172 }
173
174 pub fn stats_count(&self) -> usize {
176 self.stats.values().map(|v| v.len()).sum()
177 }
178
179 pub fn query_count(&self) -> usize {
181 self.stats.len()
182 }
183}
184
185impl Default for RuntimeFeedbackStore {
186 fn default() -> Self {
187 Self::new(100)
188 }
189}
190
191pub struct AdaptiveQueryOptimizer {
195 feedback: RuntimeFeedbackStore,
196 base_selectivities: HashMap<String, f64>,
197}
198
199impl AdaptiveQueryOptimizer {
200 pub fn new() -> Self {
202 Self {
203 feedback: RuntimeFeedbackStore::new(100),
204 base_selectivities: HashMap::new(),
205 }
206 }
207
208 pub fn with_feedback(feedback: RuntimeFeedbackStore) -> Self {
210 Self {
211 feedback,
212 base_selectivities: HashMap::new(),
213 }
214 }
215
216 pub fn set_base_selectivity(&mut self, pattern: impl Into<String>, selectivity: f64) {
218 self.base_selectivities.insert(pattern.into(), selectivity);
219 }
220
221 pub fn optimize_join_order(&self, patterns: &[String]) -> Vec<String> {
226 let mut with_sel: Vec<(String, f64)> = patterns
227 .iter()
228 .map(|p| {
229 let sel = self
230 .base_selectivities
231 .get(p.as_str())
232 .copied()
233 .unwrap_or_else(|| self.feedback.estimate_selectivity(p));
234 (p.clone(), sel)
235 })
236 .collect();
237
238 with_sel.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
240
241 with_sel.into_iter().map(|(p, _)| p).collect()
242 }
243
244 pub fn feedback(&self) -> &RuntimeFeedbackStore {
246 &self.feedback
247 }
248
249 pub fn feedback_mut(&mut self) -> &mut RuntimeFeedbackStore {
251 &mut self.feedback
252 }
253
254 pub fn record_execution(&mut self, query_hash: u64, mut stats: QueryStats) {
256 stats.query_hash = query_hash;
257 self.feedback.record(stats);
258 }
259
260 pub fn hash_query(query_str: &str) -> u64 {
263 let mut hash: u64 = 5381;
264 for byte in query_str.bytes() {
265 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
266 }
267 hash
268 }
269
270 pub fn estimate_cardinality(&self, pattern: &str, base_estimate: usize) -> usize {
272 self.feedback.estimate_cardinality(pattern, base_estimate)
273 }
274}
275
276impl Default for AdaptiveQueryOptimizer {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn make_stats(
287 query_hash: u64,
288 actual: usize,
289 estimated: usize,
290 time_ms: u64,
291 join_order: &[&str],
292 selectivities: &[(&str, f64)],
293 ) -> QueryStats {
294 let mut sel_map = HashMap::new();
295 for (k, v) in selectivities {
296 sel_map.insert((*k).to_string(), *v);
297 }
298 QueryStats::new(
299 query_hash,
300 actual,
301 estimated,
302 time_ms,
303 join_order.iter().map(|s| s.to_string()).collect(),
304 sel_map,
305 )
306 }
307
308 #[test]
311 fn test_query_stats_accuracy_ratio_normal() {
312 let s = make_stats(1, 50, 100, 10, &[], &[]);
313 let ratio = s.accuracy_ratio();
314 assert!((ratio - 0.5).abs() < 1e-9);
315 }
316
317 #[test]
318 fn test_query_stats_accuracy_ratio_zero_estimated() {
319 let s = make_stats(1, 10, 0, 10, &[], &[]);
320 assert!((s.accuracy_ratio() - 1.0).abs() < 1e-9);
321 }
322
323 #[test]
324 fn test_query_stats_accuracy_ratio_perfect() {
325 let s = make_stats(1, 100, 100, 5, &[], &[]);
326 assert!((s.accuracy_ratio() - 1.0).abs() < 1e-9);
327 }
328
329 #[test]
332 fn test_store_new_is_empty() {
333 let store = RuntimeFeedbackStore::new(10);
334 assert_eq!(store.stats_count(), 0);
335 assert_eq!(store.query_count(), 0);
336 }
337
338 #[test]
339 fn test_store_record_and_get() {
340 let mut store = RuntimeFeedbackStore::new(10);
341 let s = make_stats(42, 10, 20, 5, &["a", "b"], &[("a", 0.1)]);
342 store.record(s);
343
344 let hist = store.get_stats(42);
345 assert_eq!(hist.len(), 1);
346 assert_eq!(hist[0].actual_rows, 10);
347 }
348
349 #[test]
350 fn test_store_max_history_eviction() {
351 let mut store = RuntimeFeedbackStore::new(3);
352 for i in 0..5u64 {
353 store.record(make_stats(99, i as usize, 10, i, &[], &[]));
354 }
355 let hist = store.get_stats(99);
356 assert_eq!(hist.len(), 3);
357 assert_eq!(hist.last().map(|s| s.actual_rows), Some(4));
359 }
360
361 #[test]
362 fn test_store_get_unknown_hash() {
363 let store = RuntimeFeedbackStore::new(10);
364 assert!(store.get_stats(999).is_empty());
365 }
366
367 #[test]
368 fn test_estimate_selectivity_no_data_returns_default() {
369 let store = RuntimeFeedbackStore::new(10);
370 assert!((store.estimate_selectivity("p_unknown") - 0.1).abs() < 1e-9);
371 }
372
373 #[test]
374 fn test_estimate_selectivity_with_data() {
375 let mut store = RuntimeFeedbackStore::new(10);
376 store.record(make_stats(
377 1,
378 10,
379 100,
380 5,
381 &[],
382 &[("age", 0.2), ("name", 0.5)],
383 ));
384 store.record(make_stats(2, 5, 50, 3, &[], &[("age", 0.4)]));
385
386 let sel = store.estimate_selectivity("age");
388 assert!((sel - 0.3).abs() < 1e-9);
389
390 let sel_name = store.estimate_selectivity("name");
392 assert!((sel_name - 0.5).abs() < 1e-9);
393 }
394
395 #[test]
396 fn test_estimate_cardinality() {
397 let mut store = RuntimeFeedbackStore::new(10);
398 store.record(make_stats(1, 10, 100, 5, &[], &[("p", 0.2)]));
399 let card = store.estimate_cardinality("p", 1000);
401 assert_eq!(card, 2000);
402 }
403
404 #[test]
405 fn test_estimate_cardinality_no_data() {
406 let store = RuntimeFeedbackStore::new(10);
407 let card = store.estimate_cardinality("p", 500);
409 assert_eq!(card, 500);
410 }
411
412 #[test]
413 fn test_best_join_order_no_data() {
414 let store = RuntimeFeedbackStore::new(10);
415 assert!(store.best_join_order(42).is_none());
416 }
417
418 #[test]
419 fn test_best_join_order_single_entry() {
420 let mut store = RuntimeFeedbackStore::new(10);
421 store.record(make_stats(1, 10, 10, 50, &["a", "b", "c"], &[]));
422 let order = store.best_join_order(1).expect("should have an order");
423 assert_eq!(order, vec!["a", "b", "c"]);
424 }
425
426 #[test]
427 fn test_best_join_order_selects_fastest() {
428 let mut store = RuntimeFeedbackStore::new(20);
429 store.record(make_stats(5, 10, 10, 100, &["A", "B"], &[]));
431 store.record(make_stats(5, 10, 10, 100, &["A", "B"], &[]));
432 store.record(make_stats(5, 10, 10, 50, &["B", "A"], &[]));
434 store.record(make_stats(5, 10, 10, 50, &["B", "A"], &[]));
435
436 let best = store.best_join_order(5).expect("should have best order");
437 assert_eq!(best, vec!["B", "A"]);
438 }
439
440 #[test]
441 fn test_prune_old() {
442 let mut store = RuntimeFeedbackStore::new(20);
443 store.record(make_stats(1, 10, 10, 5, &[], &[])); store.record(make_stats(1, 10, 10, 15, &[], &[])); store.record(make_stats(2, 10, 10, 3, &[], &[])); store.prune_old(10);
448 assert_eq!(store.get_stats(1).len(), 1);
449 assert_eq!(store.get_stats(2).len(), 1);
450 assert_eq!(store.stats_count(), 2);
451 }
452
453 #[test]
454 fn test_stats_count() {
455 let mut store = RuntimeFeedbackStore::new(20);
456 store.record(make_stats(1, 1, 1, 1, &[], &[]));
457 store.record(make_stats(1, 2, 2, 2, &[], &[]));
458 store.record(make_stats(2, 3, 3, 3, &[], &[]));
459 assert_eq!(store.stats_count(), 3);
460 assert_eq!(store.query_count(), 2);
461 }
462
463 #[test]
466 fn test_optimizer_new_empty_feedback() {
467 let opt = AdaptiveQueryOptimizer::new();
468 assert_eq!(opt.feedback().stats_count(), 0);
469 }
470
471 #[test]
472 fn test_hash_query_deterministic() {
473 let h1 = AdaptiveQueryOptimizer::hash_query("SELECT * WHERE { ?s ?p ?o }");
474 let h2 = AdaptiveQueryOptimizer::hash_query("SELECT * WHERE { ?s ?p ?o }");
475 assert_eq!(h1, h2);
476 }
477
478 #[test]
479 fn test_hash_query_different_inputs() {
480 let h1 = AdaptiveQueryOptimizer::hash_query("query_a");
481 let h2 = AdaptiveQueryOptimizer::hash_query("query_b");
482 assert_ne!(h1, h2);
483 }
484
485 #[test]
486 fn test_optimize_join_order_no_feedback_preserves_order() {
487 let opt = AdaptiveQueryOptimizer::new();
488 let patterns = vec!["p1".to_string(), "p2".to_string(), "p3".to_string()];
489 let result = opt.optimize_join_order(&patterns);
491 assert_eq!(result.len(), 3);
492 }
493
494 #[test]
495 fn test_optimize_join_order_uses_selectivity() {
496 let mut opt = AdaptiveQueryOptimizer::new();
497 opt.set_base_selectivity("low_sel", 0.01);
498 opt.set_base_selectivity("high_sel", 0.9);
499
500 let patterns = vec!["high_sel".to_string(), "low_sel".to_string()];
501 let result = opt.optimize_join_order(&patterns);
502 assert_eq!(result[0], "low_sel"); assert_eq!(result[1], "high_sel");
504 }
505
506 #[test]
507 fn test_record_execution_updates_feedback() {
508 let mut opt = AdaptiveQueryOptimizer::new();
509 let hash = AdaptiveQueryOptimizer::hash_query("my_query");
510 let stats = make_stats(hash, 10, 20, 5, &["a"], &[("a", 0.3)]);
511 opt.record_execution(hash, stats);
512 assert_eq!(opt.feedback().get_stats(hash).len(), 1);
513 }
514
515 #[test]
516 fn test_estimate_cardinality_via_optimizer() {
517 let mut opt = AdaptiveQueryOptimizer::new();
518 let hash = AdaptiveQueryOptimizer::hash_query("q");
519 opt.record_execution(hash, make_stats(hash, 5, 10, 2, &[], &[("p", 0.5)]));
520
521 let card = opt.estimate_cardinality("p", 100);
523 assert_eq!(card, 500);
524 }
525
526 #[test]
527 fn test_with_feedback_constructor() {
528 let mut store = RuntimeFeedbackStore::new(5);
529 store.record(make_stats(1, 1, 1, 1, &[], &[]));
530 let opt = AdaptiveQueryOptimizer::with_feedback(store);
531 assert_eq!(opt.feedback().stats_count(), 1);
532 }
533
534 #[test]
535 fn test_optimize_join_order_empty() {
536 let opt = AdaptiveQueryOptimizer::new();
537 let result = opt.optimize_join_order(&[]);
538 assert!(result.is_empty());
539 }
540}