1use chryso_core::ast::{Expr, OrderByExpr};
2use chryso_planner::LogicalPlan;
3use std::collections::HashSet;
4
5#[derive(Default)]
6pub struct StatsRequirements {
7 pub tables: HashSet<String>,
8 pub columns: HashSet<(String, String)>,
9}
10
11pub fn collect_requirements(plan: &LogicalPlan) -> StatsRequirements {
12 let mut requirements = StatsRequirements::default();
13 collect_tables(plan, &mut requirements.tables);
14 collect_columns(plan, &requirements.tables, &mut requirements.columns);
15 requirements
16}
17
18fn collect_tables(plan: &LogicalPlan, tables: &mut HashSet<String>) {
19 match plan {
20 LogicalPlan::Scan { table } | LogicalPlan::IndexScan { table, .. } => {
21 tables.insert(table.clone());
22 }
23 LogicalPlan::Dml { .. } => {}
24 LogicalPlan::Derived { input, .. } => collect_tables(input, tables),
25 LogicalPlan::Filter { input, .. }
26 | LogicalPlan::Projection { input, .. }
27 | LogicalPlan::Aggregate { input, .. }
28 | LogicalPlan::Distinct { input }
29 | LogicalPlan::TopN { input, .. }
30 | LogicalPlan::Sort { input, .. }
31 | LogicalPlan::Limit { input, .. } => collect_tables(input, tables),
32 LogicalPlan::Join { left, right, .. } => {
33 collect_tables(left, tables);
34 collect_tables(right, tables);
35 }
36 }
37}
38
39fn collect_columns(
40 plan: &LogicalPlan,
41 tables: &HashSet<String>,
42 columns: &mut HashSet<(String, String)>,
43) {
44 match plan {
45 LogicalPlan::Scan { .. } | LogicalPlan::IndexScan { .. } | LogicalPlan::Dml { .. } => {}
46 LogicalPlan::Derived { input, .. } => collect_columns(input, tables, columns),
47 LogicalPlan::Filter { predicate, input } => {
48 collect_expr_columns(predicate, tables, columns);
49 collect_columns(input, tables, columns);
50 }
51 LogicalPlan::Projection { exprs, input } => {
52 collect_exprs_columns(exprs, tables, columns);
53 collect_columns(input, tables, columns);
54 }
55 LogicalPlan::Join {
56 on, left, right, ..
57 } => {
58 collect_expr_columns(on, tables, columns);
59 collect_columns(left, tables, columns);
60 collect_columns(right, tables, columns);
61 }
62 LogicalPlan::Aggregate {
63 group_exprs,
64 aggr_exprs,
65 input,
66 } => {
67 collect_exprs_columns(group_exprs, tables, columns);
68 collect_exprs_columns(aggr_exprs, tables, columns);
69 collect_columns(input, tables, columns);
70 }
71 LogicalPlan::Distinct { input } => collect_columns(input, tables, columns),
72 LogicalPlan::TopN {
73 order_by, input, ..
74 } => {
75 collect_order_by_columns(order_by, tables, columns);
76 collect_columns(input, tables, columns);
77 }
78 LogicalPlan::Sort { order_by, input } => {
79 collect_order_by_columns(order_by, tables, columns);
80 collect_columns(input, tables, columns);
81 }
82 LogicalPlan::Limit { input, .. } => collect_columns(input, tables, columns),
83 }
84}
85
86fn collect_exprs_columns(
87 exprs: &[Expr],
88 tables: &HashSet<String>,
89 columns: &mut HashSet<(String, String)>,
90) {
91 for expr in exprs {
92 collect_expr_columns(expr, tables, columns);
93 }
94}
95
96fn collect_order_by_columns(
97 order_by: &[OrderByExpr],
98 tables: &HashSet<String>,
99 columns: &mut HashSet<(String, String)>,
100) {
101 for item in order_by {
102 collect_expr_columns(&item.expr, tables, columns);
103 }
104}
105
106fn collect_expr_columns(
107 expr: &Expr,
108 tables: &HashSet<String>,
109 columns: &mut HashSet<(String, String)>,
110) {
111 match expr {
112 Expr::Identifier(name) => {
113 if let Some((table, column)) = name.split_once('.') {
114 if tables.contains(table) {
115 columns.insert((table.to_string(), column.to_string()));
116 }
117 } else if tables.len() == 1 {
118 if let Some(table) = tables.iter().next() {
119 columns.insert((table.clone(), name.to_string()));
120 }
121 }
122 }
123 Expr::BinaryOp { left, right, .. } => {
124 collect_expr_columns(left, tables, columns);
125 collect_expr_columns(right, tables, columns);
126 }
127 Expr::IsNull { expr, .. } => {
128 collect_expr_columns(expr, tables, columns);
129 }
130 Expr::UnaryOp { expr, .. } => collect_expr_columns(expr, tables, columns),
131 Expr::FunctionCall { args, .. } => {
132 for arg in args {
133 collect_expr_columns(arg, tables, columns);
134 }
135 }
136 Expr::WindowFunction { function, spec } => {
137 collect_expr_columns(function, tables, columns);
138 collect_exprs_columns(&spec.partition_by, tables, columns);
139 collect_order_by_columns(&spec.order_by, tables, columns);
140 }
141 Expr::Subquery(select) | Expr::Exists(select) => {
142 for item in &select.projection {
143 collect_expr_columns(&item.expr, tables, columns);
144 }
145 }
146 Expr::InSubquery { expr, subquery } => {
147 collect_expr_columns(expr, tables, columns);
148 for item in &subquery.projection {
149 collect_expr_columns(&item.expr, tables, columns);
150 }
151 }
152 Expr::Case {
153 operand,
154 when_then,
155 else_expr,
156 } => {
157 if let Some(expr) = operand {
158 collect_expr_columns(expr, tables, columns);
159 }
160 for (when_expr, then_expr) in when_then {
161 collect_expr_columns(when_expr, tables, columns);
162 collect_expr_columns(then_expr, tables, columns);
163 }
164 if let Some(expr) = else_expr {
165 collect_expr_columns(expr, tables, columns);
166 }
167 }
168 Expr::Literal(_) | Expr::Wildcard => {}
169 }
170}