chryso_optimizer/
lib.rs

1use chryso_metadata::StatsCache;
2use crate::cost::{CostModel, UnitCostModel};
3use crate::memo::Memo;
4use crate::rules::RuleSet;
5use chryso_planner::{LogicalPlan, PhysicalPlan};
6
7pub mod cost;
8pub mod estimation;
9pub mod enforcer;
10pub mod column_prune;
11pub mod join_order;
12pub mod stats_collect;
13pub mod memo;
14pub mod physical_rules;
15pub mod properties;
16pub mod rules;
17pub mod subquery;
18pub mod expr_rewrite;
19pub mod utils;
20
21#[derive(Debug)]
22pub struct OptimizerTrace {
23    pub applied_rules: Vec<String>,
24    pub stats_loaded: Vec<String>,
25}
26
27impl OptimizerTrace {
28    pub fn new() -> Self {
29        Self {
30            applied_rules: Vec::new(),
31            stats_loaded: Vec::new(),
32        }
33    }
34}
35
36pub struct OptimizerConfig {
37    pub enable_cascades: bool,
38    pub enable_properties: bool,
39    pub rules: RuleSet,
40    pub trace: bool,
41    pub debug_rules: bool,
42    pub stats_provider: Option<std::sync::Arc<dyn chryso_metadata::StatsProvider>>,
43}
44
45impl std::fmt::Debug for OptimizerConfig {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("OptimizerConfig")
48            .field("enable_cascades", &self.enable_cascades)
49            .field("enable_properties", &self.enable_properties)
50            .field("rules", &self.rules)
51            .field("trace", &self.trace)
52            .field("debug_rules", &self.debug_rules)
53            .field("stats_provider", &self.stats_provider.is_some())
54            .finish()
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::cost::UnitCostModel;
61    use super::{CascadesOptimizer, OptimizerConfig};
62    use chryso_metadata::{type_inference::SimpleTypeInferencer, StatsCache};
63    use chryso_parser::{Dialect, ParserConfig, SimpleParser, SqlParser};
64    use chryso_planner::PlanBuilder;
65
66    #[test]
67    fn explain_with_types_and_costs() {
68        let sql = "select sum(amount) from sales group by region";
69        let parser = SimpleParser::new(ParserConfig {
70            dialect: Dialect::Postgres,
71        });
72        let stmt = parser.parse(sql).expect("parse");
73        let logical = PlanBuilder::build(stmt).expect("plan");
74        let typed = logical.explain_typed(0, &SimpleTypeInferencer);
75        assert!(typed.contains("LogicalAggregate"));
76
77        let physical = CascadesOptimizer::new(OptimizerConfig::default())
78            .optimize(&logical, &mut StatsCache::new());
79        let costed = physical.explain_costed(0, &UnitCostModel);
80        assert!(costed.contains("cost="));
81    }
82}
83
84impl Default for OptimizerConfig {
85    fn default() -> Self {
86        Self {
87            enable_cascades: true,
88            enable_properties: true,
89            rules: RuleSet::default(),
90            trace: false,
91            debug_rules: false,
92            stats_provider: None,
93        }
94    }
95}
96
97pub struct CascadesOptimizer {
98    config: OptimizerConfig,
99}
100
101impl CascadesOptimizer {
102    pub fn new(config: OptimizerConfig) -> Self {
103        Self { config }
104    }
105
106    pub fn optimize(&self, logical: &LogicalPlan, stats: &mut StatsCache) -> PhysicalPlan {
107        let _ = ensure_stats(logical, stats, &self.config);
108        let logical = crate::expr_rewrite::rewrite_plan(logical);
109        let logical = crate::column_prune::prune_plan(&logical);
110        if self.config.enable_cascades {
111            optimize_with_cascades(&logical, &self.config, stats).0
112        } else {
113            logical_to_physical(&logical)
114        }
115    }
116
117    pub fn optimize_with_trace(
118        &self,
119        logical: &LogicalPlan,
120        stats: &mut StatsCache,
121    ) -> (PhysicalPlan, OptimizerTrace) {
122        let loaded = ensure_stats(logical, stats, &self.config).unwrap_or_default();
123        let logical = crate::expr_rewrite::rewrite_plan(logical);
124        let logical = crate::column_prune::prune_plan(&logical);
125        if self.config.enable_cascades {
126            let (plan, mut trace) = optimize_with_cascades(&logical, &self.config, stats);
127            trace.stats_loaded = loaded;
128            (plan, trace)
129        } else {
130            let mut trace = OptimizerTrace::new();
131            trace.stats_loaded = loaded;
132            (logical_to_physical(&logical), trace)
133        }
134    }
135}
136
137fn optimize_with_cascades(
138    logical: &LogicalPlan,
139    config: &OptimizerConfig,
140    _stats: &StatsCache,
141) -> (PhysicalPlan, OptimizerTrace) {
142    let mut trace = OptimizerTrace::new();
143    let logical = apply_rules_recursive(logical, &config.rules, &mut trace, config.debug_rules);
144    let logical = crate::subquery::rewrite_correlated_subqueries(&logical);
145    let logical = crate::expr_rewrite::rewrite_plan(&logical);
146    let candidates = crate::join_order::enumerate_join_orders(&logical, _stats);
147    let mut memo = Memo::new();
148    let root = memo.insert(candidates.first().unwrap_or(&logical));
149    memo.explore(&config.rules);
150    let cost_model: Box<dyn CostModel> = if _stats.is_empty() {
151        Box::new(UnitCostModel)
152    } else {
153        Box::new(cost::StatsCostModel::new(_stats))
154    };
155    let mut best = memo
156        .best_physical(root, cost_model.as_ref())
157        .unwrap_or_else(|| logical_to_physical(&logical));
158    if config.enable_properties {
159        let required = crate::properties::PhysicalProperties::default();
160        best = crate::enforcer::enforce(best, &required);
161    }
162    (best, trace)
163}
164
165fn ensure_stats(
166    logical: &LogicalPlan,
167    stats: &mut StatsCache,
168    config: &OptimizerConfig,
169) -> chryso_core::ChrysoResult<Vec<String>> {
170    let Some(provider) = &config.stats_provider else {
171        return Ok(Vec::new());
172    };
173    let requirements = crate::stats_collect::collect_requirements(logical);
174    let mut missing_tables = Vec::new();
175    for table in &requirements.tables {
176        if stats.table_stats(table).is_none() {
177            missing_tables.push(table.clone());
178        }
179    }
180    let mut missing_columns = Vec::new();
181    for (table, column) in &requirements.columns {
182        if stats.column_stats(table, column).is_none() {
183            missing_columns.push((table.clone(), column.clone()));
184        }
185    }
186    if missing_tables.is_empty() && missing_columns.is_empty() {
187        return Ok(Vec::new());
188    }
189    provider.load_stats(&missing_tables, &missing_columns, stats)?;
190    Ok(missing_tables)
191}
192
193fn apply_rules_recursive(
194    plan: &LogicalPlan,
195    rules: &RuleSet,
196    trace: &mut OptimizerTrace,
197    debug_rules: bool,
198) -> LogicalPlan {
199    let mut rewritten = plan.clone();
200    let mut matched = Vec::new();
201    for rule in rules.iter() {
202        let alternatives = rule.apply(&rewritten);
203        if !alternatives.is_empty() {
204            matched.push(rule.name().to_string());
205            rewritten = alternatives.last().cloned().unwrap_or(rewritten);
206        }
207    }
208    if debug_rules {
209        trace.applied_rules.extend(matched);
210    }
211    let rewritten = match rewritten {
212        LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
213            predicate,
214            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
215        },
216        LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
217            exprs,
218            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
219        },
220        LogicalPlan::Join {
221            join_type,
222            left,
223            right,
224            on,
225        } => LogicalPlan::Join {
226            join_type,
227            left: Box::new(apply_rules_recursive(left.as_ref(), rules, trace, debug_rules)),
228            right: Box::new(apply_rules_recursive(right.as_ref(), rules, trace, debug_rules)),
229            on,
230        },
231        LogicalPlan::Aggregate {
232            group_exprs,
233            aggr_exprs,
234            input,
235        } => LogicalPlan::Aggregate {
236            group_exprs,
237            aggr_exprs,
238            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
239        },
240        LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
241            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
242        },
243        LogicalPlan::TopN {
244            order_by,
245            limit,
246            input,
247        } => LogicalPlan::TopN {
248            order_by,
249            limit,
250            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
251        },
252        LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
253            order_by,
254            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
255        },
256        LogicalPlan::Limit {
257            limit,
258            offset,
259            input,
260        } => LogicalPlan::Limit {
261            limit,
262            offset,
263            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
264        },
265        LogicalPlan::Derived {
266            input,
267            alias,
268            column_aliases,
269        } => LogicalPlan::Derived {
270            input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
271            alias,
272            column_aliases,
273        },
274        other => other,
275    };
276    let mut final_plan = rewritten.clone();
277    for rule in rules.iter() {
278        let alternatives = rule.apply(&final_plan);
279        if !alternatives.is_empty() {
280            final_plan = alternatives.last().cloned().unwrap_or(final_plan);
281        }
282    }
283    final_plan
284}
285
286fn logical_to_physical(logical: &LogicalPlan) -> PhysicalPlan {
287    let children = match logical {
288        LogicalPlan::Scan { .. } => Vec::new(),
289        LogicalPlan::IndexScan { .. } => Vec::new(),
290        LogicalPlan::Dml { .. } => Vec::new(),
291        LogicalPlan::Derived { input, .. } => vec![logical_to_physical(input)],
292        LogicalPlan::Filter { input, .. } => vec![logical_to_physical(input)],
293        LogicalPlan::Projection { input, .. } => vec![logical_to_physical(input)],
294        LogicalPlan::Join { left, right, .. } => {
295            vec![logical_to_physical(left), logical_to_physical(right)]
296        }
297        LogicalPlan::Aggregate { input, .. } => vec![logical_to_physical(input)],
298        LogicalPlan::Distinct { input } => vec![logical_to_physical(input)],
299        LogicalPlan::TopN { input, .. } => vec![logical_to_physical(input)],
300        LogicalPlan::Sort { input, .. } => vec![logical_to_physical(input)],
301        LogicalPlan::Limit { input, .. } => vec![logical_to_physical(input)],
302    };
303    let rules = crate::physical_rules::PhysicalRuleSet::default();
304    let candidates = rules.apply_all(logical, &children);
305    let cost_model = UnitCostModel;
306    candidates
307        .into_iter()
308        .min_by(|left, right| {
309            cost_model
310                .cost(left)
311                .0
312                .partial_cmp(&cost_model.cost(right).0)
313                .unwrap()
314        })
315        .unwrap_or(PhysicalPlan::TableScan {
316            table: "unknown".to_string(),
317        })
318}