Skip to main content

chryso_optimizer/
memo.rs

1use crate::cost::{Cost, CostModel};
2use crate::physical_rules::PhysicalRuleSet;
3use crate::rules::RuleContext;
4use crate::{MemoTrace, MemoTraceCandidate, MemoTraceGroup, RuleConfig, SearchBudget};
5use chryso_planner::{LogicalPlan, PhysicalPlan};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct GroupId(usize);
10
11#[derive(Debug)]
12pub struct Memo {
13    groups: Vec<Group>,
14}
15
16impl Memo {
17    pub fn new() -> Self {
18        Self { groups: Vec::new() }
19    }
20
21    pub fn insert(&mut self, plan: &LogicalPlan) -> GroupId {
22        let expr = GroupExpr::from_logical(plan, self);
23        let group_id = GroupId(self.groups.len());
24        self.groups.push(Group {
25            expressions: vec![expr],
26        });
27        group_id
28    }
29
30    pub fn best_physical(
31        &self,
32        root: GroupId,
33        physical_rules: &PhysicalRuleSet,
34        cost_model: &dyn CostModel,
35    ) -> Option<PhysicalPlan> {
36        let mut cache = HashMap::new();
37        self.best_physical_with_cache(root, physical_rules, cost_model, &mut cache)
38            .map(|(_, plan)| plan)
39    }
40
41    pub fn trace(&self, physical_rules: &PhysicalRuleSet, cost_model: &dyn CostModel) -> MemoTrace {
42        let mut cache = HashMap::new();
43        let mut groups = Vec::with_capacity(self.groups.len());
44        for (group_id, group) in self.groups.iter().enumerate() {
45            let mut candidates = Vec::new();
46            for expr in &group.expressions {
47                let MemoOperator::Logical(logical) = &expr.operator else {
48                    continue;
49                };
50                let mut inputs = Vec::new();
51                let mut missing_input = false;
52                for child in &expr.children {
53                    if let Some((_, best)) = self.best_physical_with_cache(
54                        *child,
55                        physical_rules,
56                        cost_model,
57                        &mut cache,
58                    ) {
59                        inputs.push(best);
60                    } else {
61                        missing_input = true;
62                        break;
63                    }
64                }
65                if missing_input {
66                    continue;
67                }
68                for physical in physical_rules.apply_all(logical, &inputs) {
69                    let cost = cost_model.cost(&physical).0;
70                    let plan = physical.explain_costed(0, cost_model);
71                    candidates.push(MemoTraceCandidate { cost, plan });
72                }
73            }
74            candidates.sort_by(|left, right| {
75                left.cost
76                    .partial_cmp(&right.cost)
77                    .unwrap_or(std::cmp::Ordering::Equal)
78                    .then_with(|| left.plan.cmp(&right.plan))
79            });
80            groups.push(MemoTraceGroup {
81                id: group_id,
82                candidates,
83            });
84        }
85        MemoTrace { groups }
86    }
87
88    fn best_physical_with_cache(
89        &self,
90        root: GroupId,
91        physical_rules: &PhysicalRuleSet,
92        cost_model: &dyn CostModel,
93        cache: &mut HashMap<GroupId, (Cost, PhysicalPlan)>,
94    ) -> Option<(Cost, PhysicalPlan)> {
95        if let Some(cached) = cache.get(&root) {
96            return Some(cached.clone());
97        }
98        let group = self.groups.get(root.0)?;
99        let mut best: Option<(Cost, PhysicalPlan)> = None;
100        for expr in &group.expressions {
101            match &expr.operator {
102                MemoOperator::Logical(logical) => {
103                    let mut inputs = Vec::new();
104                    let mut missing_input = false;
105                    for child in &expr.children {
106                        if let Some((_, best_child)) =
107                            self.best_physical_with_cache(*child, physical_rules, cost_model, cache)
108                        {
109                            inputs.push(best_child);
110                        } else {
111                            missing_input = true;
112                            break;
113                        }
114                    }
115                    if missing_input {
116                        continue;
117                    }
118                    for physical in physical_rules.apply_all(logical, &inputs) {
119                        let cost = cost_model.cost(&physical);
120                        if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
121                            best = Some((cost, physical));
122                        }
123                    }
124                }
125                MemoOperator::Physical(plan) => {
126                    let cost = cost_model.cost(plan);
127                    if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
128                        best = Some((cost, plan.clone()));
129                    }
130                }
131            }
132        }
133        if let Some(result) = best.clone() {
134            cache.insert(root, result.clone());
135            return Some(result);
136        }
137        None
138    }
139
140    pub fn explore(
141        &mut self,
142        rules: &crate::rules::RuleSet,
143        rule_config: &RuleConfig,
144        budget: &SearchBudget,
145    ) {
146        let max_rewrites = budget.max_rewrites.unwrap_or(usize::MAX);
147        let mut new_exprs = Vec::new();
148        let mut rewrites = 0usize;
149        // RuleContext tracks side-channel information (e.g., literal conflicts) while exploring.
150        let mut rule_ctx = RuleContext::default();
151        for group in &self.groups {
152            for expr in &group.expressions {
153                if let MemoOperator::Logical(plan) = &expr.operator {
154                    for rule in rules.iter() {
155                        if rewrites >= max_rewrites {
156                            break;
157                        }
158                        if !rule_config.is_enabled(rule.name()) {
159                            continue;
160                        }
161                        for rewritten in rule.apply(plan, &mut rule_ctx) {
162                            if rewrites >= max_rewrites {
163                                break;
164                            }
165                            new_exprs.push(GroupExpr {
166                                operator: MemoOperator::Logical(rewritten),
167                                children: expr.children.clone(),
168                            });
169                            rewrites += 1;
170                        }
171                    }
172                }
173            }
174            if rewrites >= max_rewrites {
175                break;
176            }
177        }
178        for expr in new_exprs {
179            if self.groups.len() >= budget.max_groups.unwrap_or(usize::MAX) {
180                break;
181            }
182            self.groups.push(Group {
183                expressions: vec![expr],
184            });
185        }
186    }
187
188    #[cfg(test)]
189    pub fn group_count(&self) -> usize {
190        self.groups.len()
191    }
192}
193
194#[derive(Debug)]
195pub struct Group {
196    expressions: Vec<GroupExpr>,
197}
198
199#[derive(Debug)]
200pub struct GroupExpr {
201    operator: MemoOperator,
202    children: Vec<GroupId>,
203}
204
205impl GroupExpr {
206    pub fn from_logical(plan: &LogicalPlan, memo: &mut Memo) -> Self {
207        match plan {
208            LogicalPlan::Scan { .. } => Self {
209                operator: MemoOperator::Logical(plan.clone()),
210                children: Vec::new(),
211            },
212            LogicalPlan::IndexScan { .. } => Self {
213                operator: MemoOperator::Logical(plan.clone()),
214                children: Vec::new(),
215            },
216            LogicalPlan::Dml { .. } => Self {
217                operator: MemoOperator::Logical(plan.clone()),
218                children: Vec::new(),
219            },
220            LogicalPlan::Derived { input, .. } => {
221                let child_group = memo.insert(input);
222                Self {
223                    operator: MemoOperator::Logical(plan.clone()),
224                    children: vec![child_group],
225                }
226            }
227            LogicalPlan::Filter { input, .. } => {
228                let child_group = memo.insert(input);
229                Self {
230                    operator: MemoOperator::Logical(plan.clone()),
231                    children: vec![child_group],
232                }
233            }
234            LogicalPlan::Projection { input, .. } => {
235                let child_group = memo.insert(input);
236                Self {
237                    operator: MemoOperator::Logical(plan.clone()),
238                    children: vec![child_group],
239                }
240            }
241            LogicalPlan::Join { left, right, .. } => {
242                let left_group = memo.insert(left);
243                let right_group = memo.insert(right);
244                Self {
245                    operator: MemoOperator::Logical(plan.clone()),
246                    children: vec![left_group, right_group],
247                }
248            }
249            LogicalPlan::Aggregate { input, .. } => {
250                let child_group = memo.insert(input);
251                Self {
252                    operator: MemoOperator::Logical(plan.clone()),
253                    children: vec![child_group],
254                }
255            }
256            LogicalPlan::Distinct { input } => {
257                let child_group = memo.insert(input);
258                Self {
259                    operator: MemoOperator::Logical(plan.clone()),
260                    children: vec![child_group],
261                }
262            }
263            LogicalPlan::TopN { input, .. } => {
264                let child_group = memo.insert(input);
265                Self {
266                    operator: MemoOperator::Logical(plan.clone()),
267                    children: vec![child_group],
268                }
269            }
270            LogicalPlan::Sort { input, .. } => {
271                let child_group = memo.insert(input);
272                Self {
273                    operator: MemoOperator::Logical(plan.clone()),
274                    children: vec![child_group],
275                }
276            }
277            LogicalPlan::Limit { input, .. } => {
278                let child_group = memo.insert(input);
279                Self {
280                    operator: MemoOperator::Logical(plan.clone()),
281                    children: vec![child_group],
282                }
283            }
284        }
285    }
286
287    pub fn to_physical(&self, memo: &Memo) -> Option<PhysicalPlan> {
288        match &self.operator {
289            MemoOperator::Logical(plan) => Some(logical_to_physical(plan, memo)),
290            MemoOperator::Physical(plan) => Some(plan.clone()),
291        }
292    }
293}
294
295#[derive(Debug, Clone)]
296pub enum MemoOperator {
297    Logical(LogicalPlan),
298    Physical(PhysicalPlan),
299}
300
301fn logical_to_physical(logical: &LogicalPlan, memo: &Memo) -> PhysicalPlan {
302    match logical {
303        LogicalPlan::Scan { table } => PhysicalPlan::TableScan {
304            table: table.clone(),
305        },
306        LogicalPlan::IndexScan {
307            table,
308            index,
309            predicate,
310        } => PhysicalPlan::IndexScan {
311            table: table.clone(),
312            index: index.clone(),
313            predicate: predicate.clone(),
314        },
315        LogicalPlan::Dml { sql } => PhysicalPlan::Dml { sql: sql.clone() },
316        LogicalPlan::Derived {
317            input,
318            alias,
319            column_aliases,
320        } => PhysicalPlan::Derived {
321            input: Box::new(logical_to_physical(input, memo)),
322            alias: alias.clone(),
323            column_aliases: column_aliases.clone(),
324        },
325        LogicalPlan::Filter { predicate, input } => PhysicalPlan::Filter {
326            predicate: predicate.clone(),
327            input: Box::new(logical_to_physical(input, memo)),
328        },
329        LogicalPlan::Projection { exprs, input } => PhysicalPlan::Projection {
330            exprs: exprs.clone(),
331            input: Box::new(logical_to_physical(input, memo)),
332        },
333        LogicalPlan::Join {
334            join_type,
335            left,
336            right,
337            on,
338        } => PhysicalPlan::Join {
339            join_type: *join_type,
340            algorithm: chryso_planner::JoinAlgorithm::Hash,
341            left: Box::new(logical_to_physical(left, memo)),
342            right: Box::new(logical_to_physical(right, memo)),
343            on: on.clone(),
344        },
345        LogicalPlan::Aggregate {
346            group_exprs,
347            aggr_exprs,
348            input,
349        } => PhysicalPlan::Aggregate {
350            group_exprs: group_exprs.clone(),
351            aggr_exprs: aggr_exprs.clone(),
352            input: Box::new(logical_to_physical(input, memo)),
353        },
354        LogicalPlan::Distinct { input } => PhysicalPlan::Distinct {
355            input: Box::new(logical_to_physical(input, memo)),
356        },
357        LogicalPlan::TopN {
358            order_by,
359            limit,
360            input,
361        } => PhysicalPlan::TopN {
362            order_by: order_by.clone(),
363            limit: *limit,
364            input: Box::new(logical_to_physical(input, memo)),
365        },
366        LogicalPlan::Sort { order_by, input } => PhysicalPlan::Sort {
367            order_by: order_by.clone(),
368            input: Box::new(logical_to_physical(input, memo)),
369        },
370        LogicalPlan::Limit {
371            limit,
372            offset,
373            input,
374        } => PhysicalPlan::Limit {
375            limit: *limit,
376            offset: *offset,
377            input: Box::new(logical_to_physical(input, memo)),
378        },
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use super::Memo;
385    use crate::rules::{RemoveTrueFilter, RuleSet};
386    use crate::{RuleConfig, SearchBudget};
387    use chryso_planner::LogicalPlan;
388
389    #[test]
390    fn memo_inserts_child_groups() {
391        let plan = LogicalPlan::Filter {
392            predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
393            input: Box::new(LogicalPlan::Scan {
394                table: "users".to_string(),
395            }),
396        };
397        let mut memo = Memo::new();
398        memo.insert(&plan);
399        assert!(memo.group_count() >= 2);
400    }
401
402    #[test]
403    fn memo_respects_max_rewrites_budget() {
404        let plan = LogicalPlan::Filter {
405            predicate: chryso_core::ast::Expr::Literal(chryso_core::ast::Literal::Bool(true)),
406            input: Box::new(LogicalPlan::Scan {
407                table: "users".to_string(),
408            }),
409        };
410        let rules = RuleSet::new().with_rule(RemoveTrueFilter);
411        let mut memo = Memo::new();
412        memo.insert(&plan);
413        let initial_groups = memo.group_count();
414        let budget = SearchBudget {
415            max_groups: None,
416            max_rewrites: Some(0),
417        };
418        memo.explore(&rules, &RuleConfig::default(), &budget);
419        assert_eq!(memo.group_count(), initial_groups);
420    }
421
422    #[test]
423    fn memo_respects_max_groups_budget() {
424        let plan = LogicalPlan::Filter {
425            predicate: chryso_core::ast::Expr::Literal(chryso_core::ast::Literal::Bool(true)),
426            input: Box::new(LogicalPlan::Scan {
427                table: "users".to_string(),
428            }),
429        };
430        let rules = RuleSet::new().with_rule(RemoveTrueFilter);
431        let mut memo = Memo::new();
432        memo.insert(&plan);
433        let initial_groups = memo.group_count();
434        let budget = SearchBudget {
435            max_groups: Some(initial_groups),
436            max_rewrites: None,
437        };
438        memo.explore(&rules, &RuleConfig::default(), &budget);
439        assert_eq!(memo.group_count(), initial_groups);
440    }
441}