1pub mod cost;
4mod helpers;
5mod join_order;
6mod rule;
7mod rules;
8
9use crate::planner::LogicalPlan;
10use featherdb_catalog::Catalog;
11use rule::OptimizationRule;
12use rules::{
13 ConstantFolding, IndexSelection, PkSeekRule, PredicatePushdown, ProjectionPushdown,
14 SubqueryToJoinConversion,
15};
16use std::sync::Arc;
17
18pub use cost::{constants as cost_constants, CostEstimator};
20
21pub struct Optimizer {
23 rules: Vec<Box<dyn OptimizationRule>>,
24 #[allow(dead_code)]
26 catalog: Option<Arc<Catalog>>,
27}
28
29impl Optimizer {
30 pub fn new() -> Self {
32 Optimizer {
33 rules: vec![
34 Box::new(PredicatePushdown),
35 Box::new(PkSeekRule::new()),
36 Box::new(ProjectionPushdown),
37 Box::new(ConstantFolding),
38 ],
39 catalog: None,
40 }
41 }
42
43 pub fn with_catalog(catalog: Arc<Catalog>) -> Self {
45 Optimizer {
46 rules: vec![
47 Box::new(SubqueryToJoinConversion),
49 Box::new(PredicatePushdown),
52 Box::new(join_order::JoinOrderOptimizer::new(catalog.clone())),
54 Box::new(PredicatePushdown),
56 Box::new(PkSeekRule::new()),
58 Box::new(IndexSelection::new(catalog.clone())),
60 Box::new(ProjectionPushdown),
61 Box::new(ConstantFolding),
62 ],
63 catalog: Some(catalog),
64 }
65 }
66
67 pub fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
69 let mut current = plan;
70
71 for rule in &self.rules {
73 current = rule.apply(current);
74 }
75
76 current
77 }
78}
79
80impl Default for Optimizer {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::expr::{BinaryOp, Expr};
90 use featherdb_catalog::{Catalog, ColumnConstraint, Index, TableBuilder};
91 use featherdb_core::{ColumnType, PageId, Value};
92
93 fn create_test_catalog() -> Catalog {
94 let catalog = Catalog::new();
95 let users = TableBuilder::new("users")
96 .column_with(
97 "id",
98 ColumnType::Integer,
99 vec![ColumnConstraint::PrimaryKey],
100 )
101 .column("name", ColumnType::Text { max_len: None })
102 .column("age", ColumnType::Integer)
103 .build(1, PageId(10));
104 catalog.create_table(users).unwrap();
105
106 let idx = Index::new("idx_users_id", "users", vec![0], true, PageId(20));
108 catalog.add_index("users", idx).unwrap();
109
110 let idx_age = Index::new("idx_users_age", "users", vec![2], false, PageId(30));
112 catalog.add_index("users", idx_age).unwrap();
113
114 catalog
115 }
116
117 #[test]
118 fn test_constant_folding() {
119 let expr = Expr::binary(Expr::literal(1i64), BinaryOp::Add, Expr::literal(2i64));
120
121 let folded = helpers::fold_constants(expr);
122 assert!(matches!(folded, Expr::Literal(Value::Integer(3))));
123 }
124
125 #[test]
126 fn test_split_conjunctions() {
127 let expr = Expr::and(
128 Expr::column("a"),
129 Expr::and(Expr::column("b"), Expr::column("c")),
130 );
131
132 let parts = helpers::split_conjunctions(expr);
133 assert_eq!(parts.len(), 3);
134 }
135
136 #[test]
137 fn test_index_selection_equality() {
138 let catalog = Arc::new(create_test_catalog());
139 let table = catalog.get_table("users").unwrap();
140
141 let scan = LogicalPlan::Scan {
144 table: table.clone(),
145 alias: None,
146 projection: None,
147 filter: None,
148 };
149 let filter = LogicalPlan::Filter {
150 input: Box::new(scan),
151 predicate: Expr::binary(Expr::column("id"), BinaryOp::Eq, Expr::literal(5i64)),
152 };
153
154 let optimizer = Optimizer::with_catalog(catalog);
156 let optimized = optimizer.optimize(filter);
157
158 match optimized {
160 LogicalPlan::PkSeek {
161 key_values,
162 residual_filter,
163 ..
164 } => {
165 assert_eq!(key_values.len(), 1);
166 assert!(matches!(key_values[0], Expr::Literal(Value::Integer(5))));
167 assert!(residual_filter.is_none());
168 }
169 other => panic!("Expected PkSeek, got {:?}", other),
170 }
171 }
172
173 #[test]
174 fn test_index_selection_range() {
175 let catalog = Arc::new(create_test_catalog());
176 let table = catalog.get_table("users").unwrap();
177
178 let scan = LogicalPlan::Scan {
180 table: table.clone(),
181 alias: None,
182 projection: None,
183 filter: None,
184 };
185 let filter = LogicalPlan::Filter {
186 input: Box::new(scan),
187 predicate: Expr::and(
188 Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
189 Expr::binary(Expr::column("age"), BinaryOp::Lt, Expr::literal(65i64)),
190 ),
191 };
192
193 let optimizer = Optimizer::with_catalog(catalog);
194 let optimized = optimizer.optimize(filter);
195
196 match optimized {
198 LogicalPlan::IndexScan {
199 index,
200 range,
201 residual_filter,
202 ..
203 } => {
204 assert_eq!(index.name, "idx_users_age");
205 assert!(!range.is_point_lookup());
206 assert!(residual_filter.is_none()); }
208 other => panic!("Expected IndexScan, got {:?}", other),
209 }
210 }
211
212 #[test]
213 fn test_index_selection_with_residual_filter() {
214 let catalog = Arc::new(create_test_catalog());
215 let table = catalog.get_table("users").unwrap();
216
217 let scan = LogicalPlan::Scan {
220 table: table.clone(),
221 alias: None,
222 projection: None,
223 filter: None,
224 };
225 let filter = LogicalPlan::Filter {
226 input: Box::new(scan),
227 predicate: Expr::and(
228 Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
229 Expr::binary(
230 Expr::column("name"),
231 BinaryOp::Eq,
232 Expr::literal(Value::Text("Alice".to_string())),
233 ),
234 ),
235 };
236
237 let optimizer = Optimizer::with_catalog(catalog);
238 let optimized = optimizer.optimize(filter);
239
240 match optimized {
241 LogicalPlan::IndexScan {
242 index,
243 range,
244 residual_filter,
245 ..
246 } => {
247 assert_eq!(index.name, "idx_users_age");
248 assert!(!range.is_point_lookup());
249 assert!(residual_filter.is_some());
251 }
252 other => panic!("Expected IndexScan, got {:?}", other),
253 }
254 }
255
256 #[test]
257 fn test_index_selection_no_index() {
258 let catalog = Arc::new(create_test_catalog());
259 let table = catalog.get_table("users").unwrap();
260
261 let scan = LogicalPlan::Scan {
263 table: table.clone(),
264 alias: None,
265 projection: None,
266 filter: None,
267 };
268 let filter = LogicalPlan::Filter {
269 input: Box::new(scan),
270 predicate: Expr::binary(
271 Expr::column("name"),
272 BinaryOp::Eq,
273 Expr::literal(Value::Text("Alice".to_string())),
274 ),
275 };
276
277 let optimizer = Optimizer::with_catalog(catalog);
278 let optimized = optimizer.optimize(filter);
279
280 match optimized {
282 LogicalPlan::Scan { filter, .. } => {
283 assert!(filter.is_some());
284 }
285 other => panic!("Expected Scan with filter, got {:?}", other),
286 }
287 }
288
289 #[test]
290 fn test_cost_estimator_basic() {
291 let catalog = create_test_catalog();
292 let table = catalog.get_table("users").unwrap();
293
294 let scan = LogicalPlan::Scan {
295 table: table.clone(),
296 alias: None,
297 projection: None,
298 filter: None,
299 };
300
301 let estimator = CostEstimator::new(&catalog);
302 let cardinality = estimator.estimate_cardinality(&scan);
303 let cost = estimator.estimate_cost(&scan);
304
305 assert_eq!(cardinality, 1000.0);
307 assert!(cost > 0.0);
308 }
309
310 #[test]
311 fn test_cost_estimator_filter() {
312 let catalog = create_test_catalog();
313 let table = catalog.get_table("users").unwrap();
314
315 let scan = LogicalPlan::Scan {
316 table: table.clone(),
317 alias: None,
318 projection: None,
319 filter: None,
320 };
321
322 let filter = LogicalPlan::Filter {
323 input: Box::new(scan),
324 predicate: Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
325 };
326
327 let estimator = CostEstimator::new(&catalog);
328 let cardinality = estimator.estimate_cardinality(&filter);
329
330 assert!(cardinality < 1000.0);
332 assert!(cardinality > 0.0);
333 }
334
335 #[test]
336 fn test_cost_estimator_join() {
337 let catalog = create_test_catalog();
338
339 let orders = TableBuilder::new("orders")
341 .column("id", ColumnType::Integer)
342 .column("user_id", ColumnType::Integer)
343 .build(2, PageId(100));
344 catalog.create_table(orders).unwrap();
345
346 let users = catalog.get_table("users").unwrap();
347 let orders_table = catalog.get_table("orders").unwrap();
348
349 let left_scan = LogicalPlan::Scan {
350 table: users.clone(),
351 alias: None,
352 projection: None,
353 filter: None,
354 };
355
356 let right_scan = LogicalPlan::Scan {
357 table: orders_table.clone(),
358 alias: None,
359 projection: None,
360 filter: None,
361 };
362
363 let join = LogicalPlan::Join {
364 left: Box::new(left_scan),
365 right: Box::new(right_scan),
366 join_type: crate::planner::JoinType::Inner,
367 condition: Some(Expr::binary(
368 Expr::column("users.id"),
369 BinaryOp::Eq,
370 Expr::column("orders.user_id"),
371 )),
372 };
373
374 let estimator = CostEstimator::new(&catalog);
375 let cardinality = estimator.estimate_cardinality(&join);
376 let cost = estimator.estimate_cost(&join);
377
378 assert!(cardinality > 0.0);
380 assert!(cost > 0.0);
381 }
382
383 #[test]
384 fn test_cost_estimator_format_plan() {
385 let catalog = create_test_catalog();
386 let table = catalog.get_table("users").unwrap();
387
388 let scan = LogicalPlan::Scan {
389 table: table.clone(),
390 alias: None,
391 projection: None,
392 filter: None,
393 };
394
395 let estimator = CostEstimator::new(&catalog);
396 let formatted = estimator.format_plan_with_costs(&scan, 0);
397
398 assert!(formatted.contains("Scan: users"));
399 assert!(formatted.contains("cost="));
400 assert!(formatted.contains("rows="));
401 }
402
403 #[test]
404 fn test_join_order_optimizer() {
405 let catalog = Arc::new(create_test_catalog());
406
407 let orders = TableBuilder::new("orders")
409 .column("id", ColumnType::Integer)
410 .column("user_id", ColumnType::Integer)
411 .build(2, PageId(100));
412 catalog.create_table(orders).unwrap();
413
414 let products = TableBuilder::new("products")
415 .column("id", ColumnType::Integer)
416 .column("order_id", ColumnType::Integer)
417 .build(3, PageId(200));
418 catalog.create_table(products).unwrap();
419
420 let users = catalog.get_table("users").unwrap();
421 let orders_table = catalog.get_table("orders").unwrap();
422 let products_table = catalog.get_table("products").unwrap();
423
424 let users_scan = LogicalPlan::Scan {
426 table: users.clone(),
427 alias: None,
428 projection: None,
429 filter: None,
430 };
431
432 let orders_scan = LogicalPlan::Scan {
433 table: orders_table.clone(),
434 alias: None,
435 projection: None,
436 filter: None,
437 };
438
439 let products_scan = LogicalPlan::Scan {
440 table: products_table.clone(),
441 alias: None,
442 projection: None,
443 filter: None,
444 };
445
446 let join1 = LogicalPlan::Join {
448 left: Box::new(users_scan),
449 right: Box::new(orders_scan),
450 join_type: crate::planner::JoinType::Inner,
451 condition: Some(Expr::binary(
452 Expr::column("users.id"),
453 BinaryOp::Eq,
454 Expr::column("orders.user_id"),
455 )),
456 };
457
458 let join2 = LogicalPlan::Join {
460 left: Box::new(join1),
461 right: Box::new(products_scan),
462 join_type: crate::planner::JoinType::Inner,
463 condition: Some(Expr::binary(
464 Expr::column("orders.id"),
465 BinaryOp::Eq,
466 Expr::column("products.order_id"),
467 )),
468 };
469
470 let optimizer = Optimizer::with_catalog(catalog);
471 let optimized = optimizer.optimize(join2);
472
473 match optimized {
475 LogicalPlan::Join { .. } => {
476 }
478 other => panic!("Expected Join, got {:?}", other),
479 }
480 }
481
482 #[test]
483 fn test_selectivity_estimation() {
484 let catalog = create_test_catalog();
485 let table = catalog.get_table("users").unwrap();
486
487 let estimator = CostEstimator::new(&catalog);
488
489 let and_pred = Expr::and(
491 Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
492 Expr::binary(Expr::column("age"), BinaryOp::Lt, Expr::literal(65i64)),
493 );
494
495 let scan = LogicalPlan::Scan {
496 table: table.clone(),
497 alias: None,
498 projection: None,
499 filter: Some(and_pred),
500 };
501
502 let cardinality = estimator.estimate_cardinality(&scan);
503
504 assert!(cardinality < 1000.0);
506 assert!(cardinality > 0.0);
507 }
508}