use crate::ast::ast::{MatchClause, PathPattern};
use crate::plan::logical::LogicalPlan;
use crate::plan::pattern_optimization::{
cost_estimation::{ExecutionCost, StatisticsManager},
pattern_analysis::{PatternConnectivity, PatternPlanStrategy},
pattern_analyzer::PatternAnalyzer,
};
use std::collections::HashMap;
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct PatternOptimizationResult {
pub strategy: PatternPlanStrategy,
pub estimated_cost: ExecutionCost,
pub optimized: bool,
pub optimization_reason: String,
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct LogicalPatternOptimizer {
pattern_analyzer: PatternAnalyzer,
statistics_manager: StatisticsManager,
config: OptimizationConfig,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct OptimizationConfig {
pub enable_path_traversal: bool,
pub enable_hash_joins: bool,
pub min_patterns_for_optimization: usize,
pub max_patterns_for_optimization: usize,
pub cost_improvement_threshold: f64,
}
impl Default for OptimizationConfig {
fn default() -> Self {
Self {
enable_path_traversal: true,
enable_hash_joins: true,
min_patterns_for_optimization: 2,
max_patterns_for_optimization: 10,
cost_improvement_threshold: 0.1, }
}
}
impl LogicalPatternOptimizer {
pub fn new() -> Self {
Self {
pattern_analyzer: PatternAnalyzer::new(),
statistics_manager: StatisticsManager::new(),
config: OptimizationConfig::default(),
}
}
#[allow(dead_code)] pub fn with_config(config: OptimizationConfig) -> Self {
Self {
pattern_analyzer: PatternAnalyzer::new(),
statistics_manager: StatisticsManager::new(),
config,
}
}
pub fn optimize_comma_separated_patterns(
&mut self,
patterns: &[PathPattern],
context: &OptimizationContext,
) -> PatternOptimizationResult {
if patterns.len() < self.config.min_patterns_for_optimization {
return PatternOptimizationResult {
strategy: PatternPlanStrategy::CartesianProduct {
patterns: patterns.to_vec(),
estimated_cost: 0.0,
},
estimated_cost: ExecutionCost::new(0.0, 0, 0, 1.0),
optimized: false,
optimization_reason: format!(
"Too few patterns ({}) for optimization",
patterns.len()
),
};
}
if patterns.len() > self.config.max_patterns_for_optimization {
return PatternOptimizationResult {
strategy: PatternPlanStrategy::CartesianProduct {
patterns: patterns.to_vec(),
estimated_cost: 0.0,
},
estimated_cost: ExecutionCost::new(f64::INFINITY, u64::MAX, u64::MAX, 1.0),
optimized: false,
optimization_reason: format!(
"Too many patterns ({}) for optimization",
patterns.len()
),
};
}
let connectivity = self.pattern_analyzer.analyze_patterns(patterns.to_vec());
let strategies = self.generate_execution_strategies(&connectivity, context);
let mut cost_estimator = self.statistics_manager.create_cost_estimator();
let mut strategy_costs: Vec<(PatternPlanStrategy, ExecutionCost)> = Vec::new();
for strategy in strategies {
let cost = cost_estimator.estimate_cost(&strategy);
strategy_costs.push((strategy, cost));
}
self.choose_best_strategy(strategy_costs, context)
}
fn generate_execution_strategies(
&self,
connectivity: &PatternConnectivity,
context: &OptimizationContext,
) -> Vec<PatternPlanStrategy> {
let mut strategies = Vec::new();
strategies.push(PatternPlanStrategy::CartesianProduct {
patterns: connectivity.patterns.clone(),
estimated_cost: 0.0,
});
if self.config.enable_path_traversal {
if let Some(linear_path) = self.pattern_analyzer.detect_linear_path(connectivity) {
strategies.push(PatternPlanStrategy::PathTraversal(linear_path));
}
}
if self.config.enable_hash_joins && connectivity.patterns.len() >= 2 {
if let Some(join_strategy) = self.generate_hash_join_strategy(connectivity, context) {
strategies.push(join_strategy);
}
}
if connectivity.patterns.len() <= 4 {
strategies.push(PatternPlanStrategy::NestedLoopJoin {
patterns: connectivity.patterns.clone(),
estimated_cost: 0.0,
});
}
strategies
}
fn generate_hash_join_strategy(
&self,
connectivity: &PatternConnectivity,
_context: &OptimizationContext,
) -> Option<PatternPlanStrategy> {
let join_order = self.generate_join_order(connectivity);
if join_order.is_empty() {
return None;
}
Some(PatternPlanStrategy::HashJoin {
patterns: connectivity.patterns.clone(),
join_order,
estimated_cost: 0.0,
})
}
fn generate_join_order(
&self,
connectivity: &PatternConnectivity,
) -> Vec<crate::plan::pattern_optimization::pattern_analysis::JoinStep> {
use crate::plan::pattern_optimization::pattern_analysis::{JoinStep, JoinType};
let mut join_order = Vec::new();
for (var, pattern_indices) in &connectivity.shared_variables {
if pattern_indices.len() >= 2 {
let join_step = JoinStep {
left_pattern_idx: pattern_indices[0],
right_pattern_idx: pattern_indices[1],
join_variables: vec![var.clone()],
join_type: JoinType::Hash,
estimated_cost: 0.1, };
join_order.push(join_step);
}
}
join_order
}
fn choose_best_strategy(
&self,
strategy_costs: Vec<(PatternPlanStrategy, ExecutionCost)>,
_context: &OptimizationContext,
) -> PatternOptimizationResult {
if strategy_costs.is_empty() {
return PatternOptimizationResult {
strategy: PatternPlanStrategy::CartesianProduct {
patterns: vec![],
estimated_cost: 0.0,
},
estimated_cost: ExecutionCost::new(f64::INFINITY, u64::MAX, u64::MAX, 0.0),
optimized: false,
optimization_reason: "No strategies available".to_string(),
};
}
let mut best_strategy_idx = 0;
let mut best_cost_score = strategy_costs[0].1.total_cost();
for (i, (_, cost)) in strategy_costs.iter().enumerate() {
let cost_score = cost.total_cost();
if cost_score < best_cost_score {
best_strategy_idx = i;
best_cost_score = cost_score;
}
}
let best_strategy = &strategy_costs[best_strategy_idx];
let cartesian_cost = strategy_costs
.iter()
.find(|(s, _)| matches!(s, PatternPlanStrategy::CartesianProduct { .. }))
.map(|(_, cost)| cost.total_cost())
.unwrap_or(f64::INFINITY);
let optimized =
best_cost_score < cartesian_cost * (1.0 - self.config.cost_improvement_threshold);
let optimization_reason = if optimized {
match &best_strategy.0 {
PatternPlanStrategy::PathTraversal(_) => {
"Path traversal optimization applied for connected patterns".to_string()
}
PatternPlanStrategy::HashJoin { .. } => {
"Hash join optimization applied for shared variables".to_string()
}
PatternPlanStrategy::NestedLoopJoin { .. } => {
"Nested loop join selected for complex patterns".to_string()
}
PatternPlanStrategy::CartesianProduct { .. } => {
"Cartesian product is optimal for these patterns".to_string()
}
}
} else {
format!(
"No significant improvement found (best: {:.2}, baseline: {:.2})",
best_cost_score, cartesian_cost
)
};
PatternOptimizationResult {
strategy: best_strategy.0.clone(),
estimated_cost: best_strategy.1.clone(),
optimized,
optimization_reason,
}
}
#[allow(dead_code)] pub fn record_query_execution(&mut self, query: &str, result_count: u64) {
self.statistics_manager
.record_query_execution(query, result_count);
}
#[allow(dead_code)] pub fn get_config(&self) -> &OptimizationConfig {
&self.config
}
#[allow(dead_code)] pub fn update_config(&mut self, config: OptimizationConfig) {
self.config = config;
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct OptimizationContext {
pub available_indexes: Vec<IndexInfo>,
pub hints: HashMap<String, String>,
pub memory_budget: Option<u64>,
pub performance_requirements: PerformanceRequirements,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct IndexInfo {
pub index_type: String,
pub node_labels: Vec<String>,
pub relationship_types: Vec<String>,
pub properties: Vec<String>,
pub selectivity: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct PerformanceRequirements {
pub max_execution_time_ms: Option<u64>,
pub max_memory_bytes: Option<u64>,
pub optimization_priority: OptimizationPriority,
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum OptimizationPriority {
Speed,
Memory,
Balanced,
}
impl Default for OptimizationContext {
fn default() -> Self {
Self {
available_indexes: vec![],
hints: HashMap::new(),
memory_budget: None,
performance_requirements: PerformanceRequirements {
max_execution_time_ms: None,
max_memory_bytes: None,
optimization_priority: OptimizationPriority::Balanced,
},
}
}
}
pub fn optimize_match_clause_patterns(
optimizer: &mut LogicalPatternOptimizer,
match_clause: &MatchClause,
context: &OptimizationContext,
) -> Result<PatternOptimizationResult, String> {
let patterns: Vec<PathPattern> = match_clause.patterns.clone();
if patterns.is_empty() {
return Ok(PatternOptimizationResult {
strategy: PatternPlanStrategy::CartesianProduct {
patterns: vec![],
estimated_cost: 0.0,
},
estimated_cost: ExecutionCost::new(0.0, 0, 0, 1.0),
optimized: false,
optimization_reason: "No patterns to optimize".to_string(),
});
}
let result = optimizer.optimize_comma_separated_patterns(&patterns, context);
Ok(result)
}
#[allow(dead_code)] pub fn apply_optimization_to_logical_plan(
optimization_result: &PatternOptimizationResult,
current_plan: &LogicalPlan,
) -> Result<LogicalPlan, String> {
if !optimization_result.optimized {
return Ok(current_plan.clone());
}
let optimized_plan = current_plan.clone();
match &optimization_result.strategy {
PatternPlanStrategy::PathTraversal(_) => {
}
PatternPlanStrategy::HashJoin { .. } => {
}
PatternPlanStrategy::NestedLoopJoin { .. } => {
}
PatternPlanStrategy::CartesianProduct { .. } => {
}
}
Ok(optimized_plan)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::ast::Location;
#[test]
fn test_pattern_optimizer_creation() {
let optimizer = LogicalPatternOptimizer::new();
assert_eq!(optimizer.config.min_patterns_for_optimization, 2);
assert!(optimizer.config.enable_path_traversal);
}
#[test]
fn test_optimization_with_few_patterns() {
let mut optimizer = LogicalPatternOptimizer::new();
let context = OptimizationContext::default();
let pattern = PathPattern {
assignment: None,
path_type: None,
elements: vec![],
location: Location::default(),
};
let result = optimizer.optimize_comma_separated_patterns(&[pattern], &context);
assert!(!result.optimized);
assert!(result.optimization_reason.contains("Too few patterns"));
}
#[test]
fn test_optimization_with_many_patterns() {
let mut optimizer = LogicalPatternOptimizer::new();
let context = OptimizationContext::default();
let patterns: Vec<PathPattern> = (0..15)
.map(|_| PathPattern {
assignment: None,
path_type: None,
elements: vec![],
location: Location::default(),
})
.collect();
let result = optimizer.optimize_comma_separated_patterns(&patterns, &context);
assert!(!result.optimized);
assert!(result.optimization_reason.contains("Too many patterns"));
}
#[test]
fn test_optimization_context_default() {
let context = OptimizationContext::default();
assert_eq!(
context.performance_requirements.optimization_priority,
OptimizationPriority::Balanced
);
assert!(context.available_indexes.is_empty());
assert!(context.hints.is_empty());
}
#[test]
fn test_optimization_config_custom() {
let config = OptimizationConfig {
enable_path_traversal: false,
enable_hash_joins: true,
min_patterns_for_optimization: 3,
max_patterns_for_optimization: 5,
cost_improvement_threshold: 0.2,
};
let optimizer = LogicalPatternOptimizer::with_config(config.clone());
assert_eq!(optimizer.config.min_patterns_for_optimization, 3);
assert!(!optimizer.config.enable_path_traversal);
assert!(optimizer.config.enable_hash_joins);
}
}