1use crate::cost::{Cost, CostModel};
2use crate::physical_rules::PhysicalRuleSet;
3use crate::rules::RuleContext;
4use crate::{MemoTrace, MemoTraceCandidate, MemoTraceGroup, RuleConfig, SearchBudget};
5use chryso_planner::{LogicalPlan, PhysicalPlan};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct GroupId(usize);
10
11#[derive(Debug)]
12pub struct Memo {
13 groups: Vec<Group>,
14}
15
16impl Memo {
17 pub fn new() -> Self {
18 Self { groups: Vec::new() }
19 }
20
21 pub fn insert(&mut self, plan: &LogicalPlan) -> GroupId {
22 let expr = GroupExpr::from_logical(plan, self);
23 let group_id = GroupId(self.groups.len());
24 self.groups.push(Group {
25 expressions: vec![expr],
26 });
27 group_id
28 }
29
30 pub fn best_physical(
31 &self,
32 root: GroupId,
33 physical_rules: &PhysicalRuleSet,
34 cost_model: &dyn CostModel,
35 ) -> Option<PhysicalPlan> {
36 let mut cache = HashMap::new();
37 self.best_physical_with_cache(root, physical_rules, cost_model, &mut cache)
38 .map(|(_, plan)| plan)
39 }
40
41 pub fn trace(&self, physical_rules: &PhysicalRuleSet, cost_model: &dyn CostModel) -> MemoTrace {
42 let mut cache = HashMap::new();
43 let mut groups = Vec::with_capacity(self.groups.len());
44 for (group_id, group) in self.groups.iter().enumerate() {
45 let mut candidates = Vec::new();
46 for expr in &group.expressions {
47 let MemoOperator::Logical(logical) = &expr.operator else {
48 continue;
49 };
50 let mut inputs = Vec::new();
51 let mut missing_input = false;
52 for child in &expr.children {
53 if let Some((_, best)) = self.best_physical_with_cache(
54 *child,
55 physical_rules,
56 cost_model,
57 &mut cache,
58 ) {
59 inputs.push(best);
60 } else {
61 missing_input = true;
62 break;
63 }
64 }
65 if missing_input {
66 continue;
67 }
68 for physical in physical_rules.apply_all(logical, &inputs) {
69 let cost = cost_model.cost(&physical).0;
70 let plan = physical.explain_costed(0, cost_model);
71 candidates.push(MemoTraceCandidate { cost, plan });
72 }
73 }
74 candidates.sort_by(|left, right| {
75 left.cost
76 .partial_cmp(&right.cost)
77 .unwrap_or(std::cmp::Ordering::Equal)
78 .then_with(|| left.plan.cmp(&right.plan))
79 });
80 groups.push(MemoTraceGroup {
81 id: group_id,
82 candidates,
83 });
84 }
85 MemoTrace { groups }
86 }
87
88 fn best_physical_with_cache(
89 &self,
90 root: GroupId,
91 physical_rules: &PhysicalRuleSet,
92 cost_model: &dyn CostModel,
93 cache: &mut HashMap<GroupId, (Cost, PhysicalPlan)>,
94 ) -> Option<(Cost, PhysicalPlan)> {
95 if let Some(cached) = cache.get(&root) {
96 return Some(cached.clone());
97 }
98 let group = self.groups.get(root.0)?;
99 let mut best: Option<(Cost, PhysicalPlan)> = None;
100 for expr in &group.expressions {
101 match &expr.operator {
102 MemoOperator::Logical(logical) => {
103 let mut inputs = Vec::new();
104 let mut missing_input = false;
105 for child in &expr.children {
106 if let Some((_, best_child)) =
107 self.best_physical_with_cache(*child, physical_rules, cost_model, cache)
108 {
109 inputs.push(best_child);
110 } else {
111 missing_input = true;
112 break;
113 }
114 }
115 if missing_input {
116 continue;
117 }
118 for physical in physical_rules.apply_all(logical, &inputs) {
119 let cost = cost_model.cost(&physical);
120 if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
121 best = Some((cost, physical));
122 }
123 }
124 }
125 MemoOperator::Physical(plan) => {
126 let cost = cost_model.cost(plan);
127 if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
128 best = Some((cost, plan.clone()));
129 }
130 }
131 }
132 }
133 if let Some(result) = best.clone() {
134 cache.insert(root, result.clone());
135 return Some(result);
136 }
137 None
138 }
139
140 pub fn explore(
141 &mut self,
142 rules: &crate::rules::RuleSet,
143 rule_config: &RuleConfig,
144 budget: &SearchBudget,
145 ) {
146 let max_rewrites = budget.max_rewrites.unwrap_or(usize::MAX);
147 let mut new_exprs = Vec::new();
148 let mut rewrites = 0usize;
149 let mut rule_ctx = RuleContext::default();
151 for group in &self.groups {
152 for expr in &group.expressions {
153 if let MemoOperator::Logical(plan) = &expr.operator {
154 for rule in rules.iter() {
155 if rewrites >= max_rewrites {
156 break;
157 }
158 if !rule_config.is_enabled(rule.name()) {
159 continue;
160 }
161 for rewritten in rule.apply(plan, &mut rule_ctx) {
162 if rewrites >= max_rewrites {
163 break;
164 }
165 new_exprs.push(GroupExpr {
166 operator: MemoOperator::Logical(rewritten),
167 children: expr.children.clone(),
168 });
169 rewrites += 1;
170 }
171 }
172 }
173 }
174 if rewrites >= max_rewrites {
175 break;
176 }
177 }
178 for expr in new_exprs {
179 if self.groups.len() >= budget.max_groups.unwrap_or(usize::MAX) {
180 break;
181 }
182 self.groups.push(Group {
183 expressions: vec![expr],
184 });
185 }
186 }
187
188 #[cfg(test)]
189 pub fn group_count(&self) -> usize {
190 self.groups.len()
191 }
192}
193
194#[derive(Debug)]
195pub struct Group {
196 expressions: Vec<GroupExpr>,
197}
198
199#[derive(Debug)]
200pub struct GroupExpr {
201 operator: MemoOperator,
202 children: Vec<GroupId>,
203}
204
205impl GroupExpr {
206 pub fn from_logical(plan: &LogicalPlan, memo: &mut Memo) -> Self {
207 match plan {
208 LogicalPlan::Scan { .. } => Self {
209 operator: MemoOperator::Logical(plan.clone()),
210 children: Vec::new(),
211 },
212 LogicalPlan::IndexScan { .. } => Self {
213 operator: MemoOperator::Logical(plan.clone()),
214 children: Vec::new(),
215 },
216 LogicalPlan::Dml { .. } => Self {
217 operator: MemoOperator::Logical(plan.clone()),
218 children: Vec::new(),
219 },
220 LogicalPlan::Derived { input, .. } => {
221 let child_group = memo.insert(input);
222 Self {
223 operator: MemoOperator::Logical(plan.clone()),
224 children: vec![child_group],
225 }
226 }
227 LogicalPlan::Filter { input, .. } => {
228 let child_group = memo.insert(input);
229 Self {
230 operator: MemoOperator::Logical(plan.clone()),
231 children: vec![child_group],
232 }
233 }
234 LogicalPlan::Projection { input, .. } => {
235 let child_group = memo.insert(input);
236 Self {
237 operator: MemoOperator::Logical(plan.clone()),
238 children: vec![child_group],
239 }
240 }
241 LogicalPlan::Join { left, right, .. } => {
242 let left_group = memo.insert(left);
243 let right_group = memo.insert(right);
244 Self {
245 operator: MemoOperator::Logical(plan.clone()),
246 children: vec![left_group, right_group],
247 }
248 }
249 LogicalPlan::Aggregate { input, .. } => {
250 let child_group = memo.insert(input);
251 Self {
252 operator: MemoOperator::Logical(plan.clone()),
253 children: vec![child_group],
254 }
255 }
256 LogicalPlan::Distinct { input } => {
257 let child_group = memo.insert(input);
258 Self {
259 operator: MemoOperator::Logical(plan.clone()),
260 children: vec![child_group],
261 }
262 }
263 LogicalPlan::TopN { input, .. } => {
264 let child_group = memo.insert(input);
265 Self {
266 operator: MemoOperator::Logical(plan.clone()),
267 children: vec![child_group],
268 }
269 }
270 LogicalPlan::Sort { input, .. } => {
271 let child_group = memo.insert(input);
272 Self {
273 operator: MemoOperator::Logical(plan.clone()),
274 children: vec![child_group],
275 }
276 }
277 LogicalPlan::Limit { input, .. } => {
278 let child_group = memo.insert(input);
279 Self {
280 operator: MemoOperator::Logical(plan.clone()),
281 children: vec![child_group],
282 }
283 }
284 }
285 }
286
287 pub fn to_physical(&self, memo: &Memo) -> Option<PhysicalPlan> {
288 match &self.operator {
289 MemoOperator::Logical(plan) => Some(logical_to_physical(plan, memo)),
290 MemoOperator::Physical(plan) => Some(plan.clone()),
291 }
292 }
293}
294
295#[derive(Debug, Clone)]
296pub enum MemoOperator {
297 Logical(LogicalPlan),
298 Physical(PhysicalPlan),
299}
300
301fn logical_to_physical(logical: &LogicalPlan, memo: &Memo) -> PhysicalPlan {
302 match logical {
303 LogicalPlan::Scan { table } => PhysicalPlan::TableScan {
304 table: table.clone(),
305 },
306 LogicalPlan::IndexScan {
307 table,
308 index,
309 predicate,
310 } => PhysicalPlan::IndexScan {
311 table: table.clone(),
312 index: index.clone(),
313 predicate: predicate.clone(),
314 },
315 LogicalPlan::Dml { sql } => PhysicalPlan::Dml { sql: sql.clone() },
316 LogicalPlan::Derived {
317 input,
318 alias,
319 column_aliases,
320 } => PhysicalPlan::Derived {
321 input: Box::new(logical_to_physical(input, memo)),
322 alias: alias.clone(),
323 column_aliases: column_aliases.clone(),
324 },
325 LogicalPlan::Filter { predicate, input } => PhysicalPlan::Filter {
326 predicate: predicate.clone(),
327 input: Box::new(logical_to_physical(input, memo)),
328 },
329 LogicalPlan::Projection { exprs, input } => PhysicalPlan::Projection {
330 exprs: exprs.clone(),
331 input: Box::new(logical_to_physical(input, memo)),
332 },
333 LogicalPlan::Join {
334 join_type,
335 left,
336 right,
337 on,
338 } => PhysicalPlan::Join {
339 join_type: *join_type,
340 algorithm: chryso_planner::JoinAlgorithm::Hash,
341 left: Box::new(logical_to_physical(left, memo)),
342 right: Box::new(logical_to_physical(right, memo)),
343 on: on.clone(),
344 },
345 LogicalPlan::Aggregate {
346 group_exprs,
347 aggr_exprs,
348 input,
349 } => PhysicalPlan::Aggregate {
350 group_exprs: group_exprs.clone(),
351 aggr_exprs: aggr_exprs.clone(),
352 input: Box::new(logical_to_physical(input, memo)),
353 },
354 LogicalPlan::Distinct { input } => PhysicalPlan::Distinct {
355 input: Box::new(logical_to_physical(input, memo)),
356 },
357 LogicalPlan::TopN {
358 order_by,
359 limit,
360 input,
361 } => PhysicalPlan::TopN {
362 order_by: order_by.clone(),
363 limit: *limit,
364 input: Box::new(logical_to_physical(input, memo)),
365 },
366 LogicalPlan::Sort { order_by, input } => PhysicalPlan::Sort {
367 order_by: order_by.clone(),
368 input: Box::new(logical_to_physical(input, memo)),
369 },
370 LogicalPlan::Limit {
371 limit,
372 offset,
373 input,
374 } => PhysicalPlan::Limit {
375 limit: *limit,
376 offset: *offset,
377 input: Box::new(logical_to_physical(input, memo)),
378 },
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use super::Memo;
385 use crate::rules::{RemoveTrueFilter, RuleSet};
386 use crate::{RuleConfig, SearchBudget};
387 use chryso_planner::LogicalPlan;
388
389 #[test]
390 fn memo_inserts_child_groups() {
391 let plan = LogicalPlan::Filter {
392 predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
393 input: Box::new(LogicalPlan::Scan {
394 table: "users".to_string(),
395 }),
396 };
397 let mut memo = Memo::new();
398 memo.insert(&plan);
399 assert!(memo.group_count() >= 2);
400 }
401
402 #[test]
403 fn memo_respects_max_rewrites_budget() {
404 let plan = LogicalPlan::Filter {
405 predicate: chryso_core::ast::Expr::Literal(chryso_core::ast::Literal::Bool(true)),
406 input: Box::new(LogicalPlan::Scan {
407 table: "users".to_string(),
408 }),
409 };
410 let rules = RuleSet::new().with_rule(RemoveTrueFilter);
411 let mut memo = Memo::new();
412 memo.insert(&plan);
413 let initial_groups = memo.group_count();
414 let budget = SearchBudget {
415 max_groups: None,
416 max_rewrites: Some(0),
417 };
418 memo.explore(&rules, &RuleConfig::default(), &budget);
419 assert_eq!(memo.group_count(), initial_groups);
420 }
421
422 #[test]
423 fn memo_respects_max_groups_budget() {
424 let plan = LogicalPlan::Filter {
425 predicate: chryso_core::ast::Expr::Literal(chryso_core::ast::Literal::Bool(true)),
426 input: Box::new(LogicalPlan::Scan {
427 table: "users".to_string(),
428 }),
429 };
430 let rules = RuleSet::new().with_rule(RemoveTrueFilter);
431 let mut memo = Memo::new();
432 memo.insert(&plan);
433 let initial_groups = memo.group_count();
434 let budget = SearchBudget {
435 max_groups: Some(initial_groups),
436 max_rewrites: None,
437 };
438 memo.explore(&rules, &RuleConfig::default(), &budget);
439 assert_eq!(memo.group_count(), initial_groups);
440 }
441}