chryso_optimizer/
memo.rs

1use crate::cost::{Cost, CostModel};
2use chryso_planner::{LogicalPlan, PhysicalPlan};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub struct GroupId(usize);
6
7#[derive(Debug)]
8pub struct Memo {
9    groups: Vec<Group>,
10}
11
12impl Memo {
13    pub fn new() -> Self {
14        Self { groups: Vec::new() }
15    }
16
17    pub fn insert(&mut self, plan: &LogicalPlan) -> GroupId {
18        let expr = GroupExpr::from_logical(plan, self);
19        let group_id = GroupId(self.groups.len());
20        self.groups.push(Group {
21            expressions: vec![expr],
22        });
23        group_id
24    }
25
26    pub fn best_physical(&self, root: GroupId, cost_model: &dyn CostModel) -> Option<PhysicalPlan> {
27        let group = self.groups.get(root.0)?;
28        let mut best: Option<(Cost, PhysicalPlan)> = None;
29        for expr in &group.expressions {
30            if let Some(physical) = expr.to_physical(self) {
31                let cost = cost_model.cost(&physical);
32                if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
33                    best = Some((cost, physical));
34                }
35            }
36        }
37        best.map(|(_, plan)| plan)
38    }
39
40    pub fn explore(&mut self, rules: &crate::rules::RuleSet) {
41        let mut new_exprs = Vec::new();
42        for group in &self.groups {
43            for expr in &group.expressions {
44                if let MemoOperator::Logical(plan) = &expr.operator {
45                    for rewritten in rules.apply_all(plan) {
46                        new_exprs.push(GroupExpr {
47                            operator: MemoOperator::Logical(rewritten),
48                            children: expr.children.clone(),
49                        });
50                    }
51                }
52            }
53        }
54        for expr in new_exprs {
55            self.groups.push(Group {
56                expressions: vec![expr],
57            });
58        }
59    }
60
61    #[cfg(test)]
62    pub fn group_count(&self) -> usize {
63        self.groups.len()
64    }
65}
66
67#[derive(Debug)]
68pub struct Group {
69    expressions: Vec<GroupExpr>,
70}
71
72#[derive(Debug)]
73pub struct GroupExpr {
74    operator: MemoOperator,
75    children: Vec<GroupId>,
76}
77
78impl GroupExpr {
79    pub fn from_logical(plan: &LogicalPlan, memo: &mut Memo) -> Self {
80        match plan {
81            LogicalPlan::Scan { .. } => Self {
82                operator: MemoOperator::Logical(plan.clone()),
83                children: Vec::new(),
84            },
85            LogicalPlan::IndexScan { .. } => Self {
86                operator: MemoOperator::Logical(plan.clone()),
87                children: Vec::new(),
88            },
89            LogicalPlan::Dml { .. } => Self {
90                operator: MemoOperator::Logical(plan.clone()),
91                children: Vec::new(),
92            },
93            LogicalPlan::Derived { input, .. } => {
94                let child_group = memo.insert(input);
95                Self {
96                    operator: MemoOperator::Logical(plan.clone()),
97                    children: vec![child_group],
98                }
99            }
100            LogicalPlan::Filter { input, .. } => {
101                let child_group = memo.insert(input);
102                Self {
103                    operator: MemoOperator::Logical(plan.clone()),
104                    children: vec![child_group],
105                }
106            }
107            LogicalPlan::Projection { input, .. } => {
108                let child_group = memo.insert(input);
109                Self {
110                    operator: MemoOperator::Logical(plan.clone()),
111                    children: vec![child_group],
112                }
113            }
114            LogicalPlan::Join { left, right, .. } => {
115                let left_group = memo.insert(left);
116                let right_group = memo.insert(right);
117                Self {
118                    operator: MemoOperator::Logical(plan.clone()),
119                    children: vec![left_group, right_group],
120                }
121            }
122            LogicalPlan::Aggregate { input, .. } => {
123                let child_group = memo.insert(input);
124                Self {
125                    operator: MemoOperator::Logical(plan.clone()),
126                    children: vec![child_group],
127                }
128            }
129            LogicalPlan::Distinct { input } => {
130                let child_group = memo.insert(input);
131                Self {
132                    operator: MemoOperator::Logical(plan.clone()),
133                    children: vec![child_group],
134                }
135            }
136            LogicalPlan::TopN { input, .. } => {
137                let child_group = memo.insert(input);
138                Self {
139                    operator: MemoOperator::Logical(plan.clone()),
140                    children: vec![child_group],
141                }
142            }
143            LogicalPlan::Sort { input, .. } => {
144                let child_group = memo.insert(input);
145                Self {
146                    operator: MemoOperator::Logical(plan.clone()),
147                    children: vec![child_group],
148                }
149            }
150            LogicalPlan::Limit { input, .. } => {
151                let child_group = memo.insert(input);
152                Self {
153                    operator: MemoOperator::Logical(plan.clone()),
154                    children: vec![child_group],
155                }
156            }
157        }
158    }
159
160    pub fn to_physical(&self, memo: &Memo) -> Option<PhysicalPlan> {
161        match &self.operator {
162            MemoOperator::Logical(plan) => Some(logical_to_physical(plan, memo)),
163            MemoOperator::Physical(plan) => Some(plan.clone()),
164        }
165    }
166}
167
168#[derive(Debug, Clone)]
169pub enum MemoOperator {
170    Logical(LogicalPlan),
171    Physical(PhysicalPlan),
172}
173
174fn logical_to_physical(logical: &LogicalPlan, memo: &Memo) -> PhysicalPlan {
175    match logical {
176        LogicalPlan::Scan { table } => PhysicalPlan::TableScan {
177            table: table.clone(),
178        },
179        LogicalPlan::IndexScan {
180            table,
181            index,
182            predicate,
183        } => PhysicalPlan::IndexScan {
184            table: table.clone(),
185            index: index.clone(),
186            predicate: predicate.clone(),
187        },
188        LogicalPlan::Dml { sql } => PhysicalPlan::Dml { sql: sql.clone() },
189        LogicalPlan::Derived {
190            input,
191            alias,
192            column_aliases,
193        } => PhysicalPlan::Derived {
194            input: Box::new(logical_to_physical(input, memo)),
195            alias: alias.clone(),
196            column_aliases: column_aliases.clone(),
197        },
198        LogicalPlan::Filter { predicate, input } => PhysicalPlan::Filter {
199            predicate: predicate.clone(),
200            input: Box::new(logical_to_physical(input, memo)),
201        },
202        LogicalPlan::Projection { exprs, input } => PhysicalPlan::Projection {
203            exprs: exprs.clone(),
204            input: Box::new(logical_to_physical(input, memo)),
205        },
206        LogicalPlan::Join {
207            join_type,
208            left,
209            right,
210            on,
211        } => PhysicalPlan::Join {
212            join_type: *join_type,
213            algorithm: chryso_planner::JoinAlgorithm::Hash,
214            left: Box::new(logical_to_physical(left, memo)),
215            right: Box::new(logical_to_physical(right, memo)),
216            on: on.clone(),
217        },
218        LogicalPlan::Aggregate {
219            group_exprs,
220            aggr_exprs,
221            input,
222        } => PhysicalPlan::Aggregate {
223            group_exprs: group_exprs.clone(),
224            aggr_exprs: aggr_exprs.clone(),
225            input: Box::new(logical_to_physical(input, memo)),
226        },
227        LogicalPlan::Distinct { input } => PhysicalPlan::Distinct {
228            input: Box::new(logical_to_physical(input, memo)),
229        },
230        LogicalPlan::TopN {
231            order_by,
232            limit,
233            input,
234        } => PhysicalPlan::TopN {
235            order_by: order_by.clone(),
236            limit: *limit,
237            input: Box::new(logical_to_physical(input, memo)),
238        },
239        LogicalPlan::Sort { order_by, input } => PhysicalPlan::Sort {
240            order_by: order_by.clone(),
241            input: Box::new(logical_to_physical(input, memo)),
242        },
243        LogicalPlan::Limit {
244            limit,
245            offset,
246            input,
247        } => PhysicalPlan::Limit {
248            limit: *limit,
249            offset: *offset,
250            input: Box::new(logical_to_physical(input, memo)),
251        },
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::Memo;
258    use chryso_planner::LogicalPlan;
259
260    #[test]
261    fn memo_inserts_child_groups() {
262        let plan = LogicalPlan::Filter {
263            predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
264            input: Box::new(LogicalPlan::Scan {
265                table: "users".to_string(),
266            }),
267        };
268        let mut memo = Memo::new();
269        memo.insert(&plan);
270        assert!(memo.group_count() >= 2);
271    }
272}