Skip to main content

cynos_query/optimizer/
get_row_count.rs

1//! Get row count pass - optimizes COUNT(*) queries.
2//!
3//! This pass identifies simple COUNT(*) queries (without WHERE, GROUP BY,
4//! LIMIT, or SKIP) and replaces them with a direct row count lookup.
5//!
6//! Example:
7//! ```text
8//! HashAggregate(COUNT(*))     =>    GetRowCount(users)
9//!        |
10//! TableScan(users)
11//! ```
12//!
13//! This optimization is beneficial because:
14//! 1. It avoids scanning the entire table
15//! 2. Row count can be retrieved from table metadata in O(1)
16
17use crate::ast::{AggregateFunc, Expr};
18use crate::context::ExecutionContext;
19use crate::planner::PhysicalPlan;
20use alloc::boxed::Box;
21use alloc::string::String;
22
23/// A special plan node for direct row count retrieval.
24#[derive(Clone, Debug)]
25pub struct GetRowCountPlan {
26    pub table: String,
27}
28
29/// Pass that optimizes COUNT(*) queries.
30pub struct GetRowCountPass<'a> {
31    ctx: &'a ExecutionContext,
32}
33
34impl<'a> GetRowCountPass<'a> {
35    /// Creates a new GetRowCountPass with the given execution context.
36    pub fn new(ctx: &'a ExecutionContext) -> Self {
37        Self { ctx }
38    }
39
40    /// Optimizes the physical plan by replacing COUNT(*) with direct row count.
41    /// Returns the optimized plan and optionally a GetRowCountPlan if applicable.
42    pub fn optimize(&self, plan: PhysicalPlan) -> (PhysicalPlan, Option<GetRowCountPlan>) {
43        self.traverse(plan)
44    }
45
46    fn traverse(&self, plan: PhysicalPlan) -> (PhysicalPlan, Option<GetRowCountPlan>) {
47        match plan {
48            PhysicalPlan::HashAggregate {
49                input,
50                group_by,
51                aggregates,
52            } => {
53                // Check if this is a simple COUNT(*) query
54                if let Some(table) = self.is_count_star_query(&input, &group_by, &aggregates) {
55                    // Return the original plan but also indicate we can use GetRowCount
56                    return (
57                        PhysicalPlan::HashAggregate {
58                            input,
59                            group_by,
60                            aggregates,
61                        },
62                        Some(GetRowCountPlan { table }),
63                    );
64                }
65
66                // Not a COUNT(*) query, recursively optimize input
67                let (optimized_input, _) = self.traverse(*input);
68                (
69                    PhysicalPlan::HashAggregate {
70                        input: Box::new(optimized_input),
71                        group_by,
72                        aggregates,
73                    },
74                    None,
75                )
76            }
77
78            // Recursively process other nodes
79            PhysicalPlan::Filter { input, predicate } => {
80                let (optimized_input, _) = self.traverse(*input);
81                (
82                    PhysicalPlan::Filter {
83                        input: Box::new(optimized_input),
84                        predicate,
85                    },
86                    None,
87                )
88            }
89
90            PhysicalPlan::Project { input, columns } => {
91                let (optimized_input, row_count) = self.traverse(*input);
92                (
93                    PhysicalPlan::Project {
94                        input: Box::new(optimized_input),
95                        columns,
96                    },
97                    row_count,
98                )
99            }
100
101            PhysicalPlan::Sort { input, order_by } => {
102                let (optimized_input, _) = self.traverse(*input);
103                (
104                    PhysicalPlan::Sort {
105                        input: Box::new(optimized_input),
106                        order_by,
107                    },
108                    None,
109                )
110            }
111
112            PhysicalPlan::Limit {
113                input,
114                limit,
115                offset,
116            } => {
117                let (optimized_input, _) = self.traverse(*input);
118                (
119                    PhysicalPlan::Limit {
120                        input: Box::new(optimized_input),
121                        limit,
122                        offset,
123                    },
124                    None,
125                )
126            }
127
128            PhysicalPlan::HashJoin {
129                left,
130                right,
131                condition,
132                join_type,
133            } => {
134                let (left_opt, _) = self.traverse(*left);
135                let (right_opt, _) = self.traverse(*right);
136                (
137                    PhysicalPlan::HashJoin {
138                        left: Box::new(left_opt),
139                        right: Box::new(right_opt),
140                        condition,
141                        join_type,
142                    },
143                    None,
144                )
145            }
146
147            PhysicalPlan::SortMergeJoin {
148                left,
149                right,
150                condition,
151                join_type,
152            } => {
153                let (left_opt, _) = self.traverse(*left);
154                let (right_opt, _) = self.traverse(*right);
155                (
156                    PhysicalPlan::SortMergeJoin {
157                        left: Box::new(left_opt),
158                        right: Box::new(right_opt),
159                        condition,
160                        join_type,
161                    },
162                    None,
163                )
164            }
165
166            PhysicalPlan::NestedLoopJoin {
167                left,
168                right,
169                condition,
170                join_type,
171            } => {
172                let (left_opt, _) = self.traverse(*left);
173                let (right_opt, _) = self.traverse(*right);
174                (
175                    PhysicalPlan::NestedLoopJoin {
176                        left: Box::new(left_opt),
177                        right: Box::new(right_opt),
178                        condition,
179                        join_type,
180                    },
181                    None,
182                )
183            }
184
185            PhysicalPlan::IndexNestedLoopJoin {
186                outer,
187                inner_table,
188                inner_index,
189                condition,
190                join_type,
191            } => {
192                let (outer_opt, _) = self.traverse(*outer);
193                (
194                    PhysicalPlan::IndexNestedLoopJoin {
195                        outer: Box::new(outer_opt),
196                        inner_table,
197                        inner_index,
198                        condition,
199                        join_type,
200                    },
201                    None,
202                )
203            }
204
205            PhysicalPlan::CrossProduct { left, right } => {
206                let (left_opt, _) = self.traverse(*left);
207                let (right_opt, _) = self.traverse(*right);
208                (
209                    PhysicalPlan::CrossProduct {
210                        left: Box::new(left_opt),
211                        right: Box::new(right_opt),
212                    },
213                    None,
214                )
215            }
216
217            PhysicalPlan::NoOp { input } => {
218                let (optimized_input, row_count) = self.traverse(*input);
219                (
220                    PhysicalPlan::NoOp {
221                        input: Box::new(optimized_input),
222                    },
223                    row_count,
224                )
225            }
226
227            PhysicalPlan::TopN {
228                input,
229                order_by,
230                limit,
231                offset,
232            } => {
233                let (optimized_input, _) = self.traverse(*input);
234                (
235                    PhysicalPlan::TopN {
236                        input: Box::new(optimized_input),
237                        order_by,
238                        limit,
239                        offset,
240                    },
241                    None,
242                )
243            }
244
245            // Leaf nodes - no transformation
246            plan @ (PhysicalPlan::TableScan { .. }
247            | PhysicalPlan::IndexScan { .. }
248            | PhysicalPlan::IndexGet { .. }
249            | PhysicalPlan::IndexInGet { .. }
250            | PhysicalPlan::GinIndexScan { .. }
251            | PhysicalPlan::GinIndexScanMulti { .. }
252            | PhysicalPlan::Empty) => (plan, None),
253        }
254    }
255
256    /// Checks if this is a simple COUNT(*) query.
257    /// Returns the table name if it is, None otherwise.
258    fn is_count_star_query(
259        &self,
260        input: &PhysicalPlan,
261        group_by: &[Expr],
262        aggregates: &[(AggregateFunc, Expr)],
263    ) -> Option<String> {
264        // Must have no GROUP BY
265        if !group_by.is_empty() {
266            return None;
267        }
268
269        // Must have exactly one aggregate: COUNT(*)
270        if aggregates.len() != 1 {
271            return None;
272        }
273
274        let (func, _expr) = &aggregates[0];
275        if *func != AggregateFunc::Count {
276            return None;
277        }
278
279        // Check if it's COUNT(*) - represented as COUNT with a star/all expression
280        // For simplicity, we accept any COUNT aggregate here
281        // In a real implementation, we'd check for the star column specifically
282
283        // Input must be a simple TableScan (no filters, joins, etc.)
284        match input {
285            PhysicalPlan::TableScan { table } => Some(table.clone()),
286            _ => None,
287        }
288    }
289
290    /// Gets the row count for a table from the execution context.
291    pub fn get_row_count(&self, table: &str) -> usize {
292        self.ctx.row_count(table)
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::ast::Expr;
300    use crate::context::TableStats;
301
302    fn create_test_context() -> ExecutionContext {
303        let mut ctx = ExecutionContext::new();
304
305        ctx.register_table(
306            "users",
307            TableStats {
308                row_count: 1000,
309                is_sorted: false,
310                indexes: alloc::vec![],
311            },
312        );
313
314        ctx
315    }
316
317    #[test]
318    fn test_count_star_optimization() {
319        let ctx = create_test_context();
320        let pass = GetRowCountPass::new(&ctx);
321
322        // Create: HashAggregate(COUNT(*)) -> TableScan(users)
323        let plan = PhysicalPlan::HashAggregate {
324            input: Box::new(PhysicalPlan::table_scan("users")),
325            group_by: alloc::vec![],
326            aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
327        };
328
329        let (_, row_count_plan) = pass.optimize(plan);
330
331        // Should detect COUNT(*) optimization opportunity
332        assert!(row_count_plan.is_some());
333        assert_eq!(row_count_plan.unwrap().table, "users");
334    }
335
336    #[test]
337    fn test_count_with_group_by_not_optimized() {
338        let ctx = create_test_context();
339        let pass = GetRowCountPass::new(&ctx);
340
341        // Create: HashAggregate(COUNT(*), GROUP BY name) -> TableScan(users)
342        let plan = PhysicalPlan::HashAggregate {
343            input: Box::new(PhysicalPlan::table_scan("users")),
344            group_by: alloc::vec![Expr::column("users", "name", 1)],
345            aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
346        };
347
348        let (_, row_count_plan) = pass.optimize(plan);
349
350        // Should NOT detect optimization (has GROUP BY)
351        assert!(row_count_plan.is_none());
352    }
353
354    #[test]
355    fn test_count_with_filter_not_optimized() {
356        let ctx = create_test_context();
357        let pass = GetRowCountPass::new(&ctx);
358
359        // Create: HashAggregate(COUNT(*)) -> Filter -> TableScan(users)
360        let plan = PhysicalPlan::HashAggregate {
361            input: Box::new(PhysicalPlan::Filter {
362                input: Box::new(PhysicalPlan::table_scan("users")),
363                predicate: Expr::gt(Expr::column("users", "age", 1), Expr::literal(18i64)),
364            }),
365            group_by: alloc::vec![],
366            aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
367        };
368
369        let (_, row_count_plan) = pass.optimize(plan);
370
371        // Should NOT detect optimization (has Filter)
372        assert!(row_count_plan.is_none());
373    }
374
375    #[test]
376    fn test_sum_not_optimized() {
377        let ctx = create_test_context();
378        let pass = GetRowCountPass::new(&ctx);
379
380        // Create: HashAggregate(SUM(amount)) -> TableScan(users)
381        let plan = PhysicalPlan::HashAggregate {
382            input: Box::new(PhysicalPlan::table_scan("users")),
383            group_by: alloc::vec![],
384            aggregates: alloc::vec![(AggregateFunc::Sum, Expr::column("users", "amount", 2))],
385        };
386
387        let (_, row_count_plan) = pass.optimize(plan);
388
389        // Should NOT detect optimization (not COUNT)
390        assert!(row_count_plan.is_none());
391    }
392
393    #[test]
394    fn test_get_row_count() {
395        let ctx = create_test_context();
396        let pass = GetRowCountPass::new(&ctx);
397
398        assert_eq!(pass.get_row_count("users"), 1000);
399        assert_eq!(pass.get_row_count("nonexistent"), 0);
400    }
401
402    #[test]
403    fn test_multiple_aggregates_not_optimized() {
404        let ctx = create_test_context();
405        let pass = GetRowCountPass::new(&ctx);
406
407        // Create: HashAggregate(COUNT(*), SUM(amount)) -> TableScan(users)
408        let plan = PhysicalPlan::HashAggregate {
409            input: Box::new(PhysicalPlan::table_scan("users")),
410            group_by: alloc::vec![],
411            aggregates: alloc::vec![
412                (AggregateFunc::Count, Expr::literal(1i64)),
413                (AggregateFunc::Sum, Expr::column("users", "amount", 2)),
414            ],
415        };
416
417        let (_, row_count_plan) = pass.optimize(plan);
418
419        // Should NOT detect optimization (multiple aggregates)
420        assert!(row_count_plan.is_none());
421    }
422}