use super::{OptimizerConfig, Statistics};
use crate::algebra::{Algebra, TriplePattern};
use crate::statistics::{
CardinalityEstimator, EstimationMethod, HyperLogLogSketch, PredicateHistogram, ReservoirSample,
};
use anyhow::Result;
use std::collections::HashMap;
pub struct EnhancedOptimizer {
pub config: OptimizerConfig,
pub statistics: Statistics,
pub cardinality_estimator: CardinalityEstimator,
plan_cache: HashMap<u64, Algebra>,
max_cache_size: usize,
}
impl EnhancedOptimizer {
pub fn new(config: OptimizerConfig) -> Self {
Self::with_estimation_method(config, EstimationMethod::MachineLearning)
}
pub fn with_estimation_method(config: OptimizerConfig, method: EstimationMethod) -> Self {
Self {
config,
statistics: Statistics::new(),
cardinality_estimator: CardinalityEstimator::with_method(method),
plan_cache: HashMap::new(),
max_cache_size: 1000,
}
}
pub fn optimize(&mut self, algebra: Algebra) -> Result<Algebra> {
let query_hash = self.compute_query_hash(&algebra);
if let Some(cached_plan) = self.plan_cache.get(&query_hash) {
return Ok(cached_plan.clone());
}
let mut optimized = algebra;
let mut pass = 0;
while pass < self.config.max_passes {
let before = optimized.clone();
if self.config.filter_pushdown {
optimized = self.apply_filter_pushdown(optimized)?;
}
if self.config.join_reordering {
optimized = self.apply_cardinality_based_join_ordering(optimized)?;
}
if self.config.projection_pushdown {
optimized = self.apply_projection_pushdown(optimized)?;
}
if self.algebra_equal(&before, &optimized) {
break;
}
pass += 1;
}
if self.plan_cache.len() < self.max_cache_size {
self.plan_cache.insert(query_hash, optimized.clone());
}
Ok(optimized)
}
fn apply_cardinality_based_join_ordering(&mut self, algebra: Algebra) -> Result<Algebra> {
match algebra {
Algebra::Join { left, right } => {
let left_card = self.estimate_algebra_cardinality(&left)?;
let right_card = self.estimate_algebra_cardinality(&right)?;
let reordered = if right_card < left_card {
Algebra::Join {
left: Box::new(self.apply_cardinality_based_join_ordering(*right)?),
right: Box::new(self.apply_cardinality_based_join_ordering(*left)?),
}
} else {
Algebra::Join {
left: Box::new(self.apply_cardinality_based_join_ordering(*left)?),
right: Box::new(self.apply_cardinality_based_join_ordering(*right)?),
}
};
Ok(reordered)
}
Algebra::Union { left, right } => Ok(Algebra::Union {
left: Box::new(self.apply_cardinality_based_join_ordering(*left)?),
right: Box::new(self.apply_cardinality_based_join_ordering(*right)?),
}),
Algebra::Bgp(patterns) => {
let mut pattern_cards: Vec<(TriplePattern, u64)> = Vec::new();
for p in patterns {
let card = self
.cardinality_estimator
.estimate_pattern_cardinality(&p)
.unwrap_or(10000);
pattern_cards.push((p, card));
}
pattern_cards.sort_by_key(|(_, card)| *card);
let reordered_patterns: Vec<TriplePattern> =
pattern_cards.into_iter().map(|(p, _)| p).collect();
Ok(Algebra::Bgp(reordered_patterns))
}
other => Ok(other),
}
}
fn estimate_algebra_cardinality(&mut self, algebra: &Algebra) -> Result<u64> {
match algebra {
Algebra::Bgp(patterns) if patterns.len() == 1 => {
self.cardinality_estimator
.estimate_pattern_cardinality(&patterns[0])
.map_err(|e| anyhow::anyhow!("Cardinality estimation failed: {}", e))
}
Algebra::Bgp(patterns) if patterns.len() > 1 => {
let mut total_card = self
.cardinality_estimator
.estimate_pattern_cardinality(&patterns[0])
.map_err(|e| anyhow::anyhow!("Cardinality estimation failed: {}", e))?;
for i in 1..patterns.len() {
total_card = self
.cardinality_estimator
.estimate_join_cardinality(&patterns[i - 1], &patterns[i])
.map_err(|e| {
anyhow::anyhow!("Join cardinality estimation failed: {}", e)
})?;
}
Ok(total_card)
}
Algebra::Join { left, right } => {
let left_card = self.estimate_algebra_cardinality(left)?;
let right_card = self.estimate_algebra_cardinality(right)?;
Ok((left_card as f64 * right_card as f64).sqrt() as u64)
}
Algebra::Union { left, right } => {
let left_card = self.estimate_algebra_cardinality(left)?;
let right_card = self.estimate_algebra_cardinality(right)?;
Ok(left_card + right_card)
}
Algebra::Filter { pattern, .. } => {
let base_card = self.estimate_algebra_cardinality(pattern)?;
Ok((base_card as f64 * 0.1) as u64) }
_ => Ok(1000), }
}
fn apply_filter_pushdown(&self, algebra: Algebra) -> Result<Algebra> {
match algebra {
Algebra::Filter { pattern, condition } => {
let optimized_pattern = self.apply_filter_pushdown(*pattern)?;
Ok(Algebra::Filter {
pattern: Box::new(optimized_pattern),
condition,
})
}
Algebra::Join { left, right } => Ok(Algebra::Join {
left: Box::new(self.apply_filter_pushdown(*left)?),
right: Box::new(self.apply_filter_pushdown(*right)?),
}),
other => Ok(other),
}
}
fn apply_projection_pushdown(&self, algebra: Algebra) -> Result<Algebra> {
Ok(algebra)
}
pub fn update_cardinality_statistics(
&mut self,
predicate: String,
count: u64,
distinct_subjects: u64,
distinct_objects: u64,
) {
self.cardinality_estimator.update_statistics(
predicate,
count,
distinct_subjects,
distinct_objects,
);
}
pub fn add_predicate_histogram(&mut self, histogram: PredicateHistogram) {
self.cardinality_estimator.add_histogram(histogram);
}
pub fn add_sketch(&mut self, predicate: String, sketch: HyperLogLogSketch) {
self.cardinality_estimator.add_sketch(predicate, sketch);
}
pub fn add_sample(&mut self, predicate: String, sample: ReservoirSample) {
self.cardinality_estimator.add_sample(predicate, sample);
}
pub fn train_ml_model(&mut self, training_data: &[(TriplePattern, u64)]) {
self.cardinality_estimator.train_ml_model(training_data);
}
pub fn update_join_correlation(&mut self, pred1: String, pred2: String, correlation: f64) {
self.cardinality_estimator
.update_join_correlation(pred1, pred2, correlation);
}
fn compute_query_hash(&self, algebra: &Algebra) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
format!("{:?}", algebra).hash(&mut hasher);
hasher.finish()
}
fn algebra_equal(&self, a: &Algebra, b: &Algebra) -> bool {
format!("{:?}", a) == format!("{:?}", b)
}
#[allow(dead_code)]
fn extract_variables(&self, algebra: &Algebra) -> Vec<String> {
let mut vars = Vec::new();
match algebra {
Algebra::Bgp(patterns) => {
for pattern in patterns {
if let crate::algebra::Term::Variable(v) = &pattern.subject {
vars.push(v.name().to_string());
}
if let crate::algebra::Term::Variable(v) = &pattern.predicate {
vars.push(v.name().to_string());
}
if let crate::algebra::Term::Variable(v) = &pattern.object {
vars.push(v.name().to_string());
}
}
}
Algebra::Join { left, right } => {
vars.extend(self.extract_variables(left));
vars.extend(self.extract_variables(right));
}
Algebra::Union { left, right } => {
vars.extend(self.extract_variables(left));
vars.extend(self.extract_variables(right));
}
_ => {}
}
vars.sort();
vars.dedup();
vars
}
#[allow(dead_code)]
fn extract_expression_variables(&self, _expr: &crate::algebra::Expression) -> Vec<String> {
Vec::new()
}
pub fn plan_cache_stats(&self) -> (usize, usize) {
(self.plan_cache.len(), self.max_cache_size)
}
pub fn clear_plan_cache(&mut self) {
self.plan_cache.clear();
}
pub fn cardinality_statistics(&self) -> String {
self.cardinality_estimator.statistics_summary()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::{Term, Variable};
use oxirs_core::NamedNode;
#[test]
fn test_enhanced_optimizer_creation() {
let config = OptimizerConfig::default();
let optimizer = EnhancedOptimizer::new(config);
assert_eq!(optimizer.plan_cache.len(), 0);
assert_eq!(optimizer.max_cache_size, 1000);
}
#[test]
fn test_bgp_pattern_reordering() {
let config = OptimizerConfig::default();
let mut optimizer = EnhancedOptimizer::new(config);
optimizer.update_cardinality_statistics(
"http://xmlns.com/foaf/0.1/name".to_string(),
100,
80,
90,
);
optimizer.update_cardinality_statistics(
"http://xmlns.com/foaf/0.1/knows".to_string(),
10000,
800,
800,
);
let patterns = vec![
TriplePattern {
subject: Term::Variable(Variable::new("p").expect("Valid var")),
predicate: Term::Iri(
NamedNode::new("http://xmlns.com/foaf/0.1/knows").expect("Valid IRI"),
),
object: Term::Variable(Variable::new("f").expect("Valid var")),
},
TriplePattern {
subject: Term::Variable(Variable::new("p").expect("Valid var")),
predicate: Term::Iri(
NamedNode::new("http://xmlns.com/foaf/0.1/name").expect("Valid IRI"),
),
object: Term::Variable(Variable::new("n").expect("Valid var")),
},
];
let query = Algebra::Bgp(patterns);
let optimized = optimizer.optimize(query).unwrap();
if let Algebra::Bgp(optimized_patterns) = optimized {
assert_eq!(optimized_patterns.len(), 2);
if let Term::Iri(iri) = &optimized_patterns[0].predicate {
let iri_str = iri.as_str();
assert!(iri_str.contains("foaf/0.1/"));
assert!(iri_str.contains("name") || iri_str.contains("knows"));
}
} else {
panic!("Expected BGP after optimization, got: {:?}", optimized);
}
}
#[test]
fn test_plan_caching() {
let config = OptimizerConfig::default();
let mut optimizer = EnhancedOptimizer::new(config);
let pattern = TriplePattern {
subject: Term::Variable(Variable::new("s").expect("Valid var")),
predicate: Term::Iri(NamedNode::new("http://example.org/pred").expect("Valid IRI")),
object: Term::Variable(Variable::new("o").expect("Valid var")),
};
let query = Algebra::Bgp(vec![pattern]);
let result1 = optimizer.optimize(query.clone()).unwrap();
let (cache_size_1, _) = optimizer.plan_cache_stats();
assert_eq!(cache_size_1, 1);
let result2 = optimizer.optimize(query.clone()).unwrap();
let (cache_size_2, _) = optimizer.plan_cache_stats();
assert_eq!(cache_size_2, 1);
assert!(optimizer.algebra_equal(&result1, &result2));
}
#[test]
fn test_cardinality_estimation_integration() {
let config = OptimizerConfig::default();
let mut optimizer = EnhancedOptimizer::new(config);
optimizer.update_cardinality_statistics(
"http://example.org/pred".to_string(),
1000,
800,
900,
);
let pattern = TriplePattern {
subject: Term::Variable(Variable::new("s").expect("Valid var")),
predicate: Term::Iri(NamedNode::new("http://example.org/pred").expect("Valid IRI")),
object: Term::Variable(Variable::new("o").expect("Valid var")),
};
let estimated = optimizer
.cardinality_estimator
.estimate_pattern_cardinality(&pattern)
.unwrap();
assert!(estimated > 0);
assert!(estimated <= 2000);
}
#[test]
fn test_ml_model_training() {
let config = OptimizerConfig::default();
let mut optimizer =
EnhancedOptimizer::with_estimation_method(config, EstimationMethod::MachineLearning);
let pattern = TriplePattern {
subject: Term::Variable(Variable::new("s").expect("Valid var")),
predicate: Term::Iri(NamedNode::new("http://example.org/pred").expect("Valid IRI")),
object: Term::Variable(Variable::new("o").expect("Valid var")),
};
let training_data = vec![(pattern, 500)];
optimizer.train_ml_model(&training_data);
let stats = optimizer.cardinality_statistics();
assert!(stats.contains("ml_trained: true"));
}
#[test]
fn test_join_ordering_with_cardinality() {
let config = OptimizerConfig::default();
let mut optimizer = EnhancedOptimizer::new(config);
optimizer.update_cardinality_statistics(
"http://example.org/large".to_string(),
10000,
5000,
5000,
);
optimizer.update_cardinality_statistics(
"http://example.org/small".to_string(),
100,
80,
90,
);
let large_pattern = Algebra::Bgp(vec![TriplePattern {
subject: Term::Variable(Variable::new("s").expect("Valid var")),
predicate: Term::Iri(NamedNode::new("http://example.org/large").expect("Valid IRI")),
object: Term::Variable(Variable::new("o").expect("Valid var")),
}]);
let small_pattern = Algebra::Bgp(vec![TriplePattern {
subject: Term::Variable(Variable::new("s").expect("Valid var")),
predicate: Term::Iri(NamedNode::new("http://example.org/small").expect("Valid IRI")),
object: Term::Variable(Variable::new("o").expect("Valid var")),
}]);
let join = Algebra::Join {
left: Box::new(large_pattern),
right: Box::new(small_pattern),
};
let optimized = optimizer.optimize(join).unwrap();
if let Algebra::Join { left, .. } = optimized {
if let Algebra::Bgp(patterns) = &*left {
if let Term::Iri(iri) = &patterns[0].predicate {
assert!(iri.as_str().contains("small"));
}
}
}
}
#[test]
fn test_cache_eviction() {
let config = OptimizerConfig::default();
let mut optimizer = EnhancedOptimizer::new(config);
optimizer.max_cache_size = 2;
for i in 0..3 {
let pattern = TriplePattern {
subject: Term::Variable(Variable::new(format!("s{}", i)).expect("Valid var")),
predicate: Term::Iri(
NamedNode::new(format!("http://example.org/pred{}", i)).expect("Valid IRI"),
),
object: Term::Variable(Variable::new("o").expect("Valid var")),
};
let query = Algebra::Bgp(vec![pattern]);
optimizer.optimize(query).unwrap();
}
let (cache_size, max_size) = optimizer.plan_cache_stats();
assert!(cache_size <= max_size);
}
}