1use crate::cost::{Cost, CostModel};
2use chryso_planner::{LogicalPlan, PhysicalPlan};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
5pub struct GroupId(usize);
6
7#[derive(Debug)]
8pub struct Memo {
9 groups: Vec<Group>,
10}
11
12impl Memo {
13 pub fn new() -> Self {
14 Self { groups: Vec::new() }
15 }
16
17 pub fn insert(&mut self, plan: &LogicalPlan) -> GroupId {
18 let expr = GroupExpr::from_logical(plan, self);
19 let group_id = GroupId(self.groups.len());
20 self.groups.push(Group {
21 expressions: vec![expr],
22 });
23 group_id
24 }
25
26 pub fn best_physical(&self, root: GroupId, cost_model: &dyn CostModel) -> Option<PhysicalPlan> {
27 let group = self.groups.get(root.0)?;
28 let mut best: Option<(Cost, PhysicalPlan)> = None;
29 for expr in &group.expressions {
30 if let Some(physical) = expr.to_physical(self) {
31 let cost = cost_model.cost(&physical);
32 if best.as_ref().map(|(c, _)| cost.0 < c.0).unwrap_or(true) {
33 best = Some((cost, physical));
34 }
35 }
36 }
37 best.map(|(_, plan)| plan)
38 }
39
40 pub fn explore(&mut self, rules: &crate::rules::RuleSet) {
41 let mut new_exprs = Vec::new();
42 for group in &self.groups {
43 for expr in &group.expressions {
44 if let MemoOperator::Logical(plan) = &expr.operator {
45 for rewritten in rules.apply_all(plan) {
46 new_exprs.push(GroupExpr {
47 operator: MemoOperator::Logical(rewritten),
48 children: expr.children.clone(),
49 });
50 }
51 }
52 }
53 }
54 for expr in new_exprs {
55 self.groups.push(Group {
56 expressions: vec![expr],
57 });
58 }
59 }
60
61 #[cfg(test)]
62 pub fn group_count(&self) -> usize {
63 self.groups.len()
64 }
65}
66
67#[derive(Debug)]
68pub struct Group {
69 expressions: Vec<GroupExpr>,
70}
71
72#[derive(Debug)]
73pub struct GroupExpr {
74 operator: MemoOperator,
75 children: Vec<GroupId>,
76}
77
78impl GroupExpr {
79 pub fn from_logical(plan: &LogicalPlan, memo: &mut Memo) -> Self {
80 match plan {
81 LogicalPlan::Scan { .. } => Self {
82 operator: MemoOperator::Logical(plan.clone()),
83 children: Vec::new(),
84 },
85 LogicalPlan::IndexScan { .. } => Self {
86 operator: MemoOperator::Logical(plan.clone()),
87 children: Vec::new(),
88 },
89 LogicalPlan::Dml { .. } => Self {
90 operator: MemoOperator::Logical(plan.clone()),
91 children: Vec::new(),
92 },
93 LogicalPlan::Derived { input, .. } => {
94 let child_group = memo.insert(input);
95 Self {
96 operator: MemoOperator::Logical(plan.clone()),
97 children: vec![child_group],
98 }
99 }
100 LogicalPlan::Filter { input, .. } => {
101 let child_group = memo.insert(input);
102 Self {
103 operator: MemoOperator::Logical(plan.clone()),
104 children: vec![child_group],
105 }
106 }
107 LogicalPlan::Projection { input, .. } => {
108 let child_group = memo.insert(input);
109 Self {
110 operator: MemoOperator::Logical(plan.clone()),
111 children: vec![child_group],
112 }
113 }
114 LogicalPlan::Join { left, right, .. } => {
115 let left_group = memo.insert(left);
116 let right_group = memo.insert(right);
117 Self {
118 operator: MemoOperator::Logical(plan.clone()),
119 children: vec![left_group, right_group],
120 }
121 }
122 LogicalPlan::Aggregate { input, .. } => {
123 let child_group = memo.insert(input);
124 Self {
125 operator: MemoOperator::Logical(plan.clone()),
126 children: vec![child_group],
127 }
128 }
129 LogicalPlan::Distinct { input } => {
130 let child_group = memo.insert(input);
131 Self {
132 operator: MemoOperator::Logical(plan.clone()),
133 children: vec![child_group],
134 }
135 }
136 LogicalPlan::TopN { input, .. } => {
137 let child_group = memo.insert(input);
138 Self {
139 operator: MemoOperator::Logical(plan.clone()),
140 children: vec![child_group],
141 }
142 }
143 LogicalPlan::Sort { input, .. } => {
144 let child_group = memo.insert(input);
145 Self {
146 operator: MemoOperator::Logical(plan.clone()),
147 children: vec![child_group],
148 }
149 }
150 LogicalPlan::Limit { input, .. } => {
151 let child_group = memo.insert(input);
152 Self {
153 operator: MemoOperator::Logical(plan.clone()),
154 children: vec![child_group],
155 }
156 }
157 }
158 }
159
160 pub fn to_physical(&self, memo: &Memo) -> Option<PhysicalPlan> {
161 match &self.operator {
162 MemoOperator::Logical(plan) => Some(logical_to_physical(plan, memo)),
163 MemoOperator::Physical(plan) => Some(plan.clone()),
164 }
165 }
166}
167
168#[derive(Debug, Clone)]
169pub enum MemoOperator {
170 Logical(LogicalPlan),
171 Physical(PhysicalPlan),
172}
173
174fn logical_to_physical(logical: &LogicalPlan, memo: &Memo) -> PhysicalPlan {
175 match logical {
176 LogicalPlan::Scan { table } => PhysicalPlan::TableScan {
177 table: table.clone(),
178 },
179 LogicalPlan::IndexScan {
180 table,
181 index,
182 predicate,
183 } => PhysicalPlan::IndexScan {
184 table: table.clone(),
185 index: index.clone(),
186 predicate: predicate.clone(),
187 },
188 LogicalPlan::Dml { sql } => PhysicalPlan::Dml { sql: sql.clone() },
189 LogicalPlan::Derived {
190 input,
191 alias,
192 column_aliases,
193 } => PhysicalPlan::Derived {
194 input: Box::new(logical_to_physical(input, memo)),
195 alias: alias.clone(),
196 column_aliases: column_aliases.clone(),
197 },
198 LogicalPlan::Filter { predicate, input } => PhysicalPlan::Filter {
199 predicate: predicate.clone(),
200 input: Box::new(logical_to_physical(input, memo)),
201 },
202 LogicalPlan::Projection { exprs, input } => PhysicalPlan::Projection {
203 exprs: exprs.clone(),
204 input: Box::new(logical_to_physical(input, memo)),
205 },
206 LogicalPlan::Join {
207 join_type,
208 left,
209 right,
210 on,
211 } => PhysicalPlan::Join {
212 join_type: *join_type,
213 algorithm: chryso_planner::JoinAlgorithm::Hash,
214 left: Box::new(logical_to_physical(left, memo)),
215 right: Box::new(logical_to_physical(right, memo)),
216 on: on.clone(),
217 },
218 LogicalPlan::Aggregate {
219 group_exprs,
220 aggr_exprs,
221 input,
222 } => PhysicalPlan::Aggregate {
223 group_exprs: group_exprs.clone(),
224 aggr_exprs: aggr_exprs.clone(),
225 input: Box::new(logical_to_physical(input, memo)),
226 },
227 LogicalPlan::Distinct { input } => PhysicalPlan::Distinct {
228 input: Box::new(logical_to_physical(input, memo)),
229 },
230 LogicalPlan::TopN {
231 order_by,
232 limit,
233 input,
234 } => PhysicalPlan::TopN {
235 order_by: order_by.clone(),
236 limit: *limit,
237 input: Box::new(logical_to_physical(input, memo)),
238 },
239 LogicalPlan::Sort { order_by, input } => PhysicalPlan::Sort {
240 order_by: order_by.clone(),
241 input: Box::new(logical_to_physical(input, memo)),
242 },
243 LogicalPlan::Limit {
244 limit,
245 offset,
246 input,
247 } => PhysicalPlan::Limit {
248 limit: *limit,
249 offset: *offset,
250 input: Box::new(logical_to_physical(input, memo)),
251 },
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::Memo;
258 use chryso_planner::LogicalPlan;
259
260 #[test]
261 fn memo_inserts_child_groups() {
262 let plan = LogicalPlan::Filter {
263 predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
264 input: Box::new(LogicalPlan::Scan {
265 table: "users".to_string(),
266 }),
267 };
268 let mut memo = Memo::new();
269 memo.insert(&plan);
270 assert!(memo.group_count() >= 2);
271 }
272}