1use chryso_metadata::StatsCache;
2use crate::cost::{CostModel, UnitCostModel};
3use crate::memo::Memo;
4use crate::rules::RuleSet;
5use chryso_planner::{LogicalPlan, PhysicalPlan};
6
7pub mod cost;
8pub mod estimation;
9pub mod enforcer;
10pub mod column_prune;
11pub mod join_order;
12pub mod stats_collect;
13pub mod memo;
14pub mod physical_rules;
15pub mod properties;
16pub mod rules;
17pub mod subquery;
18pub mod expr_rewrite;
19pub mod utils;
20
21#[derive(Debug)]
22pub struct OptimizerTrace {
23 pub applied_rules: Vec<String>,
24 pub stats_loaded: Vec<String>,
25}
26
27impl OptimizerTrace {
28 pub fn new() -> Self {
29 Self {
30 applied_rules: Vec::new(),
31 stats_loaded: Vec::new(),
32 }
33 }
34}
35
36pub struct OptimizerConfig {
37 pub enable_cascades: bool,
38 pub enable_properties: bool,
39 pub rules: RuleSet,
40 pub trace: bool,
41 pub debug_rules: bool,
42 pub stats_provider: Option<std::sync::Arc<dyn chryso_metadata::StatsProvider>>,
43}
44
45impl std::fmt::Debug for OptimizerConfig {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("OptimizerConfig")
48 .field("enable_cascades", &self.enable_cascades)
49 .field("enable_properties", &self.enable_properties)
50 .field("rules", &self.rules)
51 .field("trace", &self.trace)
52 .field("debug_rules", &self.debug_rules)
53 .field("stats_provider", &self.stats_provider.is_some())
54 .finish()
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::cost::UnitCostModel;
61 use super::{CascadesOptimizer, OptimizerConfig};
62 use chryso_metadata::{type_inference::SimpleTypeInferencer, StatsCache};
63 use chryso_parser::{Dialect, ParserConfig, SimpleParser, SqlParser};
64 use chryso_planner::PlanBuilder;
65
66 #[test]
67 fn explain_with_types_and_costs() {
68 let sql = "select sum(amount) from sales group by region";
69 let parser = SimpleParser::new(ParserConfig {
70 dialect: Dialect::Postgres,
71 });
72 let stmt = parser.parse(sql).expect("parse");
73 let logical = PlanBuilder::build(stmt).expect("plan");
74 let typed = logical.explain_typed(0, &SimpleTypeInferencer);
75 assert!(typed.contains("LogicalAggregate"));
76
77 let physical = CascadesOptimizer::new(OptimizerConfig::default())
78 .optimize(&logical, &mut StatsCache::new());
79 let costed = physical.explain_costed(0, &UnitCostModel);
80 assert!(costed.contains("cost="));
81 }
82}
83
84impl Default for OptimizerConfig {
85 fn default() -> Self {
86 Self {
87 enable_cascades: true,
88 enable_properties: true,
89 rules: RuleSet::default(),
90 trace: false,
91 debug_rules: false,
92 stats_provider: None,
93 }
94 }
95}
96
97pub struct CascadesOptimizer {
98 config: OptimizerConfig,
99}
100
101impl CascadesOptimizer {
102 pub fn new(config: OptimizerConfig) -> Self {
103 Self { config }
104 }
105
106 pub fn optimize(&self, logical: &LogicalPlan, stats: &mut StatsCache) -> PhysicalPlan {
107 let _ = ensure_stats(logical, stats, &self.config);
108 let logical = crate::expr_rewrite::rewrite_plan(logical);
109 let logical = crate::column_prune::prune_plan(&logical);
110 if self.config.enable_cascades {
111 optimize_with_cascades(&logical, &self.config, stats).0
112 } else {
113 logical_to_physical(&logical)
114 }
115 }
116
117 pub fn optimize_with_trace(
118 &self,
119 logical: &LogicalPlan,
120 stats: &mut StatsCache,
121 ) -> (PhysicalPlan, OptimizerTrace) {
122 let loaded = ensure_stats(logical, stats, &self.config).unwrap_or_default();
123 let logical = crate::expr_rewrite::rewrite_plan(logical);
124 let logical = crate::column_prune::prune_plan(&logical);
125 if self.config.enable_cascades {
126 let (plan, mut trace) = optimize_with_cascades(&logical, &self.config, stats);
127 trace.stats_loaded = loaded;
128 (plan, trace)
129 } else {
130 let mut trace = OptimizerTrace::new();
131 trace.stats_loaded = loaded;
132 (logical_to_physical(&logical), trace)
133 }
134 }
135}
136
137fn optimize_with_cascades(
138 logical: &LogicalPlan,
139 config: &OptimizerConfig,
140 _stats: &StatsCache,
141) -> (PhysicalPlan, OptimizerTrace) {
142 let mut trace = OptimizerTrace::new();
143 let logical = apply_rules_recursive(logical, &config.rules, &mut trace, config.debug_rules);
144 let logical = crate::subquery::rewrite_correlated_subqueries(&logical);
145 let logical = crate::expr_rewrite::rewrite_plan(&logical);
146 let candidates = crate::join_order::enumerate_join_orders(&logical, _stats);
147 let mut memo = Memo::new();
148 let root = memo.insert(candidates.first().unwrap_or(&logical));
149 memo.explore(&config.rules);
150 let cost_model: Box<dyn CostModel> = if _stats.is_empty() {
151 Box::new(UnitCostModel)
152 } else {
153 Box::new(cost::StatsCostModel::new(_stats))
154 };
155 let mut best = memo
156 .best_physical(root, cost_model.as_ref())
157 .unwrap_or_else(|| logical_to_physical(&logical));
158 if config.enable_properties {
159 let required = crate::properties::PhysicalProperties::default();
160 best = crate::enforcer::enforce(best, &required);
161 }
162 (best, trace)
163}
164
165fn ensure_stats(
166 logical: &LogicalPlan,
167 stats: &mut StatsCache,
168 config: &OptimizerConfig,
169) -> chryso_core::ChrysoResult<Vec<String>> {
170 let Some(provider) = &config.stats_provider else {
171 return Ok(Vec::new());
172 };
173 let requirements = crate::stats_collect::collect_requirements(logical);
174 let mut missing_tables = Vec::new();
175 for table in &requirements.tables {
176 if stats.table_stats(table).is_none() {
177 missing_tables.push(table.clone());
178 }
179 }
180 let mut missing_columns = Vec::new();
181 for (table, column) in &requirements.columns {
182 if stats.column_stats(table, column).is_none() {
183 missing_columns.push((table.clone(), column.clone()));
184 }
185 }
186 if missing_tables.is_empty() && missing_columns.is_empty() {
187 return Ok(Vec::new());
188 }
189 provider.load_stats(&missing_tables, &missing_columns, stats)?;
190 Ok(missing_tables)
191}
192
193fn apply_rules_recursive(
194 plan: &LogicalPlan,
195 rules: &RuleSet,
196 trace: &mut OptimizerTrace,
197 debug_rules: bool,
198) -> LogicalPlan {
199 let mut rewritten = plan.clone();
200 let mut matched = Vec::new();
201 for rule in rules.iter() {
202 let alternatives = rule.apply(&rewritten);
203 if !alternatives.is_empty() {
204 matched.push(rule.name().to_string());
205 rewritten = alternatives.last().cloned().unwrap_or(rewritten);
206 }
207 }
208 if debug_rules {
209 trace.applied_rules.extend(matched);
210 }
211 let rewritten = match rewritten {
212 LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
213 predicate,
214 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
215 },
216 LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
217 exprs,
218 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
219 },
220 LogicalPlan::Join {
221 join_type,
222 left,
223 right,
224 on,
225 } => LogicalPlan::Join {
226 join_type,
227 left: Box::new(apply_rules_recursive(left.as_ref(), rules, trace, debug_rules)),
228 right: Box::new(apply_rules_recursive(right.as_ref(), rules, trace, debug_rules)),
229 on,
230 },
231 LogicalPlan::Aggregate {
232 group_exprs,
233 aggr_exprs,
234 input,
235 } => LogicalPlan::Aggregate {
236 group_exprs,
237 aggr_exprs,
238 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
239 },
240 LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
241 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
242 },
243 LogicalPlan::TopN {
244 order_by,
245 limit,
246 input,
247 } => LogicalPlan::TopN {
248 order_by,
249 limit,
250 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
251 },
252 LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
253 order_by,
254 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
255 },
256 LogicalPlan::Limit {
257 limit,
258 offset,
259 input,
260 } => LogicalPlan::Limit {
261 limit,
262 offset,
263 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
264 },
265 LogicalPlan::Derived {
266 input,
267 alias,
268 column_aliases,
269 } => LogicalPlan::Derived {
270 input: Box::new(apply_rules_recursive(input.as_ref(), rules, trace, debug_rules)),
271 alias,
272 column_aliases,
273 },
274 other => other,
275 };
276 let mut final_plan = rewritten.clone();
277 for rule in rules.iter() {
278 let alternatives = rule.apply(&final_plan);
279 if !alternatives.is_empty() {
280 final_plan = alternatives.last().cloned().unwrap_or(final_plan);
281 }
282 }
283 final_plan
284}
285
286fn logical_to_physical(logical: &LogicalPlan) -> PhysicalPlan {
287 let children = match logical {
288 LogicalPlan::Scan { .. } => Vec::new(),
289 LogicalPlan::IndexScan { .. } => Vec::new(),
290 LogicalPlan::Dml { .. } => Vec::new(),
291 LogicalPlan::Derived { input, .. } => vec![logical_to_physical(input)],
292 LogicalPlan::Filter { input, .. } => vec![logical_to_physical(input)],
293 LogicalPlan::Projection { input, .. } => vec![logical_to_physical(input)],
294 LogicalPlan::Join { left, right, .. } => {
295 vec![logical_to_physical(left), logical_to_physical(right)]
296 }
297 LogicalPlan::Aggregate { input, .. } => vec![logical_to_physical(input)],
298 LogicalPlan::Distinct { input } => vec![logical_to_physical(input)],
299 LogicalPlan::TopN { input, .. } => vec![logical_to_physical(input)],
300 LogicalPlan::Sort { input, .. } => vec![logical_to_physical(input)],
301 LogicalPlan::Limit { input, .. } => vec![logical_to_physical(input)],
302 };
303 let rules = crate::physical_rules::PhysicalRuleSet::default();
304 let candidates = rules.apply_all(logical, &children);
305 let cost_model = UnitCostModel;
306 candidates
307 .into_iter()
308 .min_by(|left, right| {
309 cost_model
310 .cost(left)
311 .0
312 .partial_cmp(&cost_model.cost(right).0)
313 .unwrap()
314 })
315 .unwrap_or(PhysicalPlan::TableScan {
316 table: "unknown".to_string(),
317 })
318}