Skip to main content

chryso_optimizer/
stats_collect.rs

1use chryso_core::ast::{Expr, OrderByExpr};
2use chryso_planner::LogicalPlan;
3use std::collections::HashSet;
4
5#[derive(Default)]
6pub struct StatsRequirements {
7    pub tables: HashSet<String>,
8    pub columns: HashSet<(String, String)>,
9}
10
11pub fn collect_requirements(plan: &LogicalPlan) -> StatsRequirements {
12    let mut requirements = StatsRequirements::default();
13    collect_tables(plan, &mut requirements.tables);
14    collect_columns(plan, &requirements.tables, &mut requirements.columns);
15    requirements
16}
17
18fn collect_tables(plan: &LogicalPlan, tables: &mut HashSet<String>) {
19    match plan {
20        LogicalPlan::Scan { table } | LogicalPlan::IndexScan { table, .. } => {
21            tables.insert(table.clone());
22        }
23        LogicalPlan::Dml { .. } => {}
24        LogicalPlan::Derived { input, .. } => collect_tables(input, tables),
25        LogicalPlan::Filter { input, .. }
26        | LogicalPlan::Projection { input, .. }
27        | LogicalPlan::Aggregate { input, .. }
28        | LogicalPlan::Distinct { input }
29        | LogicalPlan::TopN { input, .. }
30        | LogicalPlan::Sort { input, .. }
31        | LogicalPlan::Limit { input, .. } => collect_tables(input, tables),
32        LogicalPlan::Join { left, right, .. } => {
33            collect_tables(left, tables);
34            collect_tables(right, tables);
35        }
36    }
37}
38
39fn collect_columns(
40    plan: &LogicalPlan,
41    tables: &HashSet<String>,
42    columns: &mut HashSet<(String, String)>,
43) {
44    match plan {
45        LogicalPlan::Scan { .. } | LogicalPlan::IndexScan { .. } | LogicalPlan::Dml { .. } => {}
46        LogicalPlan::Derived { input, .. } => collect_columns(input, tables, columns),
47        LogicalPlan::Filter { predicate, input } => {
48            collect_expr_columns(predicate, tables, columns);
49            collect_columns(input, tables, columns);
50        }
51        LogicalPlan::Projection { exprs, input } => {
52            collect_exprs_columns(exprs, tables, columns);
53            collect_columns(input, tables, columns);
54        }
55        LogicalPlan::Join {
56            on, left, right, ..
57        } => {
58            collect_expr_columns(on, tables, columns);
59            collect_columns(left, tables, columns);
60            collect_columns(right, tables, columns);
61        }
62        LogicalPlan::Aggregate {
63            group_exprs,
64            aggr_exprs,
65            input,
66        } => {
67            collect_exprs_columns(group_exprs, tables, columns);
68            collect_exprs_columns(aggr_exprs, tables, columns);
69            collect_columns(input, tables, columns);
70        }
71        LogicalPlan::Distinct { input } => collect_columns(input, tables, columns),
72        LogicalPlan::TopN {
73            order_by, input, ..
74        } => {
75            collect_order_by_columns(order_by, tables, columns);
76            collect_columns(input, tables, columns);
77        }
78        LogicalPlan::Sort { order_by, input } => {
79            collect_order_by_columns(order_by, tables, columns);
80            collect_columns(input, tables, columns);
81        }
82        LogicalPlan::Limit { input, .. } => collect_columns(input, tables, columns),
83    }
84}
85
86fn collect_exprs_columns(
87    exprs: &[Expr],
88    tables: &HashSet<String>,
89    columns: &mut HashSet<(String, String)>,
90) {
91    for expr in exprs {
92        collect_expr_columns(expr, tables, columns);
93    }
94}
95
96fn collect_order_by_columns(
97    order_by: &[OrderByExpr],
98    tables: &HashSet<String>,
99    columns: &mut HashSet<(String, String)>,
100) {
101    for item in order_by {
102        collect_expr_columns(&item.expr, tables, columns);
103    }
104}
105
106fn collect_expr_columns(
107    expr: &Expr,
108    tables: &HashSet<String>,
109    columns: &mut HashSet<(String, String)>,
110) {
111    match expr {
112        Expr::Identifier(name) => {
113            if let Some((table, column)) = name.split_once('.') {
114                if tables.contains(table) {
115                    columns.insert((table.to_string(), column.to_string()));
116                }
117            } else if tables.len() == 1 {
118                if let Some(table) = tables.iter().next() {
119                    columns.insert((table.clone(), name.to_string()));
120                }
121            }
122        }
123        Expr::BinaryOp { left, right, .. } => {
124            collect_expr_columns(left, tables, columns);
125            collect_expr_columns(right, tables, columns);
126        }
127        Expr::IsNull { expr, .. } => {
128            collect_expr_columns(expr, tables, columns);
129        }
130        Expr::UnaryOp { expr, .. } => collect_expr_columns(expr, tables, columns),
131        Expr::FunctionCall { args, .. } => {
132            for arg in args {
133                collect_expr_columns(arg, tables, columns);
134            }
135        }
136        Expr::WindowFunction { function, spec } => {
137            collect_expr_columns(function, tables, columns);
138            collect_exprs_columns(&spec.partition_by, tables, columns);
139            collect_order_by_columns(&spec.order_by, tables, columns);
140        }
141        Expr::Subquery(select) | Expr::Exists(select) => {
142            for item in &select.projection {
143                collect_expr_columns(&item.expr, tables, columns);
144            }
145        }
146        Expr::InSubquery { expr, subquery } => {
147            collect_expr_columns(expr, tables, columns);
148            for item in &subquery.projection {
149                collect_expr_columns(&item.expr, tables, columns);
150            }
151        }
152        Expr::Case {
153            operand,
154            when_then,
155            else_expr,
156        } => {
157            if let Some(expr) = operand {
158                collect_expr_columns(expr, tables, columns);
159            }
160            for (when_expr, then_expr) in when_then {
161                collect_expr_columns(when_expr, tables, columns);
162                collect_expr_columns(then_expr, tables, columns);
163            }
164            if let Some(expr) = else_expr {
165                collect_expr_columns(expr, tables, columns);
166            }
167        }
168        Expr::Literal(_) | Expr::Wildcard => {}
169    }
170}