1use anyhow::{anyhow, Result};
36use scirs2_core::metrics::{Counter, Histogram, Timer};
37use scirs2_core::ndarray_ext::{Array1, Array2};
38use scirs2_core::random::Random;
39use scirs2_core::rngs::StdRng;
40use serde::{Deserialize, Serialize};
41use std::sync::{Arc, RwLock};
42
43pub struct MLQueryOptimizer {
56 training_data: Arc<RwLock<TrainingBuffer>>,
58 prediction_weights: Arc<RwLock<Array1<f32>>>,
60 config: MLOptimizerConfig,
62 #[allow(dead_code)]
64 rng: Random<StdRng>,
65 prediction_counter: Arc<Counter>,
67 training_counter: Arc<Counter>,
69 prediction_timer: Arc<Timer>,
71 training_timer: Arc<Timer>,
73 prediction_error_histogram: Arc<Histogram>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct MLOptimizerConfig {
80 pub training_buffer_size: usize,
82 pub min_training_samples: usize,
84 pub learning_rate: f64,
86 pub enable_adaptive_joins: bool,
88 pub batch_size: usize,
90}
91
92impl Default for MLOptimizerConfig {
93 fn default() -> Self {
94 Self {
95 training_buffer_size: 10000,
96 min_training_samples: 100,
97 learning_rate: 0.001,
98 enable_adaptive_joins: true,
99 batch_size: 128,
100 }
101 }
102}
103
104struct TrainingBuffer {
106 features: Vec<Vec<f32>>,
108 cardinalities: Vec<f32>,
110 execution_times: Vec<f32>,
112 max_size: usize,
114}
115
116impl TrainingBuffer {
117 fn new(max_size: usize) -> Self {
118 Self {
119 features: Vec::with_capacity(max_size),
120 cardinalities: Vec::with_capacity(max_size),
121 execution_times: Vec::with_capacity(max_size),
122 max_size,
123 }
124 }
125
126 fn add(&mut self, features: Vec<f32>, cardinality: f32, execution_time: f32) {
127 if self.features.len() >= self.max_size {
128 self.features.remove(0);
130 self.cardinalities.remove(0);
131 self.execution_times.remove(0);
132 }
133
134 self.features.push(features);
135 self.cardinalities.push(cardinality);
136 self.execution_times.push(execution_time);
137 }
138
139 fn size(&self) -> usize {
140 self.features.len()
141 }
142
143 fn get_batch(&self, size: usize) -> Option<(Array2<f32>, Array1<f32>)> {
144 if self.features.is_empty() {
145 return None;
146 }
147
148 let batch_size = size.min(self.features.len());
149 let feature_dim = self.features[0].len();
150
151 let mut features = Array2::zeros((batch_size, feature_dim));
153 let mut targets = Array1::zeros(batch_size);
154
155 for i in 0..batch_size {
156 for j in 0..feature_dim {
157 features[[i, j]] = self.features[i][j];
158 }
159 targets[i] = self.cardinalities[i];
160 }
161
162 Some((features, targets))
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct PatternFeatures {
169 pub pattern_count: usize,
171 pub bound_variables: usize,
173 pub unbound_variables: usize,
175 pub avg_selectivity: f64,
177 pub join_complexity: f64,
179 pub max_join_depth: usize,
181 pub filter_count: usize,
183 pub has_property_paths: bool,
185 pub has_unions: bool,
187 pub has_optionals: bool,
189}
190
191impl PatternFeatures {
192 pub fn to_vector(&self) -> Vec<f32> {
194 vec![
195 self.pattern_count as f32,
196 self.bound_variables as f32,
197 self.unbound_variables as f32,
198 self.avg_selectivity as f32,
199 self.join_complexity as f32,
200 self.max_join_depth as f32,
201 self.filter_count as f32,
202 if self.has_property_paths { 1.0 } else { 0.0 },
203 if self.has_unions { 1.0 } else { 0.0 },
204 if self.has_optionals { 1.0 } else { 0.0 },
205 ]
206 }
207
208 pub const FEATURE_DIM: usize = 10;
210}
211
212#[derive(Debug, Clone)]
214pub struct MLOptimizationResult {
215 pub predicted_cardinality: usize,
217 pub confidence: f64,
219 pub join_order: Vec<usize>,
221 pub estimated_time_ms: f64,
223 pub use_gpu: bool,
225 pub use_parallel: bool,
227}
228
229impl MLQueryOptimizer {
230 pub fn new() -> Self {
235 Self::with_config(MLOptimizerConfig::default())
236 }
237
238 pub fn with_config(config: MLOptimizerConfig) -> Self {
240 let training_data = Arc::new(RwLock::new(TrainingBuffer::new(
242 config.training_buffer_size,
243 )));
244
245 let initial_weights = Array1::from(vec![
247 100.0, 50.0, 200.0, 1000.0, 150.0, 80.0, 30.0, 500.0, 300.0, 200.0, ]);
258 let prediction_weights = Arc::new(RwLock::new(initial_weights));
259
260 let prediction_counter = Arc::new(Counter::new("ml_optimizer_predictions".to_string()));
262 let training_counter = Arc::new(Counter::new("ml_optimizer_training".to_string()));
263 let prediction_timer = Arc::new(Timer::new("ml_optimizer_prediction_time".to_string()));
264 let training_timer = Arc::new(Timer::new("ml_optimizer_training_time".to_string()));
265 let prediction_error_histogram =
266 Arc::new(Histogram::new("ml_optimizer_prediction_error".to_string()));
267
268 Self {
269 training_data,
270 prediction_weights,
271 config,
272 rng: Random::seed(42),
273 prediction_counter,
274 training_counter,
275 prediction_timer,
276 training_timer,
277 prediction_error_histogram,
278 }
279 }
280
281 pub fn predict_cardinality(&self, features: &PatternFeatures) -> Result<usize> {
286 let _timer_guard = self.prediction_timer.start();
288 self.prediction_counter.inc();
289
290 let feature_vec = features.to_vector();
291
292 let buffer = self
294 .training_data
295 .read()
296 .map_err(|e| anyhow!("Lock error: {}", e))?;
297 if buffer.size() < self.config.min_training_samples {
298 drop(buffer);
300 return Ok(self.heuristic_cardinality(features));
301 }
302 drop(buffer);
303
304 let input = Array1::from(feature_vec);
306
307 let prediction = self.predict_with_weights(&input)? as usize;
309
310 Ok(prediction)
311 }
312
313 fn predict_with_weights(&self, input: &Array1<f32>) -> Result<f32> {
315 let weights = self
316 .prediction_weights
317 .read()
318 .map_err(|e| anyhow!("Lock error: {}", e))?;
319
320 let prediction = input
321 .iter()
322 .zip(weights.iter())
323 .map(|(x, w)| x * w)
324 .sum::<f32>();
325
326 Ok(prediction.max(1.0)) }
328
329 fn heuristic_cardinality(&self, features: &PatternFeatures) -> usize {
331 let base = 1000; let mut estimate = base;
334
335 estimate *= features.pattern_count.max(1);
336 estimate = (estimate as f64 * features.avg_selectivity) as usize;
337
338 if features.has_unions {
339 estimate *= 2;
340 }
341 if features.has_property_paths {
342 estimate *= 3;
343 }
344
345 estimate.max(1)
346 }
347
348 pub fn optimize_join_order(
353 &self,
354 pattern_count: usize,
355 features: &PatternFeatures,
356 ) -> Result<Vec<usize>> {
357 if pattern_count <= 1 {
358 return Ok(vec![0]);
359 }
360
361 if !self.config.enable_adaptive_joins {
362 return Ok((0..pattern_count).collect());
364 }
365
366 let mut order: Vec<usize> = (0..pattern_count).collect();
368
369 if features.avg_selectivity < 0.1 {
371 } else if features.avg_selectivity > 0.5 {
373 order.reverse();
375 } else {
376 let mut reordered = Vec::with_capacity(pattern_count);
379 let mid = pattern_count / 2;
380 for i in 0..mid {
381 reordered.push(i);
382 if i + mid < pattern_count {
383 reordered.push(i + mid);
384 }
385 }
386 if pattern_count % 2 != 0 {
387 reordered.push(pattern_count - 1);
388 }
389 order = reordered;
390 }
391
392 Ok(order)
393 }
394
395 pub fn train_from_execution(
401 &mut self,
402 features: PatternFeatures,
403 actual_cardinality: usize,
404 execution_time_ms: f64,
405 ) -> Result<()> {
406 let _timer_guard = self.training_timer.start();
407 self.training_counter.inc();
408
409 if let Ok(predicted) = self.predict_cardinality(&features) {
411 let error_rate = if actual_cardinality > 0 {
412 (predicted as f64 - actual_cardinality as f64).abs() / actual_cardinality as f64
413 } else {
414 0.0
415 };
416 self.prediction_error_histogram.observe(error_rate);
417 }
418
419 let feature_vec = features.to_vector();
420
421 let mut buffer = self
423 .training_data
424 .write()
425 .map_err(|e| anyhow!("Lock error: {}", e))?;
426 buffer.add(
427 feature_vec,
428 actual_cardinality as f32,
429 execution_time_ms as f32,
430 );
431
432 let buffer_size = buffer.size();
433 drop(buffer);
434
435 if buffer_size >= self.config.min_training_samples && buffer_size % 100 == 0 {
437 self.retrain_models()?;
438 }
439
440 Ok(())
441 }
442
443 fn retrain_models(&self) -> Result<()> {
445 let buffer = self
446 .training_data
447 .read()
448 .map_err(|e| anyhow!("Lock error: {}", e))?;
449
450 let batch_size = buffer.size().min(self.config.batch_size);
451 if let Some((features, targets)) = buffer.get_batch(batch_size) {
452 drop(buffer);
453
454 let mut weights = self
456 .prediction_weights
457 .write()
458 .map_err(|e| anyhow!("Lock error: {}", e))?;
459
460 for i in 0..batch_size {
462 let prediction = features
463 .row(i)
464 .iter()
465 .zip(weights.iter())
466 .map(|(x, w)| x * w)
467 .sum::<f32>();
468 let error = prediction - targets[i];
469
470 for (j, weight) in weights.iter_mut().enumerate() {
472 if j < features.ncols() {
473 let gradient = error * features[[i, j]];
474 *weight -= (self.config.learning_rate as f32) * gradient;
475 }
476 }
477 }
478
479 drop(weights);
480 }
481
482 Ok(())
483 }
484
485 pub fn optimize(&mut self, features: PatternFeatures) -> Result<MLOptimizationResult> {
490 let predicted_cardinality = self.predict_cardinality(&features)?;
492
493 let join_order = self.optimize_join_order(features.pattern_count, &features)?;
495
496 let estimated_time_ms = predicted_cardinality as f64 * features.join_complexity * 0.001;
498
499 let use_gpu = predicted_cardinality > 10000;
501
502 let use_parallel = features.pattern_count > 3 || predicted_cardinality > 1000;
504
505 let buffer = self
507 .training_data
508 .read()
509 .map_err(|e| anyhow!("Lock error: {}", e))?;
510 let confidence = if buffer.size() >= self.config.min_training_samples {
511 0.9 } else {
513 0.5 };
515 drop(buffer);
516
517 Ok(MLOptimizationResult {
518 predicted_cardinality,
519 confidence,
520 join_order,
521 estimated_time_ms,
522 use_gpu,
523 use_parallel,
524 })
525 }
526
527 pub fn training_stats(&self) -> Result<TrainingStats> {
529 let buffer = self
530 .training_data
531 .read()
532 .map_err(|e| anyhow!("Lock error: {}", e))?;
533 Ok(TrainingStats {
534 total_samples: buffer.size(),
535 is_trained: buffer.size() >= self.config.min_training_samples,
536 min_samples_required: self.config.min_training_samples,
537 })
538 }
539
540 pub fn performance_metrics(&self) -> PerformanceMetrics {
546 PerformanceMetrics {
547 total_predictions: self.prediction_counter.get(),
548 total_trainings: self.training_counter.get(),
549 }
550 }
551}
552
553impl Default for MLQueryOptimizer {
554 fn default() -> Self {
555 Self::new()
556 }
557}
558
559#[derive(Debug, Clone, Serialize, Deserialize)]
561pub struct TrainingStats {
562 pub total_samples: usize,
564 pub is_trained: bool,
566 pub min_samples_required: usize,
568}
569
570#[derive(Debug, Clone)]
572pub struct PerformanceMetrics {
573 pub total_predictions: u64,
575 pub total_trainings: u64,
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_ml_optimizer_creation() {
585 let optimizer = MLQueryOptimizer::new();
586 assert_eq!(optimizer.config.training_buffer_size, 10000);
588 }
589
590 #[test]
591 fn test_pattern_features_conversion() {
592 let features = PatternFeatures {
593 pattern_count: 3,
594 bound_variables: 2,
595 unbound_variables: 4,
596 avg_selectivity: 0.1,
597 join_complexity: 2.5,
598 max_join_depth: 3,
599 filter_count: 1,
600 has_property_paths: true,
601 has_unions: false,
602 has_optionals: true,
603 };
604
605 let vec = features.to_vector();
606 assert_eq!(vec.len(), PatternFeatures::FEATURE_DIM);
607 assert_eq!(vec[0], 3.0); assert_eq!(vec[7], 1.0); }
610
611 #[test]
612 fn test_heuristic_cardinality() {
613 let optimizer = MLQueryOptimizer::new();
614
615 let simple_features = PatternFeatures {
616 pattern_count: 1,
617 bound_variables: 1,
618 unbound_variables: 2,
619 avg_selectivity: 0.1,
620 join_complexity: 1.0,
621 max_join_depth: 1,
622 filter_count: 0,
623 has_property_paths: false,
624 has_unions: false,
625 has_optionals: false,
626 };
627
628 let cardinality = optimizer.heuristic_cardinality(&simple_features);
629 assert!(cardinality > 0);
630 }
631
632 #[test]
633 fn test_training_buffer() {
634 let mut buffer = TrainingBuffer::new(5);
635
636 for i in 0..7 {
638 buffer.add(vec![i as f32; 10], i as f32 * 100.0, i as f32 * 10.0);
639 }
640
641 assert_eq!(buffer.size(), 5);
643
644 assert_eq!(buffer.cardinalities[0], 200.0); }
647
648 #[test]
649 fn test_join_order_optimization() -> Result<()> {
650 let optimizer = MLQueryOptimizer::new();
651
652 let features = PatternFeatures {
653 pattern_count: 5,
654 bound_variables: 3,
655 unbound_variables: 7,
656 avg_selectivity: 0.05,
657 join_complexity: 3.0,
658 max_join_depth: 4,
659 filter_count: 2,
660 has_property_paths: false,
661 has_unions: false,
662 has_optionals: true,
663 };
664
665 let order = optimizer.optimize_join_order(5, &features)?;
666 assert_eq!(order.len(), 5);
667
668 Ok(())
669 }
670
671 #[test]
672 fn test_adaptive_join_ordering() -> Result<()> {
673 let optimizer = MLQueryOptimizer::new();
674
675 let low_sel = PatternFeatures {
677 pattern_count: 5,
678 bound_variables: 1,
679 unbound_variables: 9,
680 avg_selectivity: 0.6,
681 join_complexity: 2.5,
682 max_join_depth: 3,
683 filter_count: 0,
684 has_property_paths: false,
685 has_unions: false,
686 has_optionals: false,
687 };
688
689 let order = optimizer.optimize_join_order(5, &low_sel)?;
690 assert_eq!(order.len(), 5);
691 assert_eq!(order, vec![4, 3, 2, 1, 0]);
693
694 let high_sel = PatternFeatures {
696 pattern_count: 5,
697 bound_variables: 4,
698 unbound_variables: 1,
699 avg_selectivity: 0.05,
700 join_complexity: 1.5,
701 max_join_depth: 2,
702 filter_count: 2,
703 has_property_paths: false,
704 has_unions: false,
705 has_optionals: false,
706 };
707
708 let order = optimizer.optimize_join_order(5, &high_sel)?;
709 assert_eq!(order, vec![0, 1, 2, 3, 4]);
710
711 Ok(())
712 }
713
714 #[test]
715 fn test_training_and_prediction() -> Result<()> {
716 let mut optimizer = MLQueryOptimizer::with_config(MLOptimizerConfig {
717 min_training_samples: 5,
718 ..Default::default()
719 });
720
721 for i in 0..10 {
723 let features = PatternFeatures {
724 pattern_count: i % 5 + 1,
725 bound_variables: i % 3,
726 unbound_variables: i % 7,
727 avg_selectivity: 0.1 * (i as f64 / 10.0),
728 join_complexity: 1.0 + (i as f64 / 5.0),
729 max_join_depth: i % 4 + 1,
730 filter_count: i % 3,
731 has_property_paths: i % 2 == 0,
732 has_unions: i % 3 == 0,
733 has_optionals: i % 4 == 0,
734 };
735
736 optimizer.train_from_execution(features, i * 100, (i * 10) as f64)?;
737 }
738
739 let stats = optimizer.training_stats()?;
741 assert_eq!(stats.total_samples, 10);
742 assert!(stats.is_trained);
743
744 Ok(())
745 }
746
747 #[test]
748 fn test_comprehensive_optimization() -> Result<()> {
749 let mut optimizer = MLQueryOptimizer::new();
750
751 let features = PatternFeatures {
752 pattern_count: 4,
753 bound_variables: 2,
754 unbound_variables: 6,
755 avg_selectivity: 0.15,
756 join_complexity: 2.8,
757 max_join_depth: 3,
758 filter_count: 1,
759 has_property_paths: true,
760 has_unions: false,
761 has_optionals: true,
762 };
763
764 let result = optimizer.optimize(features.clone())?;
765
766 assert!(result.predicted_cardinality > 0);
767 assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
768 assert_eq!(result.join_order.len(), 4);
769 assert!(result.estimated_time_ms >= 0.0);
770
771 Ok(())
772 }
773}