Skip to main content

featherdb_query/optimizer/
mod.rs

1//! Query optimizer
2
3pub mod cost;
4mod helpers;
5mod join_order;
6mod rule;
7mod rules;
8
9use crate::planner::LogicalPlan;
10use featherdb_catalog::Catalog;
11use rule::OptimizationRule;
12use rules::{
13    ConstantFolding, IndexSelection, PkSeekRule, PredicatePushdown, ProjectionPushdown,
14    SubqueryToJoinConversion,
15};
16use std::sync::Arc;
17
18// Re-export public types
19pub use cost::{constants as cost_constants, CostEstimator};
20
21/// Query optimizer that applies transformation rules
22pub struct Optimizer {
23    rules: Vec<Box<dyn OptimizationRule>>,
24    /// Catalog reference for index selection
25    #[allow(dead_code)]
26    catalog: Option<Arc<Catalog>>,
27}
28
29impl Optimizer {
30    /// Create a new optimizer with default rules
31    pub fn new() -> Self {
32        Optimizer {
33            rules: vec![
34                Box::new(PredicatePushdown),
35                Box::new(PkSeekRule::new()),
36                Box::new(ProjectionPushdown),
37                Box::new(ConstantFolding),
38            ],
39            catalog: None,
40        }
41    }
42
43    /// Create an optimizer with catalog access for index selection and join ordering
44    pub fn with_catalog(catalog: Arc<Catalog>) -> Self {
45        Optimizer {
46            rules: vec![
47                // SubqueryToJoinConversion should run early to convert subqueries to joins
48                Box::new(SubqueryToJoinConversion),
49                // PredicatePushdown before JoinOrder so that single-table filters
50                // are already on scans when we estimate relation cardinalities
51                Box::new(PredicatePushdown),
52                // JoinOrderOptimizer reorders joins using cardinality estimates
53                Box::new(join_order::JoinOrderOptimizer::new(catalog.clone())),
54                // Run PredicatePushdown again to push any new predicates from reordered joins
55                Box::new(PredicatePushdown),
56                // PkSeekRule after PredicatePushdown, before IndexSelection
57                Box::new(PkSeekRule::new()),
58                // IndexSelection converts Scan + Filter into IndexScan when beneficial
59                Box::new(IndexSelection::new(catalog.clone())),
60                Box::new(ProjectionPushdown),
61                Box::new(ConstantFolding),
62            ],
63            catalog: Some(catalog),
64        }
65    }
66
67    /// Optimize a logical plan
68    pub fn optimize(&self, plan: LogicalPlan) -> LogicalPlan {
69        let mut current = plan;
70
71        // Apply each rule
72        for rule in &self.rules {
73            current = rule.apply(current);
74        }
75
76        current
77    }
78}
79
80impl Default for Optimizer {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::expr::{BinaryOp, Expr};
90    use featherdb_catalog::{Catalog, ColumnConstraint, Index, TableBuilder};
91    use featherdb_core::{ColumnType, PageId, Value};
92
93    fn create_test_catalog() -> Catalog {
94        let catalog = Catalog::new();
95        let users = TableBuilder::new("users")
96            .column_with(
97                "id",
98                ColumnType::Integer,
99                vec![ColumnConstraint::PrimaryKey],
100            )
101            .column("name", ColumnType::Text { max_len: None })
102            .column("age", ColumnType::Integer)
103            .build(1, PageId(10));
104        catalog.create_table(users).unwrap();
105
106        // Add an index on 'id'
107        let idx = Index::new("idx_users_id", "users", vec![0], true, PageId(20));
108        catalog.add_index("users", idx).unwrap();
109
110        // Add an index on 'age'
111        let idx_age = Index::new("idx_users_age", "users", vec![2], false, PageId(30));
112        catalog.add_index("users", idx_age).unwrap();
113
114        catalog
115    }
116
117    #[test]
118    fn test_constant_folding() {
119        let expr = Expr::binary(Expr::literal(1i64), BinaryOp::Add, Expr::literal(2i64));
120
121        let folded = helpers::fold_constants(expr);
122        assert!(matches!(folded, Expr::Literal(Value::Integer(3))));
123    }
124
125    #[test]
126    fn test_split_conjunctions() {
127        let expr = Expr::and(
128            Expr::column("a"),
129            Expr::and(Expr::column("b"), Expr::column("c")),
130        );
131
132        let parts = helpers::split_conjunctions(expr);
133        assert_eq!(parts.len(), 3);
134    }
135
136    #[test]
137    fn test_index_selection_equality() {
138        let catalog = Arc::new(create_test_catalog());
139        let table = catalog.get_table("users").unwrap();
140
141        // Create a Filter(Scan) plan with an equality predicate: WHERE id = 5
142        // Since id is the primary key, this should use PkSeek, not IndexScan
143        let scan = LogicalPlan::Scan {
144            table: table.clone(),
145            alias: None,
146            projection: None,
147            filter: None,
148        };
149        let filter = LogicalPlan::Filter {
150            input: Box::new(scan),
151            predicate: Expr::binary(Expr::column("id"), BinaryOp::Eq, Expr::literal(5i64)),
152        };
153
154        // Apply optimization
155        let optimizer = Optimizer::with_catalog(catalog);
156        let optimized = optimizer.optimize(filter);
157
158        // Should be converted to PkSeek (primary key equality)
159        match optimized {
160            LogicalPlan::PkSeek {
161                key_values,
162                residual_filter,
163                ..
164            } => {
165                assert_eq!(key_values.len(), 1);
166                assert!(matches!(key_values[0], Expr::Literal(Value::Integer(5))));
167                assert!(residual_filter.is_none());
168            }
169            other => panic!("Expected PkSeek, got {:?}", other),
170        }
171    }
172
173    #[test]
174    fn test_index_selection_range() {
175        let catalog = Arc::new(create_test_catalog());
176        let table = catalog.get_table("users").unwrap();
177
178        // Create a plan: WHERE age > 18 AND age < 65
179        let scan = LogicalPlan::Scan {
180            table: table.clone(),
181            alias: None,
182            projection: None,
183            filter: None,
184        };
185        let filter = LogicalPlan::Filter {
186            input: Box::new(scan),
187            predicate: Expr::and(
188                Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
189                Expr::binary(Expr::column("age"), BinaryOp::Lt, Expr::literal(65i64)),
190            ),
191        };
192
193        let optimizer = Optimizer::with_catalog(catalog);
194        let optimized = optimizer.optimize(filter);
195
196        // Should be converted to IndexScan with range
197        match optimized {
198            LogicalPlan::IndexScan {
199                index,
200                range,
201                residual_filter,
202                ..
203            } => {
204                assert_eq!(index.name, "idx_users_age");
205                assert!(!range.is_point_lookup());
206                assert!(residual_filter.is_none()); // Both predicates consumed by range
207            }
208            other => panic!("Expected IndexScan, got {:?}", other),
209        }
210    }
211
212    #[test]
213    fn test_index_selection_with_residual_filter() {
214        let catalog = Arc::new(create_test_catalog());
215        let table = catalog.get_table("users").unwrap();
216
217        // Create: WHERE age > 18 AND name = 'Alice'
218        // Only age predicate can use index, name becomes residual
219        let scan = LogicalPlan::Scan {
220            table: table.clone(),
221            alias: None,
222            projection: None,
223            filter: None,
224        };
225        let filter = LogicalPlan::Filter {
226            input: Box::new(scan),
227            predicate: Expr::and(
228                Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
229                Expr::binary(
230                    Expr::column("name"),
231                    BinaryOp::Eq,
232                    Expr::literal(Value::Text("Alice".to_string())),
233                ),
234            ),
235        };
236
237        let optimizer = Optimizer::with_catalog(catalog);
238        let optimized = optimizer.optimize(filter);
239
240        match optimized {
241            LogicalPlan::IndexScan {
242                index,
243                range,
244                residual_filter,
245                ..
246            } => {
247                assert_eq!(index.name, "idx_users_age");
248                assert!(!range.is_point_lookup());
249                // name = 'Alice' should be in residual filter
250                assert!(residual_filter.is_some());
251            }
252            other => panic!("Expected IndexScan, got {:?}", other),
253        }
254    }
255
256    #[test]
257    fn test_index_selection_no_index() {
258        let catalog = Arc::new(create_test_catalog());
259        let table = catalog.get_table("users").unwrap();
260
261        // Create: WHERE name = 'Alice' (no index on name)
262        let scan = LogicalPlan::Scan {
263            table: table.clone(),
264            alias: None,
265            projection: None,
266            filter: None,
267        };
268        let filter = LogicalPlan::Filter {
269            input: Box::new(scan),
270            predicate: Expr::binary(
271                Expr::column("name"),
272                BinaryOp::Eq,
273                Expr::literal(Value::Text("Alice".to_string())),
274            ),
275        };
276
277        let optimizer = Optimizer::with_catalog(catalog);
278        let optimized = optimizer.optimize(filter);
279
280        // Should stay as Scan with filter (predicate pushed down)
281        match optimized {
282            LogicalPlan::Scan { filter, .. } => {
283                assert!(filter.is_some());
284            }
285            other => panic!("Expected Scan with filter, got {:?}", other),
286        }
287    }
288
289    #[test]
290    fn test_cost_estimator_basic() {
291        let catalog = create_test_catalog();
292        let table = catalog.get_table("users").unwrap();
293
294        let scan = LogicalPlan::Scan {
295            table: table.clone(),
296            alias: None,
297            projection: None,
298            filter: None,
299        };
300
301        let estimator = CostEstimator::new(&catalog);
302        let cardinality = estimator.estimate_cardinality(&scan);
303        let cost = estimator.estimate_cost(&scan);
304
305        // Default row count is 1000 for new tables
306        assert_eq!(cardinality, 1000.0);
307        assert!(cost > 0.0);
308    }
309
310    #[test]
311    fn test_cost_estimator_filter() {
312        let catalog = create_test_catalog();
313        let table = catalog.get_table("users").unwrap();
314
315        let scan = LogicalPlan::Scan {
316            table: table.clone(),
317            alias: None,
318            projection: None,
319            filter: None,
320        };
321
322        let filter = LogicalPlan::Filter {
323            input: Box::new(scan),
324            predicate: Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
325        };
326
327        let estimator = CostEstimator::new(&catalog);
328        let cardinality = estimator.estimate_cardinality(&filter);
329
330        // Filter should reduce cardinality
331        assert!(cardinality < 1000.0);
332        assert!(cardinality > 0.0);
333    }
334
335    #[test]
336    fn test_cost_estimator_join() {
337        let catalog = create_test_catalog();
338
339        // Create another table for joining
340        let orders = TableBuilder::new("orders")
341            .column("id", ColumnType::Integer)
342            .column("user_id", ColumnType::Integer)
343            .build(2, PageId(100));
344        catalog.create_table(orders).unwrap();
345
346        let users = catalog.get_table("users").unwrap();
347        let orders_table = catalog.get_table("orders").unwrap();
348
349        let left_scan = LogicalPlan::Scan {
350            table: users.clone(),
351            alias: None,
352            projection: None,
353            filter: None,
354        };
355
356        let right_scan = LogicalPlan::Scan {
357            table: orders_table.clone(),
358            alias: None,
359            projection: None,
360            filter: None,
361        };
362
363        let join = LogicalPlan::Join {
364            left: Box::new(left_scan),
365            right: Box::new(right_scan),
366            join_type: crate::planner::JoinType::Inner,
367            condition: Some(Expr::binary(
368                Expr::column("users.id"),
369                BinaryOp::Eq,
370                Expr::column("orders.user_id"),
371            )),
372        };
373
374        let estimator = CostEstimator::new(&catalog);
375        let cardinality = estimator.estimate_cardinality(&join);
376        let cost = estimator.estimate_cost(&join);
377
378        // Join cardinality should be > 0
379        assert!(cardinality > 0.0);
380        assert!(cost > 0.0);
381    }
382
383    #[test]
384    fn test_cost_estimator_format_plan() {
385        let catalog = create_test_catalog();
386        let table = catalog.get_table("users").unwrap();
387
388        let scan = LogicalPlan::Scan {
389            table: table.clone(),
390            alias: None,
391            projection: None,
392            filter: None,
393        };
394
395        let estimator = CostEstimator::new(&catalog);
396        let formatted = estimator.format_plan_with_costs(&scan, 0);
397
398        assert!(formatted.contains("Scan: users"));
399        assert!(formatted.contains("cost="));
400        assert!(formatted.contains("rows="));
401    }
402
403    #[test]
404    fn test_join_order_optimizer() {
405        let catalog = Arc::new(create_test_catalog());
406
407        // Create additional tables for multi-way join
408        let orders = TableBuilder::new("orders")
409            .column("id", ColumnType::Integer)
410            .column("user_id", ColumnType::Integer)
411            .build(2, PageId(100));
412        catalog.create_table(orders).unwrap();
413
414        let products = TableBuilder::new("products")
415            .column("id", ColumnType::Integer)
416            .column("order_id", ColumnType::Integer)
417            .build(3, PageId(200));
418        catalog.create_table(products).unwrap();
419
420        let users = catalog.get_table("users").unwrap();
421        let orders_table = catalog.get_table("orders").unwrap();
422        let products_table = catalog.get_table("products").unwrap();
423
424        // Create a 3-way join: users JOIN orders JOIN products
425        let users_scan = LogicalPlan::Scan {
426            table: users.clone(),
427            alias: None,
428            projection: None,
429            filter: None,
430        };
431
432        let orders_scan = LogicalPlan::Scan {
433            table: orders_table.clone(),
434            alias: None,
435            projection: None,
436            filter: None,
437        };
438
439        let products_scan = LogicalPlan::Scan {
440            table: products_table.clone(),
441            alias: None,
442            projection: None,
443            filter: None,
444        };
445
446        // Join users with orders
447        let join1 = LogicalPlan::Join {
448            left: Box::new(users_scan),
449            right: Box::new(orders_scan),
450            join_type: crate::planner::JoinType::Inner,
451            condition: Some(Expr::binary(
452                Expr::column("users.id"),
453                BinaryOp::Eq,
454                Expr::column("orders.user_id"),
455            )),
456        };
457
458        // Join result with products
459        let join2 = LogicalPlan::Join {
460            left: Box::new(join1),
461            right: Box::new(products_scan),
462            join_type: crate::planner::JoinType::Inner,
463            condition: Some(Expr::binary(
464                Expr::column("orders.id"),
465                BinaryOp::Eq,
466                Expr::column("products.order_id"),
467            )),
468        };
469
470        let optimizer = Optimizer::with_catalog(catalog);
471        let optimized = optimizer.optimize(join2);
472
473        // Verify it's still a join (may be reordered)
474        match optimized {
475            LogicalPlan::Join { .. } => {
476                // Success - join order may have been optimized
477            }
478            other => panic!("Expected Join, got {:?}", other),
479        }
480    }
481
482    #[test]
483    fn test_selectivity_estimation() {
484        let catalog = create_test_catalog();
485        let table = catalog.get_table("users").unwrap();
486
487        let estimator = CostEstimator::new(&catalog);
488
489        // Test AND selectivity (should multiply)
490        let and_pred = Expr::and(
491            Expr::binary(Expr::column("age"), BinaryOp::Gt, Expr::literal(18i64)),
492            Expr::binary(Expr::column("age"), BinaryOp::Lt, Expr::literal(65i64)),
493        );
494
495        let scan = LogicalPlan::Scan {
496            table: table.clone(),
497            alias: None,
498            projection: None,
499            filter: Some(and_pred),
500        };
501
502        let cardinality = estimator.estimate_cardinality(&scan);
503
504        // Should be significantly less than base rows due to AND selectivity
505        assert!(cardinality < 1000.0);
506        assert!(cardinality > 0.0);
507    }
508}