use super::goal::Goal;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QueryOptimizer {
selectivity_map: HashMap<String, f64>,
enable_reordering: bool,
enable_index_selection: bool,
enable_memoization: bool,
stats: OptimizationStats,
}
impl QueryOptimizer {
pub fn new() -> Self {
Self {
selectivity_map: HashMap::new(),
enable_reordering: true,
enable_index_selection: true,
enable_memoization: true,
stats: OptimizationStats::new(),
}
}
pub fn with_config(config: OptimizerConfig) -> Self {
Self {
selectivity_map: HashMap::new(),
enable_reordering: config.enable_reordering,
enable_index_selection: config.enable_index_selection,
enable_memoization: config.enable_memoization,
stats: OptimizationStats::new(),
}
}
pub fn optimize_goals(&mut self, goals: Vec<Goal>) -> Vec<Goal> {
if !self.enable_reordering || goals.len() <= 1 {
return goals;
}
self.stats.total_optimizations += 1;
let mut goal_selectivity: Vec<(Goal, f64)> = goals
.into_iter()
.map(|g| {
let selectivity = self.estimate_selectivity(&g);
(g, selectivity)
})
.collect();
goal_selectivity.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let optimized: Vec<Goal> = goal_selectivity.into_iter().map(|(g, _)| g).collect();
self.stats.goals_reordered += optimized.len();
optimized
}
pub fn estimate_selectivity(&self, goal: &Goal) -> f64 {
if let Some(&selectivity) = self.selectivity_map.get(&goal.pattern) {
return selectivity;
}
self.heuristic_selectivity(goal)
}
fn heuristic_selectivity(&self, goal: &Goal) -> f64 {
let pattern = &goal.pattern;
let (bound_count, var_count) = self.count_variables(pattern);
if var_count == 0 {
return 0.1;
}
let bound_ratio = bound_count as f64 / var_count as f64;
let selectivity = 1.0 - (bound_ratio * 0.8);
if pattern.contains("in_stock") || pattern.contains("available") {
return selectivity * 0.3;
}
if pattern.contains("expensive") || pattern.contains("premium") {
return selectivity * 0.5;
}
if pattern.contains("item") || pattern.contains("product") {
return selectivity * 1.2;
}
selectivity
}
fn count_variables(&self, pattern: &str) -> (usize, usize) {
let mut bound = 0;
let mut total = 0;
let chars: Vec<char> = pattern.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '?' {
total += 1;
i += 1;
while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
i += 1;
}
while i < chars.len() && chars[i].is_whitespace() {
i += 1;
}
if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
bound += 1;
}
} else {
i += 1;
}
}
(bound, total)
}
pub fn set_selectivity(&mut self, predicate: String, selectivity: f64) {
self.selectivity_map
.insert(predicate, selectivity.clamp(0.0, 1.0));
}
pub fn stats(&self) -> &OptimizationStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = OptimizationStats::new();
}
pub fn set_reordering(&mut self, enabled: bool) {
self.enable_reordering = enabled;
}
pub fn set_index_selection(&mut self, enabled: bool) {
self.enable_index_selection = enabled;
}
pub fn set_memoization(&mut self, enabled: bool) {
self.enable_memoization = enabled;
}
}
impl Default for QueryOptimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct OptimizerConfig {
pub enable_reordering: bool,
pub enable_index_selection: bool,
pub enable_memoization: bool,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
enable_reordering: true,
enable_index_selection: true,
enable_memoization: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OptimizationStats {
pub total_optimizations: usize,
pub goals_reordered: usize,
pub index_selections: usize,
pub memoization_hits: usize,
pub memoization_misses: usize,
}
impl OptimizationStats {
pub fn new() -> Self {
Self::default()
}
pub fn memoization_hit_rate(&self) -> f64 {
let total = self.memoization_hits + self.memoization_misses;
if total == 0 {
0.0
} else {
self.memoization_hits as f64 / total as f64
}
}
pub fn summary(&self) -> String {
format!(
"Optimizations: {} | Goals reordered: {} | Memo hits: {} ({:.1}%)",
self.total_optimizations,
self.goals_reordered,
self.memoization_hits,
self.memoization_hit_rate() * 100.0
)
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct JoinOptimizer {
cost_model: HashMap<String, f64>,
}
impl JoinOptimizer {
pub fn new() -> Self {
Self {
cost_model: HashMap::new(),
}
}
pub fn optimize_joins(&self, goals: Vec<Goal>) -> Vec<Goal> {
if goals.len() <= 1 {
return goals;
}
let mut sorted_goals = goals;
sorted_goals.sort_by_key(|g| {
-(self.count_bound_vars(&g.pattern) as i32)
});
sorted_goals
}
fn count_bound_vars(&self, pattern: &str) -> usize {
let mut count = 0;
let chars: Vec<char> = pattern.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '?' {
i += 1;
while i < chars.len() && (chars[i].is_alphanumeric() || chars[i] == '_') {
i += 1;
}
while i < chars.len() && chars[i].is_whitespace() {
i += 1;
}
if i < chars.len() && (chars[i] == '=' || chars[i] == '>' || chars[i] == '<') {
count += 1;
}
} else {
i += 1;
}
}
count
}
}
impl Default for JoinOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimizer_creation() {
let optimizer = QueryOptimizer::new();
assert!(optimizer.enable_reordering);
assert!(optimizer.enable_index_selection);
assert!(optimizer.enable_memoization);
}
#[test]
fn test_optimizer_with_config() {
let config = OptimizerConfig {
enable_reordering: false,
enable_index_selection: true,
enable_memoization: false,
};
let optimizer = QueryOptimizer::with_config(config);
assert!(!optimizer.enable_reordering);
assert!(optimizer.enable_index_selection);
assert!(!optimizer.enable_memoization);
}
#[test]
fn test_goal_reordering() {
let mut optimizer = QueryOptimizer::new();
optimizer.set_selectivity("in_stock(?x)".to_string(), 0.1);
optimizer.set_selectivity("expensive(?x)".to_string(), 0.3);
optimizer.set_selectivity("item(?x)".to_string(), 0.9);
let goals = vec![
Goal::new("item(?x)".to_string()),
Goal::new("expensive(?x)".to_string()),
Goal::new("in_stock(?x)".to_string()),
];
let optimized = optimizer.optimize_goals(goals);
assert_eq!(optimized[0].pattern, "in_stock(?x)");
assert_eq!(optimized[1].pattern, "expensive(?x)");
assert_eq!(optimized[2].pattern, "item(?x)");
}
#[test]
fn test_selectivity_estimation() {
let optimizer = QueryOptimizer::new();
let goal1 = Goal::new("employee(alice)".to_string());
let sel1 = optimizer.estimate_selectivity(&goal1);
assert!(sel1 < 0.5);
let goal2 = Goal::new("employee(?x)".to_string());
let sel2 = optimizer.estimate_selectivity(&goal2);
assert!(sel2 > sel1);
let goal3 = Goal::new("salary(?x) WHERE ?x > 100000".to_string());
let sel3 = optimizer.estimate_selectivity(&goal3);
assert!(sel3 < sel2);
}
#[test]
fn test_count_variables() {
let optimizer = QueryOptimizer::new();
let (bound, total) = optimizer.count_variables("employee(alice)");
assert_eq!(total, 0);
assert_eq!(bound, 0);
let (bound, total) = optimizer.count_variables("employee(?x)");
assert_eq!(total, 1);
assert_eq!(bound, 0);
let (bound, total) = optimizer.count_variables("salary(?x) WHERE ?x > 100");
assert_eq!(total, 2); assert_eq!(bound, 1); }
#[test]
fn test_optimization_stats() {
let mut optimizer = QueryOptimizer::new();
let goals = vec![
Goal::new("a(?x)".to_string()),
Goal::new("b(?x)".to_string()),
];
optimizer.optimize_goals(goals);
let stats = optimizer.stats();
assert_eq!(stats.total_optimizations, 1);
assert_eq!(stats.goals_reordered, 2);
}
#[test]
fn test_join_optimizer() {
let optimizer = JoinOptimizer::new();
let goals = vec![
Goal::new("item(?x)".to_string()),
Goal::new("price(?x, ?p) WHERE ?p > 100".to_string()),
Goal::new("in_stock(?x)".to_string()),
];
let optimized = optimizer.optimize_joins(goals);
assert!(optimized[0].pattern.contains("?p > 100"));
}
#[test]
fn test_disable_reordering() {
let mut optimizer = QueryOptimizer::new();
optimizer.set_reordering(false);
let goals = vec![
Goal::new("a(?x)".to_string()),
Goal::new("b(?x)".to_string()),
Goal::new("c(?x)".to_string()),
];
let optimized = optimizer.optimize_goals(goals.clone());
assert_eq!(optimized[0].pattern, goals[0].pattern);
assert_eq!(optimized[1].pattern, goals[1].pattern);
assert_eq!(optimized[2].pattern, goals[2].pattern);
}
#[test]
fn test_stats_summary() {
let mut stats = OptimizationStats::new();
stats.total_optimizations = 10;
stats.goals_reordered = 25;
stats.memoization_hits = 8;
stats.memoization_misses = 2;
let summary = stats.summary();
assert!(summary.contains("10"));
assert!(summary.contains("25"));
assert!(summary.contains("8"));
assert!(summary.contains("80")); }
#[test]
fn test_memoization_hit_rate() {
let mut stats = OptimizationStats::new();
assert_eq!(stats.memoization_hit_rate(), 0.0);
stats.memoization_hits = 8;
stats.memoization_misses = 2;
assert!((stats.memoization_hit_rate() - 0.8).abs() < 0.01);
stats.memoization_hits = 10;
stats.memoization_misses = 0;
assert_eq!(stats.memoization_hit_rate(), 1.0);
}
}