1use crate::ast::{AggregateFunc, Expr};
18use crate::context::ExecutionContext;
19use crate::planner::PhysicalPlan;
20use alloc::boxed::Box;
21use alloc::string::String;
22
23#[derive(Clone, Debug)]
25pub struct GetRowCountPlan {
26 pub table: String,
27}
28
29pub struct GetRowCountPass<'a> {
31 ctx: &'a ExecutionContext,
32}
33
34impl<'a> GetRowCountPass<'a> {
35 pub fn new(ctx: &'a ExecutionContext) -> Self {
37 Self { ctx }
38 }
39
40 pub fn optimize(&self, plan: PhysicalPlan) -> (PhysicalPlan, Option<GetRowCountPlan>) {
43 self.traverse(plan)
44 }
45
46 fn traverse(&self, plan: PhysicalPlan) -> (PhysicalPlan, Option<GetRowCountPlan>) {
47 match plan {
48 PhysicalPlan::HashAggregate {
49 input,
50 group_by,
51 aggregates,
52 } => {
53 if let Some(table) = self.is_count_star_query(&input, &group_by, &aggregates) {
55 return (
57 PhysicalPlan::HashAggregate {
58 input,
59 group_by,
60 aggregates,
61 },
62 Some(GetRowCountPlan { table }),
63 );
64 }
65
66 let (optimized_input, _) = self.traverse(*input);
68 (
69 PhysicalPlan::HashAggregate {
70 input: Box::new(optimized_input),
71 group_by,
72 aggregates,
73 },
74 None,
75 )
76 }
77
78 PhysicalPlan::Filter { input, predicate } => {
80 let (optimized_input, _) = self.traverse(*input);
81 (
82 PhysicalPlan::Filter {
83 input: Box::new(optimized_input),
84 predicate,
85 },
86 None,
87 )
88 }
89
90 PhysicalPlan::Project { input, columns } => {
91 let (optimized_input, row_count) = self.traverse(*input);
92 (
93 PhysicalPlan::Project {
94 input: Box::new(optimized_input),
95 columns,
96 },
97 row_count,
98 )
99 }
100
101 PhysicalPlan::Sort { input, order_by } => {
102 let (optimized_input, _) = self.traverse(*input);
103 (
104 PhysicalPlan::Sort {
105 input: Box::new(optimized_input),
106 order_by,
107 },
108 None,
109 )
110 }
111
112 PhysicalPlan::Limit {
113 input,
114 limit,
115 offset,
116 } => {
117 let (optimized_input, _) = self.traverse(*input);
118 (
119 PhysicalPlan::Limit {
120 input: Box::new(optimized_input),
121 limit,
122 offset,
123 },
124 None,
125 )
126 }
127
128 PhysicalPlan::HashJoin {
129 left,
130 right,
131 condition,
132 join_type,
133 } => {
134 let (left_opt, _) = self.traverse(*left);
135 let (right_opt, _) = self.traverse(*right);
136 (
137 PhysicalPlan::HashJoin {
138 left: Box::new(left_opt),
139 right: Box::new(right_opt),
140 condition,
141 join_type,
142 },
143 None,
144 )
145 }
146
147 PhysicalPlan::SortMergeJoin {
148 left,
149 right,
150 condition,
151 join_type,
152 } => {
153 let (left_opt, _) = self.traverse(*left);
154 let (right_opt, _) = self.traverse(*right);
155 (
156 PhysicalPlan::SortMergeJoin {
157 left: Box::new(left_opt),
158 right: Box::new(right_opt),
159 condition,
160 join_type,
161 },
162 None,
163 )
164 }
165
166 PhysicalPlan::NestedLoopJoin {
167 left,
168 right,
169 condition,
170 join_type,
171 } => {
172 let (left_opt, _) = self.traverse(*left);
173 let (right_opt, _) = self.traverse(*right);
174 (
175 PhysicalPlan::NestedLoopJoin {
176 left: Box::new(left_opt),
177 right: Box::new(right_opt),
178 condition,
179 join_type,
180 },
181 None,
182 )
183 }
184
185 PhysicalPlan::IndexNestedLoopJoin {
186 outer,
187 inner_table,
188 inner_index,
189 condition,
190 join_type,
191 } => {
192 let (outer_opt, _) = self.traverse(*outer);
193 (
194 PhysicalPlan::IndexNestedLoopJoin {
195 outer: Box::new(outer_opt),
196 inner_table,
197 inner_index,
198 condition,
199 join_type,
200 },
201 None,
202 )
203 }
204
205 PhysicalPlan::CrossProduct { left, right } => {
206 let (left_opt, _) = self.traverse(*left);
207 let (right_opt, _) = self.traverse(*right);
208 (
209 PhysicalPlan::CrossProduct {
210 left: Box::new(left_opt),
211 right: Box::new(right_opt),
212 },
213 None,
214 )
215 }
216
217 PhysicalPlan::NoOp { input } => {
218 let (optimized_input, row_count) = self.traverse(*input);
219 (
220 PhysicalPlan::NoOp {
221 input: Box::new(optimized_input),
222 },
223 row_count,
224 )
225 }
226
227 PhysicalPlan::TopN {
228 input,
229 order_by,
230 limit,
231 offset,
232 } => {
233 let (optimized_input, _) = self.traverse(*input);
234 (
235 PhysicalPlan::TopN {
236 input: Box::new(optimized_input),
237 order_by,
238 limit,
239 offset,
240 },
241 None,
242 )
243 }
244
245 plan @ (PhysicalPlan::TableScan { .. }
247 | PhysicalPlan::IndexScan { .. }
248 | PhysicalPlan::IndexGet { .. }
249 | PhysicalPlan::IndexInGet { .. }
250 | PhysicalPlan::GinIndexScan { .. }
251 | PhysicalPlan::GinIndexScanMulti { .. }
252 | PhysicalPlan::Empty) => (plan, None),
253 }
254 }
255
256 fn is_count_star_query(
259 &self,
260 input: &PhysicalPlan,
261 group_by: &[Expr],
262 aggregates: &[(AggregateFunc, Expr)],
263 ) -> Option<String> {
264 if !group_by.is_empty() {
266 return None;
267 }
268
269 if aggregates.len() != 1 {
271 return None;
272 }
273
274 let (func, _expr) = &aggregates[0];
275 if *func != AggregateFunc::Count {
276 return None;
277 }
278
279 match input {
285 PhysicalPlan::TableScan { table } => Some(table.clone()),
286 _ => None,
287 }
288 }
289
290 pub fn get_row_count(&self, table: &str) -> usize {
292 self.ctx.row_count(table)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::ast::Expr;
300 use crate::context::TableStats;
301
302 fn create_test_context() -> ExecutionContext {
303 let mut ctx = ExecutionContext::new();
304
305 ctx.register_table(
306 "users",
307 TableStats {
308 row_count: 1000,
309 is_sorted: false,
310 indexes: alloc::vec![],
311 },
312 );
313
314 ctx
315 }
316
317 #[test]
318 fn test_count_star_optimization() {
319 let ctx = create_test_context();
320 let pass = GetRowCountPass::new(&ctx);
321
322 let plan = PhysicalPlan::HashAggregate {
324 input: Box::new(PhysicalPlan::table_scan("users")),
325 group_by: alloc::vec![],
326 aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
327 };
328
329 let (_, row_count_plan) = pass.optimize(plan);
330
331 assert!(row_count_plan.is_some());
333 assert_eq!(row_count_plan.unwrap().table, "users");
334 }
335
336 #[test]
337 fn test_count_with_group_by_not_optimized() {
338 let ctx = create_test_context();
339 let pass = GetRowCountPass::new(&ctx);
340
341 let plan = PhysicalPlan::HashAggregate {
343 input: Box::new(PhysicalPlan::table_scan("users")),
344 group_by: alloc::vec![Expr::column("users", "name", 1)],
345 aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
346 };
347
348 let (_, row_count_plan) = pass.optimize(plan);
349
350 assert!(row_count_plan.is_none());
352 }
353
354 #[test]
355 fn test_count_with_filter_not_optimized() {
356 let ctx = create_test_context();
357 let pass = GetRowCountPass::new(&ctx);
358
359 let plan = PhysicalPlan::HashAggregate {
361 input: Box::new(PhysicalPlan::Filter {
362 input: Box::new(PhysicalPlan::table_scan("users")),
363 predicate: Expr::gt(Expr::column("users", "age", 1), Expr::literal(18i64)),
364 }),
365 group_by: alloc::vec![],
366 aggregates: alloc::vec![(AggregateFunc::Count, Expr::literal(1i64))],
367 };
368
369 let (_, row_count_plan) = pass.optimize(plan);
370
371 assert!(row_count_plan.is_none());
373 }
374
375 #[test]
376 fn test_sum_not_optimized() {
377 let ctx = create_test_context();
378 let pass = GetRowCountPass::new(&ctx);
379
380 let plan = PhysicalPlan::HashAggregate {
382 input: Box::new(PhysicalPlan::table_scan("users")),
383 group_by: alloc::vec![],
384 aggregates: alloc::vec![(AggregateFunc::Sum, Expr::column("users", "amount", 2))],
385 };
386
387 let (_, row_count_plan) = pass.optimize(plan);
388
389 assert!(row_count_plan.is_none());
391 }
392
393 #[test]
394 fn test_get_row_count() {
395 let ctx = create_test_context();
396 let pass = GetRowCountPass::new(&ctx);
397
398 assert_eq!(pass.get_row_count("users"), 1000);
399 assert_eq!(pass.get_row_count("nonexistent"), 0);
400 }
401
402 #[test]
403 fn test_multiple_aggregates_not_optimized() {
404 let ctx = create_test_context();
405 let pass = GetRowCountPass::new(&ctx);
406
407 let plan = PhysicalPlan::HashAggregate {
409 input: Box::new(PhysicalPlan::table_scan("users")),
410 group_by: alloc::vec![],
411 aggregates: alloc::vec![
412 (AggregateFunc::Count, Expr::literal(1i64)),
413 (AggregateFunc::Sum, Expr::column("users", "amount", 2)),
414 ],
415 };
416
417 let (_, row_count_plan) = pass.optimize(plan);
418
419 assert!(row_count_plan.is_none());
421 }
422}