use crate::algebra::{Term, TriplePattern, Variable};
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Instant;
use scirs2_core::ndarray_ext::Array1;
use scirs2_stats::{
kurtosis,
mean,
median,
pearsonr,
regression::linear_regression,
skew,
std,
};
pub struct CardinalityEstimator {
histogram_estimator: Arc<RwLock<HistogramEstimator>>,
correlation_estimator: Arc<RwLock<CorrelationEstimator>>,
adaptive_estimator: Arc<RwLock<AdaptiveEstimator>>,
config: EstimatorConfig,
stats: Arc<RwLock<EstimationStats>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EstimatorConfig {
pub num_histogram_buckets: usize,
pub enable_correlation: bool,
pub enable_adaptive: bool,
pub min_samples: usize,
pub confidence_level: f64,
pub outlier_filtering: bool,
pub max_error_threshold: f64,
}
impl Default for EstimatorConfig {
fn default() -> Self {
Self {
num_histogram_buckets: 100,
enable_correlation: true,
enable_adaptive: true,
min_samples: 30,
confidence_level: 0.95,
outlier_filtering: true,
max_error_threshold: 0.2, }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EstimationStats {
pub total_estimations: u64,
pub total_error: f64,
pub mape: f64,
pub rmse: f64,
pub within_threshold: u64,
pub accuracy_by_pattern: HashMap<String, f64>,
}
pub struct HistogramEstimator {
predicate_histograms: HashMap<String, ValueHistogram>,
#[allow(dead_code)]
subject_histograms: HashMap<String, ValueHistogram>,
#[allow(dead_code)]
object_histograms: HashMap<String, ValueHistogram>,
config: EstimatorConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValueHistogram {
pub boundaries: Vec<String>,
pub frequencies: Vec<u64>,
pub total_count: u64,
pub distinct_count: u64,
pub mcv: Vec<(String, u64)>,
pub summary: HistogramSummary,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HistogramSummary {
pub mean_frequency: f64,
pub std_dev_frequency: f64,
pub median_frequency: f64,
pub skewness: f64,
pub kurtosis: f64,
}
impl HistogramEstimator {
pub fn new(config: EstimatorConfig) -> Self {
Self {
predicate_histograms: HashMap::new(),
subject_histograms: HashMap::new(),
object_histograms: HashMap::new(),
config,
}
}
pub fn build_histogram(&mut self, predicate: &str, values: Vec<String>) -> Result<()> {
let hist = self.create_histogram(values)?;
self.predicate_histograms
.insert(predicate.to_string(), hist);
Ok(())
}
fn create_histogram(&self, mut values: Vec<String>) -> Result<ValueHistogram> {
if values.is_empty() {
return Ok(ValueHistogram {
boundaries: Vec::new(),
frequencies: Vec::new(),
total_count: 0,
distinct_count: 0,
mcv: Vec::new(),
summary: HistogramSummary {
mean_frequency: 0.0,
std_dev_frequency: 0.0,
median_frequency: 0.0,
skewness: 0.0,
kurtosis: 0.0,
},
});
}
let total_count = values.len() as u64;
values.sort();
values.dedup();
let distinct_count = values.len() as u64;
let mut freq_map: HashMap<String, u64> = HashMap::new();
for val in &values {
*freq_map.entry(val.clone()).or_insert(0) += 1;
}
let mut mcv: Vec<(String, u64)> = freq_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
mcv.sort_by_key(|b| std::cmp::Reverse(b.1));
mcv.truncate(10);
let num_buckets = self
.config
.num_histogram_buckets
.min(distinct_count as usize);
let bucket_size = (distinct_count as usize + num_buckets - 1) / num_buckets;
let mut boundaries = Vec::new();
let mut frequencies = Vec::new();
for i in 0..num_buckets {
let start = i * bucket_size;
let end = ((i + 1) * bucket_size).min(values.len());
if start < values.len() {
boundaries.push(values[start].clone());
let bucket_freq: u64 = values[start..end]
.iter()
.map(|v| *freq_map.get(v).unwrap_or(&0))
.sum();
frequencies.push(bucket_freq);
}
}
let freq_f64: Vec<f64> = frequencies.iter().map(|&f| f as f64).collect();
let (mean_freq, std_dev_freq, median_freq, skew_val, kurt_val) = if !freq_f64.is_empty() {
let arr = Array1::from_vec(freq_f64.clone());
let arr_view = arr.view();
let mean_val = mean(&arr_view).unwrap_or(0.0);
let std_val = std(&arr_view, 1, None).unwrap_or(0.0); let median_val = median(&arr_view).unwrap_or(0.0);
let skew_result = skew(&arr_view, false, None).unwrap_or(0.0); let kurt_result = kurtosis(&arr_view, true, false, None).unwrap_or(0.0);
(mean_val, std_val, median_val, skew_result, kurt_result)
} else {
(0.0, 0.0, 0.0, 0.0, 0.0)
};
Ok(ValueHistogram {
boundaries,
frequencies,
total_count,
distinct_count,
mcv,
summary: HistogramSummary {
mean_frequency: mean_freq,
std_dev_frequency: std_dev_freq,
median_frequency: median_freq,
skewness: skew_val,
kurtosis: kurt_val,
},
})
}
pub fn estimate_selectivity(&self, predicate: &str, value: &str) -> f64 {
if let Some(hist) = self.predicate_histograms.get(predicate) {
for (mcv_val, mcv_freq) in &hist.mcv {
if mcv_val == value {
return *mcv_freq as f64 / hist.total_count as f64;
}
}
if let Some(bucket_idx) = self.find_bucket(&hist.boundaries, value) {
if bucket_idx < hist.frequencies.len() {
let bucket_freq = hist.frequencies[bucket_idx] as f64;
let bucket_distinct =
(hist.distinct_count as f64 / hist.boundaries.len() as f64).max(1.0);
return (bucket_freq / bucket_distinct) / hist.total_count as f64;
}
}
1.0 / hist.distinct_count.max(1) as f64
} else {
0.1 }
}
fn find_bucket(&self, boundaries: &[String], value: &str) -> Option<usize> {
boundaries
.binary_search(&value.to_string())
.ok()
.or_else(|| {
boundaries
.binary_search(&value.to_string())
.err()
.map(|pos| pos.saturating_sub(1))
})
}
pub fn estimate_pattern_cardinality(&self, pattern: &TriplePattern) -> u64 {
let predicate = match &pattern.predicate {
Term::Iri(iri) => Some(iri.as_str()),
Term::Literal(lit) => Some(lit.value.as_str()),
_ => None,
};
if let Some(pred) = predicate {
if let Some(hist) = self.predicate_histograms.get(pred) {
let mut selectivity = 1.0;
if let Term::Iri(subj) = &pattern.subject {
selectivity *= self.estimate_selectivity(pred, subj.as_str());
}
if let Term::Iri(obj) = &pattern.object {
selectivity *= self.estimate_selectivity(pred, obj.as_str());
}
return (hist.total_count as f64 * selectivity).max(1.0) as u64;
}
}
100
}
}
pub struct CorrelationEstimator {
correlations: HashMap<(String, String), CorrelationInfo>,
regression_coefficients: HashMap<(String, String), (f64, f64, f64)>,
config: EstimatorConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationInfo {
pub correlation: f64,
pub p_value: f64,
pub confidence_interval: (f64, f64),
pub sample_count: u64,
pub joint_cardinality: u64,
}
impl CorrelationEstimator {
pub fn new(config: EstimatorConfig) -> Self {
Self {
correlations: HashMap::new(),
regression_coefficients: HashMap::new(),
config,
}
}
pub fn build_correlation(
&mut self,
pred1: &str,
pred2: &str,
values1: Vec<f64>,
values2: Vec<f64>,
) -> Result<()> {
if values1.len() != values2.len() || values1.len() < self.config.min_samples {
return Err(anyhow!(
"Insufficient samples for correlation: {} < {}",
values1.len(),
self.config.min_samples
));
}
let arr1 = Array1::from_vec(values1.clone());
let arr2 = Array1::from_vec(values2.clone());
let (correlation, p_value) = pearsonr(&arr1.view(), &arr2.view(), "two-sided")
.map_err(|e| anyhow!("Failed to calculate correlation: {:?}", e))?;
let n = values1.len() as f64;
let z_score = 1.96_f64; let se = 1.0 / (n - 3.0).sqrt();
let fisher_z = 0.5 * ((1.0 + correlation) / (1.0 - correlation)).ln();
let ci_low = ((fisher_z - z_score * se).tanh() - 1.0) / 2.0;
let ci_high = ((fisher_z + z_score * se).tanh() + 1.0) / 2.0;
let corr_info = CorrelationInfo {
correlation,
p_value,
confidence_interval: (ci_low, ci_high),
sample_count: values1.len() as u64,
joint_cardinality: values1.len() as u64,
};
self.correlations
.insert((pred1.to_string(), pred2.to_string()), corr_info);
if p_value < (1.0 - self.config.confidence_level) && correlation.abs() > 0.3 {
use scirs2_core::ndarray_ext::Array2;
let n = arr1.len();
let mut x_2d = Array2::<f64>::zeros((n, 1));
for (i, &val) in arr1.iter().enumerate() {
x_2d[[i, 0]] = val;
}
let reg_result = linear_regression(&x_2d.view(), &arr2.view(), None)
.map_err(|e| anyhow!("Failed to fit regression model: {:?}", e))?;
let intercept = reg_result.coefficients.first().copied().unwrap_or(0.0);
let slope = reg_result.coefficients.get(1).copied().unwrap_or(0.0);
let r_squared = reg_result.r_squared;
self.regression_coefficients.insert(
(pred1.to_string(), pred2.to_string()),
(slope, intercept, r_squared),
);
}
Ok(())
}
pub fn estimate_join_cardinality(
&self,
pred1: &str,
card1: u64,
pred2: &str,
card2: u64,
) -> u64 {
let key = (pred1.to_string(), pred2.to_string());
if let Some(&(slope, intercept, _r_squared)) = self.regression_coefficients.get(&key) {
let predicted = slope * (card1 as f64) + intercept;
return predicted.max(1.0).min(card2 as f64) as u64;
}
if let Some(corr_info) = self.correlations.get(&key) {
let base_estimate = (card1 as f64).min(card2 as f64);
let correlation_factor = 1.0 - corr_info.correlation.abs();
(base_estimate * correlation_factor).max(1.0) as u64
} else {
((card1 as f64 * card2 as f64).sqrt() as u64).max(1)
}
}
pub fn get_correlation(&self, pred1: &str, pred2: &str) -> Option<f64> {
self.correlations
.get(&(pred1.to_string(), pred2.to_string()))
.map(|info| info.correlation)
}
}
pub struct AdaptiveEstimator {
estimation_history: Vec<EstimationRecord>,
correction_factors: HashMap<String, f64>,
bayesian_priors: HashMap<String, (f64, f64)>,
learning_rate: f64,
#[allow(dead_code)]
config: EstimatorConfig,
}
#[derive(Debug, Clone)]
pub struct EstimationRecord {
pub pattern_sig: String,
pub estimated: u64,
pub actual: u64,
pub error: f64,
pub timestamp: Instant,
}
impl AdaptiveEstimator {
pub fn new(config: EstimatorConfig) -> Self {
Self {
estimation_history: Vec::new(),
correction_factors: HashMap::new(),
bayesian_priors: HashMap::new(),
learning_rate: 0.1,
config,
}
}
pub fn record_estimation(&mut self, pattern_sig: String, estimated: u64, actual: u64) {
let error = if estimated > 0 {
((estimated as f64 - actual as f64).abs() / actual as f64).min(10.0)
} else {
1.0
};
let record = EstimationRecord {
pattern_sig: pattern_sig.clone(),
estimated,
actual,
error,
timestamp: Instant::now(),
};
self.estimation_history.push(record);
self.update_correction_factor(pattern_sig, estimated, actual);
if self.estimation_history.len() > 10000 {
self.estimation_history.drain(0..5000);
}
}
fn update_correction_factor(&mut self, pattern_sig: String, estimated: u64, actual: u64) {
let current_factor = self
.correction_factors
.get(&pattern_sig)
.copied()
.unwrap_or(1.0);
let target_factor = if estimated > 0 {
actual as f64 / estimated as f64
} else {
1.0
};
let (prior_mean, prior_std) = self
.bayesian_priors
.get(&pattern_sig)
.copied()
.unwrap_or((1.0, 1.0));
let prior_var = prior_std * prior_std;
let obs_var = 0.1;
let posterior_mean =
(prior_mean / prior_var + target_factor / obs_var) / (1.0 / prior_var + 1.0 / obs_var);
let posterior_var = 1.0 / (1.0 / prior_var + 1.0 / obs_var);
let posterior_std = posterior_var.sqrt();
self.bayesian_priors
.insert(pattern_sig.clone(), (posterior_mean, posterior_std));
let new_factor =
current_factor * (1.0 - self.learning_rate) + posterior_mean * self.learning_rate;
self.correction_factors.insert(pattern_sig, new_factor);
}
pub fn apply_correction(&self, pattern_sig: &str, estimate: u64) -> u64 {
if let Some(&factor) = self.correction_factors.get(pattern_sig) {
(estimate as f64 * factor).max(1.0) as u64
} else {
estimate
}
}
pub fn get_mape(&self) -> f64 {
if self.estimation_history.is_empty() {
return 0.0;
}
let total_error: f64 = self.estimation_history.iter().map(|r| r.error).sum();
total_error / self.estimation_history.len() as f64
}
pub fn get_accuracy(&self, threshold: f64) -> f64 {
if self.estimation_history.is_empty() {
return 0.0;
}
let within_threshold = self
.estimation_history
.iter()
.filter(|r| r.error <= threshold)
.count();
within_threshold as f64 / self.estimation_history.len() as f64
}
pub fn detect_outliers(&self, pattern_sig: &str) -> Vec<EstimationRecord> {
let errors: Vec<f64> = self
.estimation_history
.iter()
.filter(|r| r.pattern_sig == pattern_sig)
.map(|r| r.error)
.collect();
if errors.len() < 3 {
return Vec::new();
}
let arr = Array1::from_vec(errors.clone());
let arr_view = arr.view();
let q1 = median(&arr_view).unwrap_or(0.0) * 0.75; let q3 = median(&arr_view).unwrap_or(0.0) * 1.25; let iqr = q3 - q1;
let lower_bound = q1 - 1.5 * iqr;
let upper_bound = q3 + 1.5 * iqr;
self.estimation_history
.iter()
.filter(|r| {
r.pattern_sig == pattern_sig && (r.error < lower_bound || r.error > upper_bound)
})
.cloned()
.collect()
}
pub fn get_correction_confidence_interval(
&self,
pattern_sig: &str,
confidence_level: f64,
) -> Option<(f64, f64)> {
let (mean, std) = self.bayesian_priors.get(pattern_sig)?;
let z_score = if (confidence_level - 0.95).abs() < 0.01 {
1.96
} else if (confidence_level - 0.99).abs() < 0.01 {
2.576
} else {
1.645 };
let margin = z_score * std;
Some((mean - margin, mean + margin))
}
pub fn get_posterior_distribution(&self, pattern_sig: &str) -> Option<(f64, f64)> {
self.bayesian_priors.get(pattern_sig).copied()
}
}
impl CardinalityEstimator {
pub fn new(config: EstimatorConfig) -> Self {
Self {
histogram_estimator: Arc::new(RwLock::new(HistogramEstimator::new(config.clone()))),
correlation_estimator: Arc::new(RwLock::new(CorrelationEstimator::new(config.clone()))),
adaptive_estimator: Arc::new(RwLock::new(AdaptiveEstimator::new(config.clone()))),
config,
stats: Arc::new(RwLock::new(EstimationStats::default())),
}
}
pub fn estimate_triple_pattern(&self, pattern: &TriplePattern) -> Result<u64> {
let hist_est = self
.histogram_estimator
.read()
.map_err(|e| anyhow!("Failed to acquire histogram estimator lock: {}", e))?;
let base_estimate = hist_est.estimate_pattern_cardinality(pattern);
if self.config.enable_adaptive {
let pattern_sig = format!("{:?}", pattern);
let adaptive_est = self
.adaptive_estimator
.read()
.map_err(|e| anyhow!("Failed to acquire adaptive estimator lock: {}", e))?;
Ok(adaptive_est.apply_correction(&pattern_sig, base_estimate))
} else {
Ok(base_estimate)
}
}
pub fn estimate_join(
&self,
left_pattern: &TriplePattern,
left_card: u64,
right_pattern: &TriplePattern,
right_card: u64,
join_vars: &[Variable],
) -> Result<u64> {
let pred1 = match &left_pattern.predicate {
Term::Iri(iri) => Some(iri.as_str()),
Term::Literal(lit) => Some(lit.value.as_str()),
_ => None,
};
let pred2 = match &right_pattern.predicate {
Term::Iri(iri) => Some(iri.as_str()),
Term::Literal(lit) => Some(lit.value.as_str()),
_ => None,
};
if self.config.enable_correlation {
if let (Some(p1), Some(p2)) = (pred1, pred2) {
let corr_est = self
.correlation_estimator
.read()
.map_err(|e| anyhow!("Failed to acquire correlation estimator lock: {}", e))?;
return Ok(corr_est.estimate_join_cardinality(p1, left_card, p2, right_card));
}
}
let base_estimate = (left_card as f64 * right_card as f64).sqrt() as u64;
let join_discount = 0.5_f64.powi(join_vars.len() as i32);
Ok((base_estimate as f64 * join_discount).max(1.0) as u64)
}
pub fn record_actual_result(
&self,
pattern: &TriplePattern,
estimated: u64,
actual: u64,
) -> Result<()> {
if !self.config.enable_adaptive {
return Ok(());
}
let pattern_sig = format!("{:?}", pattern);
let mut adaptive_est = self
.adaptive_estimator
.write()
.map_err(|e| anyhow!("Failed to acquire adaptive estimator lock: {}", e))?;
adaptive_est.record_estimation(pattern_sig, estimated, actual);
let mut stats = self
.stats
.write()
.map_err(|e| anyhow!("Failed to acquire stats lock: {}", e))?;
stats.total_estimations += 1;
let error = if actual > 0 {
((estimated as f64 - actual as f64).abs() / actual as f64).min(10.0)
} else {
0.0
};
stats.total_error += error;
stats.mape = stats.total_error / stats.total_estimations as f64;
if error <= self.config.max_error_threshold {
stats.within_threshold += 1;
}
Ok(())
}
pub fn get_stats(&self) -> Result<EstimationStats> {
let stats = self
.stats
.read()
.map_err(|e| anyhow!("Failed to acquire stats lock: {}", e))?;
Ok(stats.clone())
}
pub fn build_histogram(&self, predicate: &str, values: Vec<String>) -> Result<()> {
let mut hist_est = self
.histogram_estimator
.write()
.map_err(|e| anyhow!("Failed to acquire histogram estimator lock: {}", e))?;
hist_est.build_histogram(predicate, values)
}
pub fn build_correlation(
&self,
pred1: &str,
pred2: &str,
values1: Vec<f64>,
values2: Vec<f64>,
) -> Result<()> {
if !self.config.enable_correlation {
return Ok(());
}
let mut corr_est = self
.correlation_estimator
.write()
.map_err(|e| anyhow!("Failed to acquire correlation estimator lock: {}", e))?;
corr_est.build_correlation(pred1, pred2, values1, values2)
}
}
#[cfg(feature = "parallel")]
pub struct MLCardinalityPredictor {
model: Arc<RwLock<Option<NeuralCardinalityModel>>>,
feature_extractor: QueryFeatureExtractor,
training_buffer: Arc<RwLock<Vec<TrainingExample>>>,
config: MLPredictorConfig,
stats: Arc<RwLock<MLPredictionStats>>,
}
#[derive(Debug, Clone)]
pub struct MLPredictorConfig {
pub hidden_sizes: Vec<usize>,
pub learning_rate: f64,
pub batch_size: usize,
pub num_epochs: usize,
pub min_training_examples: usize,
}
impl Default for MLPredictorConfig {
fn default() -> Self {
Self {
hidden_sizes: vec![64, 32, 16],
learning_rate: 0.001,
batch_size: 32,
num_epochs: 100,
min_training_examples: 1000,
}
}
}
pub struct NeuralCardinalityModel {
input_size: usize,
weights: Vec<f64>,
learning_rate: f64,
iterations: usize,
}
impl NeuralCardinalityModel {
pub fn new(input_size: usize, _hidden_sizes: &[usize], learning_rate: f64) -> Result<Self> {
use scirs2_core::random::Random;
let mut rng = Random::default();
let total_weights = input_size + 1;
let weights: Vec<f64> = (0..total_weights)
.map(|_| rng.random_f64() * 0.01)
.collect();
Ok(Self {
input_size,
weights,
learning_rate,
iterations: 0,
})
}
pub fn predict(&self, features: &Array1<f64>) -> Result<f64> {
if features.len() != self.input_size {
return Err(anyhow!(
"Feature size mismatch: expected {}, got {}",
self.input_size,
features.len()
));
}
let mut prediction = 0.0;
for (i, &feature) in features.iter().enumerate() {
if i < self.weights.len() {
prediction += feature * self.weights[i];
}
}
if let Some(&bias) = self.weights.last() {
prediction += bias;
}
Ok(prediction.max(0.0))
}
pub fn train_batch(&mut self, examples: &[TrainingExample]) -> Result<f64> {
if examples.is_empty() {
return Ok(0.0);
}
let mut total_loss = 0.0;
for example in examples {
let prediction = self.predict(&example.features)?;
let actual = example.actual_cardinality as f64;
let error = prediction - actual;
for (i, &feature) in example.features.iter().enumerate() {
if i < self.weights.len() {
let gradient = 2.0 * error * feature;
self.weights[i] -= self.learning_rate * gradient;
}
}
if let Some(bias) = self.weights.last_mut() {
*bias -= self.learning_rate * 2.0 * error;
}
total_loss += error * error;
}
self.iterations += 1;
Ok(total_loss / examples.len() as f64)
}
}
pub struct QueryFeatureExtractor {
feature_count: usize,
}
impl Default for QueryFeatureExtractor {
fn default() -> Self {
Self::new()
}
}
impl QueryFeatureExtractor {
pub fn new() -> Self {
Self {
feature_count: 20, }
}
pub fn extract_features(&self, pattern: &TriplePattern) -> Array1<f64> {
let mut features = Array1::zeros(self.feature_count);
let mut idx = 0;
features[idx] = if matches!(pattern.subject, Term::Variable(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.subject, Term::Iri(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.subject, Term::Literal(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.predicate, Term::Variable(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.predicate, Term::Iri(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = self.predicate_selectivity_score(&pattern.predicate);
idx += 1;
features[idx] = if matches!(pattern.object, Term::Variable(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.object, Term::Iri(_)) {
1.0
} else {
0.0
};
idx += 1;
features[idx] = if matches!(pattern.object, Term::Literal(_)) {
1.0
} else {
0.0
};
idx += 1;
let specificity = [&pattern.subject, &pattern.predicate, &pattern.object]
.iter()
.filter(|t| !matches!(t, Term::Variable(_)))
.count() as f64
/ 3.0;
features[idx] = specificity;
idx += 1;
while idx < self.feature_count {
features[idx] = 0.0;
idx += 1;
}
features
}
fn predicate_selectivity_score(&self, predicate: &Term) -> f64 {
match predicate {
Term::Variable(_) => 1.0, Term::Iri(iri)
if iri.as_str().contains("type") => {
0.3
}
_ => 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingExample {
pub features: Array1<f64>,
pub estimated_cardinality: u64,
pub actual_cardinality: u64,
pub pattern_sig: String,
}
#[derive(Debug, Clone, Default)]
pub struct MLPredictionStats {
pub total_predictions: u64,
pub ml_predictions: u64,
pub mae: f64,
pub rmse: f64,
pub r2_score: f64,
}
#[cfg(feature = "parallel")]
impl MLCardinalityPredictor {
pub fn new(config: MLPredictorConfig) -> Self {
Self {
model: Arc::new(RwLock::new(None)),
feature_extractor: QueryFeatureExtractor::new(),
training_buffer: Arc::new(RwLock::new(Vec::new())),
config,
stats: Arc::new(RwLock::new(MLPredictionStats::default())),
}
}
pub fn predict(&self, pattern: &TriplePattern) -> Result<Option<u64>> {
let model_lock = self
.model
.read()
.map_err(|e| anyhow!("Failed to acquire model lock: {}", e))?;
if let Some(model) = model_lock.as_ref() {
let features = self.feature_extractor.extract_features(pattern);
let prediction = model.predict(&features)?;
if let Ok(mut stats) = self.stats.write() {
stats.total_predictions += 1;
stats.ml_predictions += 1;
}
Ok(Some(prediction as u64))
} else {
Ok(None) }
}
pub fn record_result(
&self,
pattern: &TriplePattern,
estimated: u64,
actual: u64,
) -> Result<()> {
let features = self.feature_extractor.extract_features(pattern);
let pattern_sig = format!("{:?}", pattern);
let example = TrainingExample {
features,
estimated_cardinality: estimated,
actual_cardinality: actual,
pattern_sig,
};
let mut buffer = self
.training_buffer
.write()
.map_err(|e| anyhow!("Failed to acquire training buffer lock: {}", e))?;
buffer.push(example);
if buffer.len() >= self.config.min_training_examples {
drop(buffer); self.train_model()?;
}
Ok(())
}
fn train_model(&self) -> Result<()> {
let buffer = self
.training_buffer
.read()
.map_err(|e| anyhow!("Failed to acquire training buffer lock: {}", e))?;
if buffer.len() < self.config.min_training_examples {
return Ok(()); }
let mut model_lock = self
.model
.write()
.map_err(|e| anyhow!("Failed to acquire model lock: {}", e))?;
let model = if model_lock.is_none() {
let new_model = NeuralCardinalityModel::new(
self.feature_extractor.feature_count,
&self.config.hidden_sizes,
self.config.learning_rate,
)?;
*model_lock = Some(new_model);
model_lock.as_mut().expect("model should be initialized")
} else {
model_lock.as_mut().expect("model should be initialized")
};
let examples: Vec<_> = buffer.iter().cloned().collect();
drop(buffer);
for _ in 0..self.config.num_epochs {
for batch in examples.chunks(self.config.batch_size) {
let _loss = model.train_batch(batch)?;
}
}
Ok(())
}
pub fn get_stats(&self) -> Result<MLPredictionStats> {
let stats = self
.stats
.read()
.map_err(|e| anyhow!("Failed to acquire stats lock: {}", e))?;
Ok(stats.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_histogram_creation() {
let config = EstimatorConfig::default();
let mut estimator = HistogramEstimator::new(config);
let values: Vec<String> = (0..1000).map(|i| format!("value_{}", i)).collect();
let result = estimator.build_histogram("test_pred", values);
assert!(result.is_ok());
let hist = estimator.predicate_histograms.get("test_pred").unwrap();
assert_eq!(hist.total_count, 1000);
assert_eq!(hist.distinct_count, 1000);
}
#[test]
fn test_selectivity_estimation() {
let config = EstimatorConfig::default();
let mut estimator = HistogramEstimator::new(config);
let values: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
estimator.build_histogram("test_pred", values).unwrap();
let selectivity = estimator.estimate_selectivity("test_pred", "b");
assert!(selectivity > 0.0 && selectivity <= 1.0);
}
#[test]
fn test_correlation_calculation() {
let config = EstimatorConfig::default();
let mut estimator = CorrelationEstimator::new(config);
let x: Vec<f64> = (0..100).map(|i| i as f64).collect();
let y: Vec<f64> = (0..100).map(|i| i as f64 * 2.0).collect();
let arr_x = Array1::from_vec(x.clone());
let arr_y = Array1::from_vec(y.clone());
let (corr, _p_value) = pearsonr(&arr_x.view(), &arr_y.view(), "two-sided").unwrap();
assert!((corr - 1.0).abs() < 0.01);
estimator.build_correlation("pred1", "pred2", x, y).unwrap();
let stored_corr = estimator.get_correlation("pred1", "pred2").unwrap();
assert!((stored_corr - 1.0).abs() < 0.01);
}
#[test]
fn test_adaptive_learning() {
let config = EstimatorConfig::default();
let mut estimator = AdaptiveEstimator::new(config);
for i in 0..100 {
let estimated = 100;
let actual = 80; estimator.record_estimation(format!("pattern_{}", i % 10), estimated, actual);
}
let corrected = estimator.apply_correction("pattern_0", 100);
assert!(corrected < 100); }
#[test]
fn test_full_estimator() {
let config = EstimatorConfig::default();
let estimator = CardinalityEstimator::new(config);
let values: Vec<String> = (0..100).map(|i| format!("value_{}", i)).collect();
estimator
.build_histogram("http://example.org/pred", values)
.unwrap();
let pattern = TriplePattern {
subject: Term::Variable(Variable::new("s".to_string()).unwrap()),
predicate: Term::Iri(
crate::algebra::Iri::new("http://example.org/pred".to_string()).unwrap(),
),
object: Term::Variable(Variable::new("o".to_string()).unwrap()),
};
let estimate = estimator.estimate_triple_pattern(&pattern);
assert!(estimate.is_ok());
assert!(estimate.unwrap() > 0);
}
}