use anyhow::{anyhow, Result};
use scirs2_core::metrics::{Counter, Histogram, Timer};
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_core::random::Random;
use scirs2_core::rngs::StdRng;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, RwLock};
pub struct MLQueryOptimizer {
training_data: Arc<RwLock<TrainingBuffer>>,
prediction_weights: Arc<RwLock<Array1<f32>>>,
config: MLOptimizerConfig,
#[allow(dead_code)]
rng: Random<StdRng>,
prediction_counter: Arc<Counter>,
training_counter: Arc<Counter>,
prediction_timer: Arc<Timer>,
training_timer: Arc<Timer>,
prediction_error_histogram: Arc<Histogram>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLOptimizerConfig {
pub training_buffer_size: usize,
pub min_training_samples: usize,
pub learning_rate: f64,
pub enable_adaptive_joins: bool,
pub batch_size: usize,
}
impl Default for MLOptimizerConfig {
fn default() -> Self {
Self {
training_buffer_size: 10000,
min_training_samples: 100,
learning_rate: 0.001,
enable_adaptive_joins: true,
batch_size: 128,
}
}
}
struct TrainingBuffer {
features: Vec<Vec<f32>>,
cardinalities: Vec<f32>,
execution_times: Vec<f32>,
max_size: usize,
}
impl TrainingBuffer {
fn new(max_size: usize) -> Self {
Self {
features: Vec::with_capacity(max_size),
cardinalities: Vec::with_capacity(max_size),
execution_times: Vec::with_capacity(max_size),
max_size,
}
}
fn add(&mut self, features: Vec<f32>, cardinality: f32, execution_time: f32) {
if self.features.len() >= self.max_size {
self.features.remove(0);
self.cardinalities.remove(0);
self.execution_times.remove(0);
}
self.features.push(features);
self.cardinalities.push(cardinality);
self.execution_times.push(execution_time);
}
fn size(&self) -> usize {
self.features.len()
}
fn get_batch(&self, size: usize) -> Option<(Array2<f32>, Array1<f32>)> {
if self.features.is_empty() {
return None;
}
let batch_size = size.min(self.features.len());
let feature_dim = self.features[0].len();
let mut features = Array2::zeros((batch_size, feature_dim));
let mut targets = Array1::zeros(batch_size);
for i in 0..batch_size {
for j in 0..feature_dim {
features[[i, j]] = self.features[i][j];
}
targets[i] = self.cardinalities[i];
}
Some((features, targets))
}
}
#[derive(Debug, Clone)]
pub struct PatternFeatures {
pub pattern_count: usize,
pub bound_variables: usize,
pub unbound_variables: usize,
pub avg_selectivity: f64,
pub join_complexity: f64,
pub max_join_depth: usize,
pub filter_count: usize,
pub has_property_paths: bool,
pub has_unions: bool,
pub has_optionals: bool,
}
impl PatternFeatures {
pub fn to_vector(&self) -> Vec<f32> {
vec![
self.pattern_count as f32,
self.bound_variables as f32,
self.unbound_variables as f32,
self.avg_selectivity as f32,
self.join_complexity as f32,
self.max_join_depth as f32,
self.filter_count as f32,
if self.has_property_paths { 1.0 } else { 0.0 },
if self.has_unions { 1.0 } else { 0.0 },
if self.has_optionals { 1.0 } else { 0.0 },
]
}
pub const FEATURE_DIM: usize = 10;
}
#[derive(Debug, Clone)]
pub struct MLOptimizationResult {
pub predicted_cardinality: usize,
pub confidence: f64,
pub join_order: Vec<usize>,
pub estimated_time_ms: f64,
pub use_gpu: bool,
pub use_parallel: bool,
}
impl MLQueryOptimizer {
pub fn new() -> Self {
Self::with_config(MLOptimizerConfig::default())
}
pub fn with_config(config: MLOptimizerConfig) -> Self {
let training_data = Arc::new(RwLock::new(TrainingBuffer::new(
config.training_buffer_size,
)));
let initial_weights = Array1::from(vec![
100.0, 50.0, 200.0, 1000.0, 150.0, 80.0, 30.0, 500.0, 300.0, 200.0, ]);
let prediction_weights = Arc::new(RwLock::new(initial_weights));
let prediction_counter = Arc::new(Counter::new("ml_optimizer_predictions".to_string()));
let training_counter = Arc::new(Counter::new("ml_optimizer_training".to_string()));
let prediction_timer = Arc::new(Timer::new("ml_optimizer_prediction_time".to_string()));
let training_timer = Arc::new(Timer::new("ml_optimizer_training_time".to_string()));
let prediction_error_histogram =
Arc::new(Histogram::new("ml_optimizer_prediction_error".to_string()));
Self {
training_data,
prediction_weights,
config,
rng: Random::seed(42),
prediction_counter,
training_counter,
prediction_timer,
training_timer,
prediction_error_histogram,
}
}
pub fn predict_cardinality(&self, features: &PatternFeatures) -> Result<usize> {
let _timer_guard = self.prediction_timer.start();
self.prediction_counter.inc();
let feature_vec = features.to_vector();
let buffer = self
.training_data
.read()
.map_err(|e| anyhow!("Lock error: {}", e))?;
if buffer.size() < self.config.min_training_samples {
drop(buffer);
return Ok(self.heuristic_cardinality(features));
}
drop(buffer);
let input = Array1::from(feature_vec);
let prediction = self.predict_with_weights(&input)? as usize;
Ok(prediction)
}
fn predict_with_weights(&self, input: &Array1<f32>) -> Result<f32> {
let weights = self
.prediction_weights
.read()
.map_err(|e| anyhow!("Lock error: {}", e))?;
let prediction = input
.iter()
.zip(weights.iter())
.map(|(x, w)| x * w)
.sum::<f32>();
Ok(prediction.max(1.0)) }
fn heuristic_cardinality(&self, features: &PatternFeatures) -> usize {
let base = 1000; let mut estimate = base;
estimate *= features.pattern_count.max(1);
estimate = (estimate as f64 * features.avg_selectivity) as usize;
if features.has_unions {
estimate *= 2;
}
if features.has_property_paths {
estimate *= 3;
}
estimate.max(1)
}
pub fn optimize_join_order(
&self,
pattern_count: usize,
features: &PatternFeatures,
) -> Result<Vec<usize>> {
if pattern_count <= 1 {
return Ok(vec![0]);
}
if !self.config.enable_adaptive_joins {
return Ok((0..pattern_count).collect());
}
let mut order: Vec<usize> = (0..pattern_count).collect();
if features.avg_selectivity < 0.1 {
} else if features.avg_selectivity > 0.5 {
order.reverse();
} else {
let mut reordered = Vec::with_capacity(pattern_count);
let mid = pattern_count / 2;
for i in 0..mid {
reordered.push(i);
if i + mid < pattern_count {
reordered.push(i + mid);
}
}
if pattern_count % 2 != 0 {
reordered.push(pattern_count - 1);
}
order = reordered;
}
Ok(order)
}
pub fn train_from_execution(
&mut self,
features: PatternFeatures,
actual_cardinality: usize,
execution_time_ms: f64,
) -> Result<()> {
let _timer_guard = self.training_timer.start();
self.training_counter.inc();
if let Ok(predicted) = self.predict_cardinality(&features) {
let error_rate = if actual_cardinality > 0 {
(predicted as f64 - actual_cardinality as f64).abs() / actual_cardinality as f64
} else {
0.0
};
self.prediction_error_histogram.observe(error_rate);
}
let feature_vec = features.to_vector();
let mut buffer = self
.training_data
.write()
.map_err(|e| anyhow!("Lock error: {}", e))?;
buffer.add(
feature_vec,
actual_cardinality as f32,
execution_time_ms as f32,
);
let buffer_size = buffer.size();
drop(buffer);
if buffer_size >= self.config.min_training_samples && buffer_size % 100 == 0 {
self.retrain_models()?;
}
Ok(())
}
fn retrain_models(&self) -> Result<()> {
let buffer = self
.training_data
.read()
.map_err(|e| anyhow!("Lock error: {}", e))?;
let batch_size = buffer.size().min(self.config.batch_size);
if let Some((features, targets)) = buffer.get_batch(batch_size) {
drop(buffer);
let mut weights = self
.prediction_weights
.write()
.map_err(|e| anyhow!("Lock error: {}", e))?;
for i in 0..batch_size {
let prediction = features
.row(i)
.iter()
.zip(weights.iter())
.map(|(x, w)| x * w)
.sum::<f32>();
let error = prediction - targets[i];
for (j, weight) in weights.iter_mut().enumerate() {
if j < features.ncols() {
let gradient = error * features[[i, j]];
*weight -= (self.config.learning_rate as f32) * gradient;
}
}
}
drop(weights);
}
Ok(())
}
pub fn optimize(&mut self, features: PatternFeatures) -> Result<MLOptimizationResult> {
let predicted_cardinality = self.predict_cardinality(&features)?;
let join_order = self.optimize_join_order(features.pattern_count, &features)?;
let estimated_time_ms = predicted_cardinality as f64 * features.join_complexity * 0.001;
let use_gpu = predicted_cardinality > 10000;
let use_parallel = features.pattern_count > 3 || predicted_cardinality > 1000;
let buffer = self
.training_data
.read()
.map_err(|e| anyhow!("Lock error: {}", e))?;
let confidence = if buffer.size() >= self.config.min_training_samples {
0.9 } else {
0.5 };
drop(buffer);
Ok(MLOptimizationResult {
predicted_cardinality,
confidence,
join_order,
estimated_time_ms,
use_gpu,
use_parallel,
})
}
pub fn training_stats(&self) -> Result<TrainingStats> {
let buffer = self
.training_data
.read()
.map_err(|e| anyhow!("Lock error: {}", e))?;
Ok(TrainingStats {
total_samples: buffer.size(),
is_trained: buffer.size() >= self.config.min_training_samples,
min_samples_required: self.config.min_training_samples,
})
}
pub fn performance_metrics(&self) -> PerformanceMetrics {
PerformanceMetrics {
total_predictions: self.prediction_counter.get(),
total_trainings: self.training_counter.get(),
}
}
}
impl Default for MLQueryOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingStats {
pub total_samples: usize,
pub is_trained: bool,
pub min_samples_required: usize,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub total_predictions: u64,
pub total_trainings: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_optimizer_creation() {
let optimizer = MLQueryOptimizer::new();
assert_eq!(optimizer.config.training_buffer_size, 10000);
}
#[test]
fn test_pattern_features_conversion() {
let features = PatternFeatures {
pattern_count: 3,
bound_variables: 2,
unbound_variables: 4,
avg_selectivity: 0.1,
join_complexity: 2.5,
max_join_depth: 3,
filter_count: 1,
has_property_paths: true,
has_unions: false,
has_optionals: true,
};
let vec = features.to_vector();
assert_eq!(vec.len(), PatternFeatures::FEATURE_DIM);
assert_eq!(vec[0], 3.0); assert_eq!(vec[7], 1.0); }
#[test]
fn test_heuristic_cardinality() {
let optimizer = MLQueryOptimizer::new();
let simple_features = PatternFeatures {
pattern_count: 1,
bound_variables: 1,
unbound_variables: 2,
avg_selectivity: 0.1,
join_complexity: 1.0,
max_join_depth: 1,
filter_count: 0,
has_property_paths: false,
has_unions: false,
has_optionals: false,
};
let cardinality = optimizer.heuristic_cardinality(&simple_features);
assert!(cardinality > 0);
}
#[test]
fn test_training_buffer() {
let mut buffer = TrainingBuffer::new(5);
for i in 0..7 {
buffer.add(vec![i as f32; 10], i as f32 * 100.0, i as f32 * 10.0);
}
assert_eq!(buffer.size(), 5);
assert_eq!(buffer.cardinalities[0], 200.0); }
#[test]
fn test_join_order_optimization() -> Result<()> {
let optimizer = MLQueryOptimizer::new();
let features = PatternFeatures {
pattern_count: 5,
bound_variables: 3,
unbound_variables: 7,
avg_selectivity: 0.05,
join_complexity: 3.0,
max_join_depth: 4,
filter_count: 2,
has_property_paths: false,
has_unions: false,
has_optionals: true,
};
let order = optimizer.optimize_join_order(5, &features)?;
assert_eq!(order.len(), 5);
Ok(())
}
#[test]
fn test_adaptive_join_ordering() -> Result<()> {
let optimizer = MLQueryOptimizer::new();
let low_sel = PatternFeatures {
pattern_count: 5,
bound_variables: 1,
unbound_variables: 9,
avg_selectivity: 0.6,
join_complexity: 2.5,
max_join_depth: 3,
filter_count: 0,
has_property_paths: false,
has_unions: false,
has_optionals: false,
};
let order = optimizer.optimize_join_order(5, &low_sel)?;
assert_eq!(order.len(), 5);
assert_eq!(order, vec![4, 3, 2, 1, 0]);
let high_sel = PatternFeatures {
pattern_count: 5,
bound_variables: 4,
unbound_variables: 1,
avg_selectivity: 0.05,
join_complexity: 1.5,
max_join_depth: 2,
filter_count: 2,
has_property_paths: false,
has_unions: false,
has_optionals: false,
};
let order = optimizer.optimize_join_order(5, &high_sel)?;
assert_eq!(order, vec![0, 1, 2, 3, 4]);
Ok(())
}
#[test]
fn test_training_and_prediction() -> Result<()> {
let mut optimizer = MLQueryOptimizer::with_config(MLOptimizerConfig {
min_training_samples: 5,
..Default::default()
});
for i in 0..10 {
let features = PatternFeatures {
pattern_count: i % 5 + 1,
bound_variables: i % 3,
unbound_variables: i % 7,
avg_selectivity: 0.1 * (i as f64 / 10.0),
join_complexity: 1.0 + (i as f64 / 5.0),
max_join_depth: i % 4 + 1,
filter_count: i % 3,
has_property_paths: i % 2 == 0,
has_unions: i % 3 == 0,
has_optionals: i % 4 == 0,
};
optimizer.train_from_execution(features, i * 100, (i * 10) as f64)?;
}
let stats = optimizer.training_stats()?;
assert_eq!(stats.total_samples, 10);
assert!(stats.is_trained);
Ok(())
}
#[test]
fn test_comprehensive_optimization() -> Result<()> {
let mut optimizer = MLQueryOptimizer::new();
let features = PatternFeatures {
pattern_count: 4,
bound_variables: 2,
unbound_variables: 6,
avg_selectivity: 0.15,
join_complexity: 2.8,
max_join_depth: 3,
filter_count: 1,
has_property_paths: true,
has_unions: false,
has_optionals: true,
};
let result = optimizer.optimize(features.clone())?;
assert!(result.predicted_cardinality > 0);
assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
assert_eq!(result.join_order.len(), 4);
assert!(result.estimated_time_ms >= 0.0);
Ok(())
}
}