use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use scirs2_core::metrics::{Counter, Histogram, MetricsRegistry, Timer};
use scirs2_core::ndarray_ext::{Array1, Array2};
use scirs2_stats::regression::{linear_regression, ridge_regression, RegressionResults};
use crate::algebra::Algebra;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValueHistogram {
pub buckets: Vec<HistogramBucket>,
pub total_count: usize,
pub distinct_count: usize,
pub min_value: f64,
pub max_value: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramBucket {
pub lower_bound: f64,
pub upper_bound: f64,
pub count: usize,
pub distinct_count: usize,
}
impl ValueHistogram {
pub fn from_data(data: &[f64], num_buckets: usize) -> Self {
if data.is_empty() {
return Self::empty();
}
let min_value = data.iter().copied().fold(f64::INFINITY, f64::min);
let max_value = data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
if (max_value - min_value).abs() < 1e-10 {
return Self {
buckets: vec![HistogramBucket {
lower_bound: min_value,
upper_bound: max_value + 1e-10,
count: data.len(),
distinct_count: 1,
}],
total_count: data.len(),
distinct_count: 1,
min_value,
max_value,
};
}
let bucket_width = (max_value - min_value) / num_buckets as f64;
let mut buckets = Vec::with_capacity(num_buckets);
for i in 0..num_buckets {
let lower = min_value + i as f64 * bucket_width;
let upper = if i == num_buckets - 1 {
max_value + 1e-10 } else {
min_value + (i + 1) as f64 * bucket_width
};
buckets.push(HistogramBucket {
lower_bound: lower,
upper_bound: upper,
count: 0,
distinct_count: 0,
});
}
use std::collections::HashSet;
let mut bucket_distinct: Vec<HashSet<u64>> = vec![HashSet::new(); num_buckets];
for &value in data {
let bucket_idx = if value >= max_value {
num_buckets - 1
} else {
((value - min_value) / bucket_width).floor() as usize
};
if bucket_idx < num_buckets {
buckets[bucket_idx].count += 1;
let value_bits = value.to_bits();
bucket_distinct[bucket_idx].insert(value_bits);
}
}
for (i, bucket) in buckets.iter_mut().enumerate() {
bucket.distinct_count = bucket_distinct[i].len();
}
let distinct_count = bucket_distinct.iter().map(|s| s.len()).sum();
Self {
buckets,
total_count: data.len(),
distinct_count,
min_value,
max_value,
}
}
pub fn empty() -> Self {
Self {
buckets: Vec::new(),
total_count: 0,
distinct_count: 0,
min_value: 0.0,
max_value: 0.0,
}
}
pub fn estimate_equality_selectivity(&self, value: f64) -> f64 {
if self.total_count == 0 {
return 0.0;
}
for bucket in &self.buckets {
if value >= bucket.lower_bound && value < bucket.upper_bound {
if bucket.distinct_count == 0 {
return 0.0;
}
let selectivity =
bucket.count as f64 / bucket.distinct_count as f64 / self.total_count as f64;
return selectivity.min(1.0);
}
}
0.0
}
pub fn estimate_range_selectivity(&self, lower: f64, upper: f64) -> f64 {
if self.total_count == 0 || upper <= lower {
return 0.0;
}
let mut selected_count = 0.0;
for bucket in &self.buckets {
let overlap_lower = bucket.lower_bound.max(lower);
let overlap_upper = bucket.upper_bound.min(upper);
if overlap_upper > overlap_lower {
let bucket_width = bucket.upper_bound - bucket.lower_bound;
let overlap_width = overlap_upper - overlap_lower;
if bucket_width > 1e-10 {
let fraction = overlap_width / bucket_width;
selected_count += bucket.count as f64 * fraction;
}
}
}
(selected_count / self.total_count as f64).min(1.0)
}
pub fn estimate_cardinality(&self, selectivity: f64) -> usize {
(self.total_count as f64 * selectivity).ceil() as usize
}
pub fn avg_bucket_size(&self) -> f64 {
if self.buckets.is_empty() {
return 0.0;
}
self.total_count as f64 / self.buckets.len() as f64
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramCardinalityEstimator {
pub histograms: HashMap<String, ValueHistogram>,
pub config: HistogramConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramConfig {
pub num_buckets: usize,
pub enabled: bool,
}
impl Default for HistogramConfig {
fn default() -> Self {
Self {
num_buckets: 100,
enabled: true,
}
}
}
impl HistogramCardinalityEstimator {
pub fn new(config: HistogramConfig) -> Self {
Self {
histograms: HashMap::new(),
config,
}
}
pub fn build_from_training_data(&mut self, training_data: &[TrainingExample]) {
if !self.config.enabled || training_data.is_empty() {
return;
}
let n_features = training_data[0].features.len();
for feature_idx in 0..n_features {
let feature_values: Vec<f64> = training_data
.iter()
.map(|example| example.features.get(feature_idx).copied().unwrap_or(0.0))
.collect();
let histogram = ValueHistogram::from_data(&feature_values, self.config.num_buckets);
let feature_name = format!("feature_{}", feature_idx);
self.histograms.insert(feature_name, histogram);
}
}
pub fn estimate_cardinality_with_histogram(&self, features: &[f64]) -> Option<f64> {
if !self.config.enabled || self.histograms.is_empty() {
return None;
}
let mut product = 1.0;
let mut count = 0;
for (i, &feature_value) in features.iter().enumerate() {
let feature_name = format!("feature_{}", i);
if let Some(histogram) = self.histograms.get(&feature_name) {
let selectivity = histogram.estimate_equality_selectivity(feature_value);
if selectivity > 0.0 {
product *= selectivity;
count += 1;
}
}
}
if count > 0 {
let geometric_mean = product.powf(1.0 / count as f64);
let avg_total: f64 = self
.histograms
.values()
.map(|h| h.total_count as f64)
.sum::<f64>()
/ self.histograms.len() as f64;
Some(avg_total * geometric_mean)
} else {
None
}
}
pub fn get_statistics(&self) -> HistogramStatistics {
let total_buckets: usize = self.histograms.values().map(|h| h.buckets.len()).sum();
let avg_distinct: f64 = if !self.histograms.is_empty() {
self.histograms
.values()
.map(|h| h.distinct_count as f64)
.sum::<f64>()
/ self.histograms.len() as f64
} else {
0.0
};
HistogramStatistics {
num_histograms: self.histograms.len(),
total_buckets,
avg_distinct_values: avg_distinct,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramStatistics {
pub num_histograms: usize,
pub total_buckets: usize,
pub avg_distinct_values: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLConfig {
pub model_type: MLModelType,
pub confidence_threshold: f64,
pub training_interval_hours: u64,
pub max_training_examples: usize,
pub min_examples_for_training: usize,
pub feature_normalization: bool,
pub auto_retraining: bool,
pub model_persistence_path: Option<PathBuf>,
}
impl Default for MLConfig {
fn default() -> Self {
Self {
model_type: MLModelType::Ridge,
confidence_threshold: 0.7,
training_interval_hours: 24,
max_training_examples: 10000,
min_examples_for_training: 100,
feature_normalization: true,
auto_retraining: true,
model_persistence_path: None,
}
}
}
pub struct MLPredictor {
model: MLModel,
training_data: Vec<TrainingExample>,
feature_extractor: FeatureExtractor,
prediction_cache: HashMap<u64, MLPrediction>,
config: MLConfig,
metrics_collector: Arc<MetricsRegistry>,
last_training: Option<SystemTime>,
histogram_estimator: HistogramCardinalityEstimator,
prediction_counter: Counter,
prediction_timer: Timer,
prediction_histogram: Histogram,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLModel {
regression_results: Option<SerializableRegressionResults>,
model_type: MLModelType,
accuracy_metrics: AccuracyMetrics,
normalization_params: Option<NormalizationParams>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializableRegressionResults {
pub coefficients: Vec<f64>,
pub r_squared: f64,
pub adj_r_squared: f64,
pub residual_std_error: f64,
pub std_errors: Vec<f64>,
}
impl From<&RegressionResults<f64>> for SerializableRegressionResults {
fn from(results: &RegressionResults<f64>) -> Self {
Self {
coefficients: results.coefficients.to_vec(),
r_squared: results.r_squared,
adj_r_squared: results.adj_r_squared,
residual_std_error: results.residual_std_error,
std_errors: results.std_errors.to_vec(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum MLModelType {
LinearRegression,
Ridge,
RandomForest,
NeuralNetwork,
GradientBoosting,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
pub features: Vec<f64>,
pub target_cost: f64,
pub actual_cost: f64,
pub query_characteristics: QueryCharacteristics,
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCharacteristics {
pub triple_pattern_count: usize,
pub join_count: usize,
pub filter_count: usize,
pub optional_count: usize,
pub has_aggregation: bool,
pub has_sorting: bool,
pub estimated_cardinality: usize,
pub complexity_score: f64,
pub query_graph_diameter: usize,
pub avg_degree: f64,
pub max_degree: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeatureExtractor {
feature_weights: HashMap<String, f64>,
normalization_params: Option<NormalizationParams>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NormalizationParams {
mean: Vec<f64>,
std_dev: Vec<f64>,
min_values: Vec<f64>,
max_values: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct MLPrediction {
pub predicted_cost: f64,
pub confidence: f64,
pub recommendation: OptimizationRecommendation,
pub feature_importance: Vec<(String, f64)>,
}
#[derive(Debug, Clone)]
pub enum OptimizationRecommendation {
UseIndex(String),
EnableParallelism(usize),
ApplyStreaming,
MaterializeSubquery,
ReorderJoins(Vec<usize>),
NoChange,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccuracyMetrics {
pub mean_absolute_error: f64,
pub root_mean_square_error: f64,
pub r_squared: f64,
pub confidence_interval: (f64, f64),
}
impl MLPredictor {
pub fn new(config: MLConfig) -> Result<Self> {
let metrics_collector = Arc::new(MetricsRegistry::new());
let prediction_counter = Counter::new("ml_predictor_predictions_total".to_string());
let prediction_timer = Timer::new("ml_predictor_prediction_duration_seconds".to_string());
let prediction_histogram =
Histogram::new("ml_predictor_prediction_distribution".to_string());
Ok(Self {
model: MLModel {
regression_results: None,
model_type: config.model_type.clone(),
accuracy_metrics: AccuracyMetrics {
mean_absolute_error: 0.0,
root_mean_square_error: 0.0,
r_squared: 0.0,
confidence_interval: (0.0, 0.0),
},
normalization_params: None,
},
training_data: Vec::new(),
feature_extractor: FeatureExtractor {
feature_weights: HashMap::new(),
normalization_params: None,
},
prediction_cache: HashMap::new(),
config,
metrics_collector,
last_training: None,
histogram_estimator: HistogramCardinalityEstimator::new(HistogramConfig::default()),
prediction_counter,
prediction_timer,
prediction_histogram,
})
}
pub fn from_model_type(model_type: MLModelType) -> Result<Self> {
let config = MLConfig {
model_type,
..Default::default()
};
Self::new(config)
}
pub fn load_model(path: &Path) -> Result<Self> {
let contents = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read model from {:?}", path))?;
let predictor: MLPredictor = serde_json::from_str(&contents)
.with_context(|| format!("Failed to deserialize model from {:?}", path))?;
Ok(predictor)
}
pub fn save_model(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create directory {:?}", parent))?;
}
let contents = serde_json::to_string_pretty(self).context("Failed to serialize model")?;
std::fs::write(path, contents)
.with_context(|| format!("Failed to write model to {:?}", path))?;
Ok(())
}
pub fn confidence(&self) -> f64 {
let r_squared = self.model.accuracy_metrics.r_squared;
r_squared.clamp(0.0, 1.0)
}
pub fn should_use_ml(&self) -> bool {
self.confidence() >= self.config.confidence_threshold
&& self.model.regression_results.is_some()
}
pub fn extract_features(&self, query: &Algebra) -> Vec<f64> {
let mut features = Vec::with_capacity(13);
let characteristics = self.analyze_query_structure(query);
features.push(characteristics.triple_pattern_count as f64);
features.push(characteristics.join_count as f64);
features.push(characteristics.filter_count as f64);
features.push(characteristics.optional_count as f64);
features.push(if characteristics.has_aggregation {
1.0
} else {
0.0
});
features.push(if characteristics.has_sorting {
1.0
} else {
0.0
});
features.push(characteristics.estimated_cardinality as f64);
features.push(characteristics.query_graph_diameter as f64);
features.push(characteristics.avg_degree);
features.push(characteristics.max_degree as f64);
features.push(self.calculate_cross_product_likelihood(query));
features.push(self.calculate_subquery_depth(query));
features.push(self.calculate_aggregation_complexity(query));
if self.config.feature_normalization {
self.normalize_features(features)
} else {
features
}
}
fn analyze_query_structure(&self, query: &Algebra) -> QueryCharacteristics {
let mut characteristics = QueryCharacteristics {
triple_pattern_count: 0,
join_count: 0,
filter_count: 0,
optional_count: 0,
has_aggregation: false,
has_sorting: false,
estimated_cardinality: 1000,
complexity_score: 0.0,
query_graph_diameter: 1,
avg_degree: 1.0,
max_degree: 1,
};
self.traverse_algebra(query, &mut characteristics);
characteristics.complexity_score = self.calculate_complexity_score(&characteristics);
characteristics.query_graph_diameter = self.calculate_graph_diameter(&characteristics);
let (avg_deg, max_deg) = self.calculate_degree_metrics(&characteristics);
characteristics.avg_degree = avg_deg;
characteristics.max_degree = max_deg;
characteristics
}
fn traverse_algebra(&self, algebra: &Algebra, characteristics: &mut QueryCharacteristics) {
use crate::algebra::Algebra;
match algebra {
Algebra::Service { .. } => characteristics.triple_pattern_count += 1,
Algebra::PropertyPath { .. } => characteristics.triple_pattern_count += 1,
Algebra::Join { left, right, .. } => {
characteristics.join_count += 1;
self.traverse_algebra(left, characteristics);
self.traverse_algebra(right, characteristics);
}
Algebra::LeftJoin { left, right, .. } => {
characteristics.join_count += 1;
characteristics.optional_count += 1;
self.traverse_algebra(left, characteristics);
self.traverse_algebra(right, characteristics);
}
Algebra::Filter { pattern, .. } => {
characteristics.filter_count += 1;
self.traverse_algebra(pattern, characteristics);
}
Algebra::Union { left, right } => {
self.traverse_algebra(left, characteristics);
self.traverse_algebra(right, characteristics);
}
Algebra::Extend { pattern, .. } => {
self.traverse_algebra(pattern, characteristics);
}
Algebra::OrderBy { pattern, .. } => {
characteristics.has_sorting = true;
self.traverse_algebra(pattern, characteristics);
}
Algebra::Project { pattern, .. } => {
self.traverse_algebra(pattern, characteristics);
}
Algebra::Distinct { pattern } => {
self.traverse_algebra(pattern, characteristics);
}
Algebra::Reduced { pattern } => {
self.traverse_algebra(pattern, characteristics);
}
Algebra::Slice { pattern, .. } => {
self.traverse_algebra(pattern, characteristics);
}
Algebra::Group { pattern, .. } => {
characteristics.has_aggregation = true;
self.traverse_algebra(pattern, characteristics);
}
_ => {}
}
}
fn calculate_complexity_score(&self, characteristics: &QueryCharacteristics) -> f64 {
let mut score = 0.0;
score += (characteristics.join_count as f64).powi(2) * 3.0;
score += characteristics.triple_pattern_count as f64 * 1.0;
score += characteristics.filter_count as f64 * 0.5;
score += characteristics.optional_count as f64 * 2.0;
if characteristics.has_aggregation {
score += 5.0;
}
if characteristics.has_sorting {
score += 2.0;
}
let cardinality_log = (characteristics.estimated_cardinality as f64)
.log10()
.max(1.0);
score *= cardinality_log;
score
}
fn calculate_graph_diameter(&self, characteristics: &QueryCharacteristics) -> usize {
if characteristics.join_count == 0 {
1
} else {
(characteristics.join_count as f64).sqrt().ceil() as usize + 1
}
}
fn calculate_degree_metrics(&self, characteristics: &QueryCharacteristics) -> (f64, usize) {
if characteristics.triple_pattern_count == 0 {
return (0.0, 0);
}
let avg_degree = if characteristics.triple_pattern_count > 0 {
(characteristics.join_count as f64 * 2.0) / characteristics.triple_pattern_count as f64
} else {
0.0
};
let max_degree = characteristics
.join_count
.min(characteristics.triple_pattern_count);
(avg_degree, max_degree)
}
fn calculate_cross_product_likelihood(&self, _query: &Algebra) -> f64 {
0.0
}
fn calculate_subquery_depth(&self, algebra: &Algebra) -> f64 {
self.calculate_depth_recursive(algebra, 0) as f64
}
fn calculate_depth_recursive(&self, algebra: &Algebra, current_depth: usize) -> usize {
use crate::algebra::Algebra;
match algebra {
Algebra::Join { left, right, .. }
| Algebra::LeftJoin { left, right, .. }
| Algebra::Union { left, right } => {
let left_depth = self.calculate_depth_recursive(left, current_depth + 1);
let right_depth = self.calculate_depth_recursive(right, current_depth + 1);
left_depth.max(right_depth)
}
Algebra::Filter { pattern, .. }
| Algebra::Extend { pattern, .. }
| Algebra::OrderBy { pattern, .. }
| Algebra::Project { pattern, .. }
| Algebra::Distinct { pattern }
| Algebra::Reduced { pattern }
| Algebra::Slice { pattern, .. }
| Algebra::Group { pattern, .. } => {
self.calculate_depth_recursive(pattern, current_depth + 1)
}
_ => current_depth,
}
}
fn calculate_aggregation_complexity(&self, algebra: &Algebra) -> f64 {
self.count_aggregations(algebra) as f64
}
fn count_aggregations(&self, algebra: &Algebra) -> usize {
use crate::algebra::Algebra;
match algebra {
Algebra::Group { .. } => 1,
Algebra::Join { left, right, .. }
| Algebra::LeftJoin { left, right, .. }
| Algebra::Union { left, right } => {
self.count_aggregations(left) + self.count_aggregations(right)
}
Algebra::Filter { pattern, .. }
| Algebra::Extend { pattern, .. }
| Algebra::OrderBy { pattern, .. }
| Algebra::Project { pattern, .. }
| Algebra::Distinct { pattern }
| Algebra::Reduced { pattern }
| Algebra::Slice { pattern, .. } => self.count_aggregations(pattern),
_ => 0,
}
}
fn normalize_features(&self, features: Vec<f64>) -> Vec<f64> {
if let Some(ref params) = self.feature_extractor.normalization_params {
features
.iter()
.enumerate()
.map(|(i, &value)| {
if i < params.mean.len() && i < params.std_dev.len() {
let mean = params.mean[i];
let std_dev = params.std_dev[i];
if std_dev > 1e-10 {
(value - mean) / std_dev
} else {
value
}
} else {
value
}
})
.collect()
} else {
features
}
}
pub fn predict_cost(&mut self, query: &Algebra) -> Result<MLPrediction> {
let _guard = self.prediction_timer.start();
self.prediction_counter.inc();
let features = self.extract_features(query);
let query_hash = self.hash_query(query);
if let Some(cached) = self.prediction_cache.get(&query_hash) {
return Ok(cached.clone());
}
let (predicted_cost, confidence) = if self.should_use_ml() {
self.predict_with_model(&features)?
} else {
self.heuristic_prediction(&features)?
};
self.prediction_histogram.observe(predicted_cost);
let recommendation = self.generate_recommendation(&features, predicted_cost);
let feature_importance = self.calculate_feature_importance(&features);
let prediction = MLPrediction {
predicted_cost,
confidence,
recommendation,
feature_importance,
};
self.prediction_cache.insert(query_hash, prediction.clone());
Ok(prediction)
}
fn predict_with_model(&self, features: &[f64]) -> Result<(f64, f64)> {
if let Some(ref results) = self.model.regression_results {
let mut prediction = 0.0;
for (i, &coef) in results.coefficients.iter().enumerate() {
if i < features.len() {
prediction += coef * features[i];
} else if i == features.len() {
prediction += coef;
}
}
prediction = prediction.max(0.1);
let confidence = self.confidence();
Ok((prediction, confidence))
} else {
self.heuristic_prediction(features)
}
}
fn heuristic_prediction(&self, features: &[f64]) -> Result<(f64, f64)> {
let mut cost = 0.0;
if features.len() >= 13 {
let triple_patterns = features[0];
let joins = features[1];
let filters = features[2];
let optional = features[3];
let has_aggregation = features[4];
let has_sorting = features[5];
let cardinality = features[6];
cost += triple_patterns * 10.0;
cost += joins * joins * 50.0;
cost += filters * 5.0;
cost += optional * 15.0;
cost += has_aggregation * 100.0;
cost += has_sorting * 20.0;
cost += (cardinality / 1000.0) * 2.0;
}
cost = cost.max(1.0);
let confidence = 0.5;
Ok((cost, confidence))
}
fn generate_recommendation(
&self,
features: &[f64],
predicted_cost: f64,
) -> OptimizationRecommendation {
if features.len() < 7 {
return OptimizationRecommendation::NoChange;
}
let joins = features[1];
let has_aggregation = features[4];
let cardinality = features[6];
if predicted_cost > 1000.0 {
if joins > 3.0 {
return OptimizationRecommendation::ReorderJoins(vec![0, 1, 2]);
}
if cardinality > 10000.0 {
return OptimizationRecommendation::EnableParallelism(4);
}
if has_aggregation > 0.0 {
return OptimizationRecommendation::MaterializeSubquery;
}
}
if predicted_cost > 100.0 && cardinality > 5000.0 {
return OptimizationRecommendation::ApplyStreaming;
}
OptimizationRecommendation::NoChange
}
fn calculate_feature_importance(&self, features: &[f64]) -> Vec<(String, f64)> {
let feature_names = vec![
"triple_patterns",
"joins",
"filters",
"optional",
"has_aggregation",
"has_sorting",
"cardinality",
"graph_diameter",
"avg_degree",
"max_degree",
"cross_product",
"subquery_depth",
"aggregation_complexity",
];
let total: f64 = features.iter().map(|x| x.abs()).sum();
let mut importance: Vec<(String, f64)> = feature_names
.iter()
.zip(features.iter())
.map(|(name, &value)| {
let normalized = if total > 1e-10 {
(value.abs() / total).min(1.0)
} else {
0.0
};
(name.to_string(), normalized)
})
.collect();
importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
importance
}
pub fn add_training_example(&mut self, example: TrainingExample) {
self.training_data.push(example);
if self.training_data.len() > self.config.max_training_examples {
self.training_data.remove(0);
}
self.prediction_cache.clear();
}
pub fn update_from_execution(&mut self, query: &Algebra, actual_cost: f64) -> Result<()> {
let features = self.extract_features(query);
let characteristics = self.analyze_query_structure(query);
let example = TrainingExample {
features: features.clone(),
target_cost: actual_cost,
actual_cost,
query_characteristics: characteristics,
timestamp: SystemTime::now(),
};
self.add_training_example(example);
if self.should_retrain() {
self.train_model()?;
}
Ok(())
}
fn should_retrain(&self) -> bool {
if !self.config.auto_retraining {
return false;
}
if self.training_data.len() < self.config.min_examples_for_training {
return false;
}
if let Some(last_training) = self.last_training {
if let Ok(elapsed) = SystemTime::now().duration_since(last_training) {
let hours_elapsed = elapsed.as_secs() / 3600;
if hours_elapsed < self.config.training_interval_hours {
return false;
}
}
}
true
}
pub fn train_model(&mut self) -> Result<()> {
if self.training_data.len() < self.config.min_examples_for_training {
return Err(anyhow::anyhow!(
"Insufficient training data: {} < {}",
self.training_data.len(),
self.config.min_examples_for_training
));
}
let n_samples = self.training_data.len();
let n_features = self.training_data[0].features.len();
let mut x_data = Vec::with_capacity(n_samples * n_features);
let mut y_data = Vec::with_capacity(n_samples);
for example in &self.training_data {
x_data.extend_from_slice(&example.features);
y_data.push(example.actual_cost);
}
let x = Array2::from_shape_vec((n_samples, n_features), x_data)
.map_err(|e| anyhow::anyhow!("Failed to create feature matrix: {}", e))?;
let y = Array1::from_vec(y_data);
if self.config.feature_normalization {
self.calculate_normalization_params(&x);
}
let results = match self.model.model_type {
MLModelType::LinearRegression => linear_regression(&x.view(), &y.view(), None)
.map_err(|e| anyhow::anyhow!("Linear regression failed: {:?}", e))?,
MLModelType::Ridge => {
let alpha = Some(1.0); ridge_regression(
&x.view(),
&y.view(),
alpha,
None, None, None, None, None, )
.map_err(|e| anyhow::anyhow!("Ridge regression failed: {:?}", e))?
}
_ => {
linear_regression(&x.view(), &y.view(), None)
.map_err(|e| anyhow::anyhow!("Linear regression failed: {:?}", e))?
}
};
self.model.regression_results = Some(SerializableRegressionResults::from(&results));
self.update_accuracy_metrics(&results, &x, &y)?;
self.histogram_estimator
.build_from_training_data(&self.training_data);
self.last_training = Some(SystemTime::now());
if let Some(ref path) = self.config.model_persistence_path {
self.save_model(path)?;
}
Ok(())
}
fn calculate_normalization_params(&mut self, x: &Array2<f64>) {
let n_features = x.ncols();
let mut means = vec![0.0; n_features];
let mut std_devs = vec![0.0; n_features];
let mut mins = vec![f64::MAX; n_features];
let mut maxs = vec![f64::MIN; n_features];
for j in 0..n_features {
let column = x.column(j);
let n = column.len() as f64;
let mean: f64 = column.iter().sum::<f64>() / n;
means[j] = mean;
let variance: f64 = column.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
std_devs[j] = variance.sqrt();
for &val in column.iter() {
if val < mins[j] {
mins[j] = val;
}
if val > maxs[j] {
maxs[j] = val;
}
}
}
let params = NormalizationParams {
mean: means,
std_dev: std_devs,
min_values: mins,
max_values: maxs,
};
self.feature_extractor.normalization_params = Some(params.clone());
self.model.normalization_params = Some(params);
}
fn update_accuracy_metrics(
&mut self,
results: &RegressionResults<f64>,
_x: &Array2<f64>,
_y: &Array1<f64>,
) -> Result<()> {
let r_squared = results.r_squared;
let residual_std_error = results.residual_std_error;
let residuals = &results.residuals;
let n = residuals.len() as f64;
let mae = residuals.iter().map(|&r| r.abs()).sum::<f64>() / n;
let rmse = (residuals.iter().map(|&r| r * r).sum::<f64>() / n).sqrt();
let confidence_interval = (
r_squared - 1.96 * residual_std_error / n.sqrt(),
r_squared + 1.96 * residual_std_error / n.sqrt(),
);
self.model.accuracy_metrics = AccuracyMetrics {
mean_absolute_error: mae,
root_mean_square_error: rmse,
r_squared,
confidence_interval,
};
Ok(())
}
pub fn accuracy_metrics(&self) -> &AccuracyMetrics {
&self.model.accuracy_metrics
}
pub fn predictions_count(&self) -> usize {
self.prediction_counter.get() as usize
}
pub fn training_data_count(&self) -> usize {
self.training_data.len()
}
pub fn estimate_cardinality_with_histogram(&self, features: &[f64]) -> Option<f64> {
self.histogram_estimator
.estimate_cardinality_with_histogram(features)
}
pub fn get_histogram_statistics(&self) -> HistogramStatistics {
self.histogram_estimator.get_statistics()
}
pub fn predict_cost_with_histogram(&mut self, query: &Algebra) -> Result<MLPrediction> {
let _guard = self.prediction_timer.start();
self.prediction_counter.inc();
let features = self.extract_features(query);
let query_hash = self.hash_query(query);
if let Some(cached) = self.prediction_cache.get(&query_hash) {
return Ok(cached.clone());
}
let (mut predicted_cost, mut confidence) = if self.should_use_ml() {
self.predict_with_model(&features)?
} else {
self.heuristic_prediction(&features)?
};
if let Some(histogram_cardinality) = self.estimate_cardinality_with_histogram(&features) {
let blend_weight = 0.3; predicted_cost =
predicted_cost * (1.0 - blend_weight) + histogram_cardinality * blend_weight;
let agreement =
1.0 - ((predicted_cost - histogram_cardinality).abs() / (predicted_cost + 1.0));
confidence = (confidence + agreement * 0.2).min(1.0);
}
self.prediction_histogram.observe(predicted_cost);
let recommendation = self.generate_recommendation(&features, predicted_cost);
let feature_importance = self.calculate_feature_importance(&features);
let prediction = MLPrediction {
predicted_cost,
confidence,
recommendation,
feature_importance,
};
self.prediction_cache.insert(query_hash, prediction.clone());
Ok(prediction)
}
fn hash_query(&self, query: &Algebra) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let query_string = self.algebra_to_string(query);
query_string.hash(&mut hasher);
hasher.finish()
}
fn algebra_to_string(&self, algebra: &Algebra) -> String {
use crate::algebra::Algebra;
match algebra {
Algebra::Service { .. } => "Service".to_string(),
Algebra::PropertyPath { .. } => "PropertyPath".to_string(),
Algebra::Join { left, right, .. } => {
format!(
"Join({},{})",
self.algebra_to_string(left),
self.algebra_to_string(right)
)
}
Algebra::LeftJoin { left, right, .. } => {
format!(
"LeftJoin({},{})",
self.algebra_to_string(left),
self.algebra_to_string(right)
)
}
Algebra::Filter { pattern, .. } => {
format!("Filter({})", self.algebra_to_string(pattern))
}
Algebra::Union { left, right } => {
format!(
"Union({},{})",
self.algebra_to_string(left),
self.algebra_to_string(right)
)
}
Algebra::Extend { pattern, .. } => {
format!("Extend({})", self.algebra_to_string(pattern))
}
Algebra::OrderBy { pattern, .. } => {
format!("OrderBy({})", self.algebra_to_string(pattern))
}
Algebra::Project { pattern, .. } => {
format!("Project({})", self.algebra_to_string(pattern))
}
Algebra::Distinct { pattern } => {
format!("Distinct({})", self.algebra_to_string(pattern))
}
Algebra::Reduced { pattern } => {
format!("Reduced({})", self.algebra_to_string(pattern))
}
Algebra::Slice { pattern, .. } => {
format!("Slice({})", self.algebra_to_string(pattern))
}
Algebra::Group { pattern, .. } => {
format!("Group({})", self.algebra_to_string(pattern))
}
_ => "Unknown".to_string(),
}
}
}
impl Serialize for MLPredictor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("MLPredictor", 5)?;
state.serialize_field("model", &self.model)?;
state.serialize_field("training_data", &self.training_data)?;
state.serialize_field("feature_extractor", &self.feature_extractor)?;
state.serialize_field("config", &self.config)?;
state.serialize_field("histogram_estimator", &self.histogram_estimator)?;
state.end()
}
}
impl<'de> Deserialize<'de> for MLPredictor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct MLPredictorData {
model: MLModel,
training_data: Vec<TrainingExample>,
feature_extractor: FeatureExtractor,
config: MLConfig,
#[serde(default)]
histogram_estimator: Option<HistogramCardinalityEstimator>,
}
let data = MLPredictorData::deserialize(deserializer)?;
let metrics_collector = Arc::new(MetricsRegistry::new());
let prediction_counter = Counter::new("ml_predictor_predictions_total".to_string());
let prediction_timer = Timer::new("ml_predictor_prediction_duration_seconds".to_string());
let prediction_histogram =
Histogram::new("ml_predictor_prediction_distribution".to_string());
Ok(MLPredictor {
model: data.model,
training_data: data.training_data,
feature_extractor: data.feature_extractor,
prediction_cache: HashMap::new(),
config: data.config,
metrics_collector,
last_training: None,
histogram_estimator: data
.histogram_estimator
.unwrap_or_else(|| HistogramCardinalityEstimator::new(HistogramConfig::default())),
prediction_counter,
prediction_timer,
prediction_histogram,
})
}
}
impl Clone for MLPredictor {
fn clone(&self) -> Self {
MLPredictor {
model: self.model.clone(),
training_data: self.training_data.clone(),
feature_extractor: self.feature_extractor.clone(),
prediction_cache: HashMap::new(), config: self.config.clone(),
metrics_collector: Arc::clone(&self.metrics_collector),
last_training: self.last_training,
histogram_estimator: self.histogram_estimator.clone(),
prediction_counter: Counter::new("ml_predictor_predictions_total".to_string()),
prediction_timer: Timer::new("ml_predictor_prediction_duration_seconds".to_string()),
prediction_histogram: Histogram::new(
"ml_predictor_prediction_distribution".to_string(),
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_predictor_creation() -> Result<()> {
let config = MLConfig::default();
let predictor = MLPredictor::new(config)?;
assert_eq!(predictor.training_data_count(), 0);
assert_eq!(predictor.predictions_count(), 0);
assert!(!predictor.should_use_ml());
Ok(())
}
#[test]
fn test_feature_extraction() -> Result<()> {
let config = MLConfig::default();
let predictor = MLPredictor::new(config)?;
let query = Algebra::Empty;
let features = predictor.extract_features(&query);
assert_eq!(features.len(), 13, "Should extract 13 features");
Ok(())
}
#[test]
fn test_model_serialization() -> Result<()> {
let config = MLConfig::default();
let predictor = MLPredictor::new(config)?;
let serialized = serde_json::to_string(&predictor)?;
let deserialized: MLPredictor = serde_json::from_str(&serialized)?;
assert_eq!(predictor.config.model_type, deserialized.config.model_type);
Ok(())
}
#[test]
fn test_value_histogram_creation() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let histogram = ValueHistogram::from_data(&data, 5);
assert_eq!(histogram.total_count, 10);
assert_eq!(histogram.buckets.len(), 5);
assert_eq!(histogram.min_value, 1.0);
assert_eq!(histogram.max_value, 10.0);
}
#[test]
fn test_value_histogram_empty() {
let histogram = ValueHistogram::empty();
assert_eq!(histogram.total_count, 0);
assert_eq!(histogram.buckets.len(), 0);
}
#[test]
fn test_value_histogram_uniform() {
let data = vec![5.0; 100]; let histogram = ValueHistogram::from_data(&data, 10);
assert_eq!(histogram.total_count, 100);
assert_eq!(histogram.distinct_count, 1);
assert_eq!(histogram.buckets.len(), 1);
}
#[test]
fn test_histogram_equality_selectivity() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let histogram = ValueHistogram::from_data(&data, 5);
let selectivity_5 = histogram.estimate_equality_selectivity(5.0);
assert!(selectivity_5 > 0.0);
assert!(selectivity_5 <= 1.0);
let selectivity_100 = histogram.estimate_equality_selectivity(100.0);
assert_eq!(selectivity_100, 0.0);
}
#[test]
fn test_histogram_range_selectivity() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let histogram = ValueHistogram::from_data(&data, 5);
let selectivity_full = histogram.estimate_range_selectivity(1.0, 11.0);
assert!(selectivity_full >= 0.9);
assert!(selectivity_full <= 1.0);
let selectivity_half = histogram.estimate_range_selectivity(1.0, 5.5);
assert!(selectivity_half >= 0.4);
assert!(selectivity_half <= 0.6);
let selectivity_none = histogram.estimate_range_selectivity(100.0, 200.0);
assert_eq!(selectivity_none, 0.0);
}
#[test]
fn test_histogram_cardinality_estimation() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let histogram = ValueHistogram::from_data(&data, 5);
let selectivity = 0.5;
let estimated_cardinality = histogram.estimate_cardinality(selectivity);
assert_eq!(estimated_cardinality, 5);
}
#[test]
fn test_histogram_estimator_build() -> Result<()> {
let config = MLConfig::default();
let mut predictor = MLPredictor::new(config)?;
let query_characteristics = QueryCharacteristics {
triple_pattern_count: 5,
join_count: 2,
filter_count: 1,
optional_count: 0,
has_aggregation: false,
has_sorting: false,
estimated_cardinality: 1000,
complexity_score: 10.0,
query_graph_diameter: 2,
avg_degree: 1.5,
max_degree: 2,
};
for i in 0..100 {
let features = vec![
(i % 10) as f64,
(i % 5) as f64,
(i % 3) as f64,
0.0,
0.0,
0.0,
1000.0,
2.0,
1.5,
2.0,
0.0,
0.0,
0.0,
];
let example = TrainingExample {
features,
target_cost: (i * 10) as f64,
actual_cost: (i * 10) as f64,
query_characteristics: query_characteristics.clone(),
timestamp: SystemTime::now(),
};
predictor.add_training_example(example);
}
predictor.train_model()?;
let stats = predictor.get_histogram_statistics();
assert!(stats.num_histograms > 0);
assert!(stats.total_buckets > 0);
Ok(())
}
#[test]
fn test_histogram_enhanced_prediction() -> Result<()> {
let config = MLConfig::default();
let mut predictor = MLPredictor::new(config)?;
let query_characteristics = QueryCharacteristics {
triple_pattern_count: 5,
join_count: 2,
filter_count: 1,
optional_count: 0,
has_aggregation: false,
has_sorting: false,
estimated_cardinality: 1000,
complexity_score: 10.0,
query_graph_diameter: 2,
avg_degree: 1.5,
max_degree: 2,
};
for i in 0..100 {
let features = vec![
(i % 10) as f64,
(i % 5) as f64,
1.0,
0.0,
0.0,
0.0,
1000.0,
2.0,
1.5,
2.0,
0.0,
0.0,
0.0,
];
let example = TrainingExample {
features,
target_cost: (i * 10) as f64,
actual_cost: (i * 10) as f64,
query_characteristics: query_characteristics.clone(),
timestamp: SystemTime::now(),
};
predictor.add_training_example(example);
}
predictor.train_model()?;
let query = Algebra::Empty;
let prediction = predictor.predict_cost_with_histogram(&query)?;
assert!(prediction.predicted_cost >= 0.0);
assert!(prediction.confidence >= 0.0);
assert!(prediction.confidence <= 1.0);
Ok(())
}
#[test]
fn test_histogram_cardinality_blending() -> Result<()> {
let config = MLConfig::default();
let mut predictor = MLPredictor::new(config)?;
let query_characteristics = QueryCharacteristics {
triple_pattern_count: 5,
join_count: 2,
filter_count: 1,
optional_count: 0,
has_aggregation: false,
has_sorting: false,
estimated_cardinality: 1000,
complexity_score: 10.0,
query_graph_diameter: 2,
avg_degree: 1.5,
max_degree: 2,
};
for i in 0..150 {
let features = vec![
5.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1000.0, 2.0, 1.5, 2.0, 0.0, 0.0, 0.0,
];
let example = TrainingExample {
features,
target_cost: 100.0 + i as f64,
actual_cost: 100.0 + i as f64,
query_characteristics: query_characteristics.clone(),
timestamp: SystemTime::now(),
};
predictor.add_training_example(example);
}
predictor.train_model()?;
let features = vec![
5.0, 2.0, 1.0, 0.0, 0.0, 0.0, 1000.0, 2.0, 1.5, 2.0, 0.0, 0.0, 0.0,
];
let histogram_estimate = predictor.estimate_cardinality_with_histogram(&features);
assert!(histogram_estimate.is_some());
assert!(histogram_estimate.unwrap() > 0.0);
Ok(())
}
}