1use crate::ast::{BinaryOp, Expr};
45use crate::context::ExecutionContext;
46use crate::planner::PhysicalPlan;
47use alloc::boxed::Box;
48use alloc::string::String;
49use alloc::vec::Vec;
50
51#[derive(Clone, Debug)]
53pub struct MultiColumnOrConfig {
54 pub min_table_size: usize,
56 pub max_selectivity: f64,
58 pub max_branches: usize,
60}
61
62impl Default for MultiColumnOrConfig {
63 fn default() -> Self {
64 Self {
65 min_table_size: 10000,
66 max_selectivity: 0.01,
67 max_branches: 5,
68 }
69 }
70}
71
72pub struct MultiColumnOrPass<'a> {
74 ctx: &'a ExecutionContext,
75 config: MultiColumnOrConfig,
76}
77
78impl<'a> MultiColumnOrPass<'a> {
79 pub fn new(ctx: &'a ExecutionContext) -> Self {
81 Self {
82 ctx,
83 config: MultiColumnOrConfig::default(),
84 }
85 }
86
87 pub fn with_config(ctx: &'a ExecutionContext, config: MultiColumnOrConfig) -> Self {
89 Self { ctx, config }
90 }
91
92 pub fn optimize(&self, plan: PhysicalPlan) -> PhysicalPlan {
94 self.traverse(plan)
95 }
96
97 fn traverse(&self, plan: PhysicalPlan) -> PhysicalPlan {
98 match plan {
99 PhysicalPlan::Filter { input, predicate } => {
100 let optimized_input = self.traverse(*input);
101
102 if let Some(optimized) =
104 self.try_optimize_or_predicate(&optimized_input, &predicate)
105 {
106 return optimized;
107 }
108
109 PhysicalPlan::Filter {
110 input: Box::new(optimized_input),
111 predicate,
112 }
113 }
114
115 PhysicalPlan::Project { input, columns } => PhysicalPlan::Project {
117 input: Box::new(self.traverse(*input)),
118 columns,
119 },
120
121 PhysicalPlan::Sort { input, order_by } => PhysicalPlan::Sort {
122 input: Box::new(self.traverse(*input)),
123 order_by,
124 },
125
126 PhysicalPlan::Limit {
127 input,
128 limit,
129 offset,
130 } => PhysicalPlan::Limit {
131 input: Box::new(self.traverse(*input)),
132 limit,
133 offset,
134 },
135
136 PhysicalPlan::HashJoin {
137 left,
138 right,
139 condition,
140 join_type,
141 } => PhysicalPlan::HashJoin {
142 left: Box::new(self.traverse(*left)),
143 right: Box::new(self.traverse(*right)),
144 condition,
145 join_type,
146 },
147
148 PhysicalPlan::SortMergeJoin {
149 left,
150 right,
151 condition,
152 join_type,
153 } => PhysicalPlan::SortMergeJoin {
154 left: Box::new(self.traverse(*left)),
155 right: Box::new(self.traverse(*right)),
156 condition,
157 join_type,
158 },
159
160 PhysicalPlan::NestedLoopJoin {
161 left,
162 right,
163 condition,
164 join_type,
165 } => PhysicalPlan::NestedLoopJoin {
166 left: Box::new(self.traverse(*left)),
167 right: Box::new(self.traverse(*right)),
168 condition,
169 join_type,
170 },
171
172 PhysicalPlan::IndexNestedLoopJoin {
173 outer,
174 inner_table,
175 inner_index,
176 condition,
177 join_type,
178 } => PhysicalPlan::IndexNestedLoopJoin {
179 outer: Box::new(self.traverse(*outer)),
180 inner_table,
181 inner_index,
182 condition,
183 join_type,
184 },
185
186 PhysicalPlan::HashAggregate {
187 input,
188 group_by,
189 aggregates,
190 } => PhysicalPlan::HashAggregate {
191 input: Box::new(self.traverse(*input)),
192 group_by,
193 aggregates,
194 },
195
196 PhysicalPlan::CrossProduct { left, right } => PhysicalPlan::CrossProduct {
197 left: Box::new(self.traverse(*left)),
198 right: Box::new(self.traverse(*right)),
199 },
200
201 PhysicalPlan::NoOp { input } => PhysicalPlan::NoOp {
202 input: Box::new(self.traverse(*input)),
203 },
204
205 PhysicalPlan::TopN {
206 input,
207 order_by,
208 limit,
209 offset,
210 } => PhysicalPlan::TopN {
211 input: Box::new(self.traverse(*input)),
212 order_by,
213 limit,
214 offset,
215 },
216
217 plan @ (PhysicalPlan::TableScan { .. }
219 | PhysicalPlan::IndexScan { .. }
220 | PhysicalPlan::IndexGet { .. }
221 | PhysicalPlan::IndexInGet { .. }
222 | PhysicalPlan::GinIndexScan { .. }
223 | PhysicalPlan::GinIndexScanMulti { .. }
224 | PhysicalPlan::Empty) => plan,
225 }
226 }
227
228 fn try_optimize_or_predicate(
230 &self,
231 input: &PhysicalPlan,
232 predicate: &Expr,
233 ) -> Option<PhysicalPlan> {
234 let table = match input {
236 PhysicalPlan::TableScan { table } => table,
237 _ => return None,
238 };
239
240 let row_count = self.ctx.row_count(table);
242 if row_count < self.config.min_table_size {
243 return None;
244 }
245
246 let branches = self.extract_or_branches(predicate);
248 if branches.len() < 2 || branches.len() > self.config.max_branches {
249 return None;
250 }
251
252 let mut index_candidates = Vec::new();
254 for branch in &branches {
255 if let Some((column, index_name)) = self.find_index_for_predicate(table, branch) {
256 index_candidates.push((column, index_name, branch.clone()));
257 } else {
258 return None; }
260 }
261
262 None
266 }
267
268 fn extract_or_branches(&self, predicate: &Expr) -> Vec<Expr> {
270 match predicate {
271 Expr::BinaryOp {
272 left,
273 op: BinaryOp::Or,
274 right,
275 } => {
276 let mut branches = self.extract_or_branches(left);
277 branches.extend(self.extract_or_branches(right));
278 branches
279 }
280 other => alloc::vec![other.clone()],
281 }
282 }
283
284 fn find_index_for_predicate(
286 &self,
287 table: &str,
288 predicate: &Expr,
289 ) -> Option<(String, String)> {
290 let column = self.extract_indexed_column(predicate)?;
292
293 let index = self.ctx.find_index(table, &[&column])?;
295
296 Some((column, index.name.clone()))
297 }
298
299 fn extract_indexed_column(&self, predicate: &Expr) -> Option<String> {
301 match predicate {
302 Expr::BinaryOp {
303 left,
304 op: BinaryOp::Eq,
305 right,
306 } => {
307 if let Expr::Column(col_ref) = left.as_ref() {
309 if matches!(right.as_ref(), Expr::Literal(_)) {
310 return Some(col_ref.column.clone());
311 }
312 }
313 if let Expr::Column(col_ref) = right.as_ref() {
315 if matches!(left.as_ref(), Expr::Literal(_)) {
316 return Some(col_ref.column.clone());
317 }
318 }
319 None
320 }
321 _ => None,
322 }
323 }
324
325 #[allow(dead_code)]
327 fn estimate_selectivity(&self, _table: &str, _predicate: &Expr) -> f64 {
328 0.01
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::ast::Expr;
338 use crate::context::{IndexInfo, TableStats};
339
340 fn create_test_context() -> ExecutionContext {
341 let mut ctx = ExecutionContext::new();
342
343 ctx.register_table(
344 "users",
345 TableStats {
346 row_count: 100000, is_sorted: false,
348 indexes: alloc::vec![
349 IndexInfo::new("idx_id", alloc::vec!["id".into()], true),
350 IndexInfo::new("idx_name", alloc::vec!["name".into()], false),
351 IndexInfo::new("idx_email", alloc::vec!["email".into()], true),
352 ],
353 },
354 );
355
356 ctx.register_table(
357 "small_table",
358 TableStats {
359 row_count: 100, is_sorted: false,
361 indexes: alloc::vec![IndexInfo::new(
362 "idx_id",
363 alloc::vec!["id".into()],
364 true
365 )],
366 },
367 );
368
369 ctx
370 }
371
372 #[test]
373 fn test_extract_or_branches() {
374 let ctx = create_test_context();
375 let pass = MultiColumnOrPass::new(&ctx);
376
377 let pred = Expr::or(
379 Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64)),
380 Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64)),
381 );
382 let branches = pass.extract_or_branches(&pred);
383 assert_eq!(branches.len(), 2);
384
385 let pred = Expr::or(
387 Expr::or(
388 Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64)),
389 Expr::eq(Expr::column("t", "b", 1), Expr::literal(2i64)),
390 ),
391 Expr::eq(Expr::column("t", "c", 2), Expr::literal(3i64)),
392 );
393 let branches = pass.extract_or_branches(&pred);
394 assert_eq!(branches.len(), 3);
395
396 let pred = Expr::eq(Expr::column("t", "a", 0), Expr::literal(1i64));
398 let branches = pass.extract_or_branches(&pred);
399 assert_eq!(branches.len(), 1);
400 }
401
402 #[test]
403 fn test_find_index_for_predicate() {
404 let ctx = create_test_context();
405 let pass = MultiColumnOrPass::new(&ctx);
406
407 let pred = Expr::eq(Expr::column("users", "id", 0), Expr::literal(1i64));
409 let result = pass.find_index_for_predicate("users", &pred);
410 assert!(result.is_some());
411 let (col, idx) = result.unwrap();
412 assert_eq!(col, "id");
413 assert_eq!(idx, "idx_id");
414
415 let pred = Expr::eq(Expr::column("users", "age", 3), Expr::literal(25i64));
417 let result = pass.find_index_for_predicate("users", &pred);
418 assert!(result.is_none());
419 }
420
421 #[test]
422 fn test_small_table_not_optimized() {
423 let ctx = create_test_context();
424 let pass = MultiColumnOrPass::new(&ctx);
425
426 let plan = PhysicalPlan::Filter {
428 input: Box::new(PhysicalPlan::table_scan("small_table")),
429 predicate: Expr::or(
430 Expr::eq(Expr::column("small_table", "id", 0), Expr::literal(1i64)),
431 Expr::eq(Expr::column("small_table", "id", 0), Expr::literal(2i64)),
432 ),
433 };
434
435 let result = pass.optimize(plan);
436
437 assert!(matches!(result, PhysicalPlan::Filter { .. }));
439 }
440
441 #[test]
442 fn test_config_customization() {
443 let ctx = create_test_context();
444 let config = MultiColumnOrConfig {
445 min_table_size: 1000,
446 max_selectivity: 0.05,
447 max_branches: 3,
448 };
449 let pass = MultiColumnOrPass::with_config(&ctx, config);
450
451 assert_eq!(pass.config.min_table_size, 1000);
452 assert_eq!(pass.config.max_selectivity, 0.05);
453 assert_eq!(pass.config.max_branches, 3);
454 }
455
456 #[test]
457 fn test_extract_indexed_column() {
458 let ctx = create_test_context();
459 let pass = MultiColumnOrPass::new(&ctx);
460
461 let pred = Expr::eq(Expr::column("t", "id", 0), Expr::literal(1i64));
463 assert_eq!(pass.extract_indexed_column(&pred), Some("id".into()));
464
465 let pred = Expr::eq(Expr::literal(1i64), Expr::column("t", "id", 0));
467 assert_eq!(pass.extract_indexed_column(&pred), Some("id".into()));
468
469 let pred = Expr::eq(Expr::column("t", "a", 0), Expr::column("t", "b", 1));
471 assert_eq!(pass.extract_indexed_column(&pred), None);
472
473 let pred = Expr::gt(Expr::column("t", "id", 0), Expr::literal(1i64));
475 assert_eq!(pass.extract_indexed_column(&pred), None);
476 }
477}