1pub use chryso_planner::cost::{Cost, CostModel};
2use chryso_metadata::StatsCache;
3use chryso_planner::PhysicalPlan;
4
5pub struct UnitCostModel;
6
7impl CostModel for UnitCostModel {
8 fn cost(&self, plan: &PhysicalPlan) -> Cost {
9 Cost(plan.node_count() as f64 + join_penalty(plan))
10 }
11}
12
13impl std::fmt::Debug for UnitCostModel {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15 f.write_str("UnitCostModel")
16 }
17}
18
19pub struct StatsCostModel<'a> {
20 stats: &'a StatsCache,
21}
22
23impl<'a> StatsCostModel<'a> {
24 pub fn new(stats: &'a StatsCache) -> Self {
25 Self { stats }
26 }
27}
28
29impl CostModel for StatsCostModel<'_> {
30 fn cost(&self, plan: &PhysicalPlan) -> Cost {
31 Cost(estimate_rows(plan, self.stats) + join_penalty(plan))
32 }
33}
34
35impl std::fmt::Debug for StatsCostModel<'_> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.write_str("StatsCostModel")
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::{CostModel, StatsCostModel, StatsCache, UnitCostModel};
44 use chryso_metadata::ColumnStats;
45 use chryso_planner::PhysicalPlan;
46
47 #[test]
48 fn unit_cost_counts_nodes() {
49 let plan = PhysicalPlan::Filter {
50 predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
51 input: Box::new(PhysicalPlan::TableScan {
52 table: "t".to_string(),
53 }),
54 };
55 let cost = UnitCostModel.cost(&plan);
56 assert_eq!(cost.0, 2.0);
57 }
58
59 #[test]
60 fn join_algorithm_costs_differ() {
61 let left = PhysicalPlan::TableScan {
62 table: "t1".to_string(),
63 };
64 let right = PhysicalPlan::TableScan {
65 table: "t2".to_string(),
66 };
67 let hash = PhysicalPlan::Join {
68 join_type: chryso_core::ast::JoinType::Inner,
69 algorithm: chryso_planner::JoinAlgorithm::Hash,
70 left: Box::new(left.clone()),
71 right: Box::new(right.clone()),
72 on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
73 };
74 let nested = PhysicalPlan::Join {
75 join_type: chryso_core::ast::JoinType::Inner,
76 algorithm: chryso_planner::JoinAlgorithm::NestedLoop,
77 left: Box::new(left),
78 right: Box::new(right),
79 on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
80 };
81 let model = UnitCostModel;
82 assert!(model.cost(&hash).0 < model.cost(&nested).0);
83 }
84
85 #[test]
86 fn stats_cost_uses_selectivity() {
87 let plan = PhysicalPlan::Filter {
88 predicate: chryso_core::ast::Expr::BinaryOp {
89 left: Box::new(chryso_core::ast::Expr::Identifier(
90 "sales.region".to_string(),
91 )),
92 op: chryso_core::ast::BinaryOperator::Eq,
93 right: Box::new(chryso_core::ast::Expr::Literal(
94 chryso_core::ast::Literal::String("us".to_string()),
95 )),
96 },
97 input: Box::new(PhysicalPlan::TableScan {
98 table: "sales".to_string(),
99 }),
100 };
101 let mut stats = StatsCache::new();
102 stats.insert_table_stats("sales", chryso_metadata::TableStats { row_count: 100.0 });
103 stats.insert_column_stats(
104 "sales",
105 "region",
106 ColumnStats {
107 distinct_count: 50.0,
108 null_fraction: 0.0,
109 },
110 );
111 let model = StatsCostModel::new(&stats);
112 let cost = model.cost(&plan);
113 assert!(cost.0 < 5.0);
114 }
115}
116
117fn join_penalty(plan: &PhysicalPlan) -> f64 {
118 match plan {
119 PhysicalPlan::Join { algorithm, .. } => match algorithm {
120 chryso_planner::JoinAlgorithm::Hash => 1.0,
121 chryso_planner::JoinAlgorithm::NestedLoop => 5.0,
122 },
123 _ => 0.0,
124 }
125}
126
127fn estimate_rows(plan: &PhysicalPlan, stats: &StatsCache) -> f64 {
128 match plan {
129 PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => stats
130 .table_stats(table)
131 .map(|stats| stats.row_count)
132 .unwrap_or(1000.0),
133 PhysicalPlan::Dml { .. } => 1.0,
134 PhysicalPlan::Derived { input, .. } => estimate_rows(input, stats),
135 PhysicalPlan::Filter { predicate, input } => {
136 let base = estimate_rows(input, stats);
137 let table = single_table_name(input);
138 base * estimate_selectivity(predicate, stats, table.as_deref())
139 }
140 PhysicalPlan::Projection { input, .. } => estimate_rows(input, stats),
141 PhysicalPlan::Join { left, right, .. } => {
142 estimate_rows(left, stats) * estimate_rows(right, stats) * 0.1
143 }
144 PhysicalPlan::Aggregate { input, .. } => (estimate_rows(input, stats) * 0.1).max(1.0),
145 PhysicalPlan::Distinct { input } => (estimate_rows(input, stats) * 0.3).max(1.0),
146 PhysicalPlan::TopN { limit, input, .. } => estimate_rows(input, stats).min(*limit as f64),
147 PhysicalPlan::Sort { input, .. } => estimate_rows(input, stats),
148 PhysicalPlan::Limit { limit, input, .. } => match limit {
149 Some(limit) => estimate_rows(input, stats).min(*limit as f64),
150 None => estimate_rows(input, stats),
151 },
152 }
153}
154
155fn estimate_selectivity(
156 predicate: &chryso_core::ast::Expr,
157 stats: &StatsCache,
158 table: Option<&str>,
159) -> f64 {
160 use chryso_core::ast::{BinaryOperator, Expr};
161 match predicate {
162 Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::And) => {
163 estimate_selectivity(left, stats, table) * estimate_selectivity(right, stats, table)
164 }
165 Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::Or) => {
166 let left_sel = estimate_selectivity(left, stats, table);
167 let right_sel = estimate_selectivity(right, stats, table);
168 (left_sel + right_sel - left_sel * right_sel).min(1.0)
169 }
170 Expr::IsNull { expr, negated } => {
171 let (table_name, column_name) = match expr.as_ref() {
172 Expr::Identifier(name) => match name.split_once('.') {
173 Some((prefix, column)) => (Some(prefix), column),
174 None => (table, name.as_str()),
175 },
176 _ => (table, ""),
177 };
178 if let (Some(table_name), column_name) = (table_name, column_name) {
179 if !column_name.is_empty() {
180 if let Some(stats) = stats.column_stats(table_name, column_name) {
181 let base = stats.null_fraction;
182 return if *negated { 1.0 - base } else { base };
183 }
184 }
185 }
186 if *negated { 0.9 } else { 0.1 }
187 }
188 Expr::BinaryOp { left, op, right } => {
189 if let Some(selectivity) = estimate_eq_selectivity(left, right, stats, table) {
190 match op {
191 BinaryOperator::Eq => selectivity,
192 BinaryOperator::NotEq => (1.0 - selectivity).max(0.0),
193 BinaryOperator::Lt
194 | BinaryOperator::LtEq
195 | BinaryOperator::Gt
196 | BinaryOperator::GtEq => 0.3,
197 _ => 0.3,
198 }
199 } else {
200 0.3
201 }
202 }
203 _ => 0.5,
204 }
205}
206
207fn estimate_eq_selectivity(
208 left: &chryso_core::ast::Expr,
209 right: &chryso_core::ast::Expr,
210 stats: &StatsCache,
211 table: Option<&str>,
212) -> Option<f64> {
213 let (ident, literal) = match (left, right) {
214 (chryso_core::ast::Expr::Identifier(name), chryso_core::ast::Expr::Literal(_)) => {
215 (name, right)
216 }
217 (chryso_core::ast::Expr::Literal(_), chryso_core::ast::Expr::Identifier(name)) => {
218 (name, left)
219 }
220 _ => return None,
221 };
222 let _ = literal;
223 let (table_name, column_name) = match ident.split_once('.') {
224 Some((prefix, column)) => (Some(prefix), column),
225 None => (table, ident.as_str()),
226 };
227 let table_name = table_name?;
228 let stats = stats.column_stats(table_name, column_name)?;
229 let distinct = stats.distinct_count.max(1.0);
230 Some(1.0 / distinct)
231}
232
233fn single_table_name(plan: &PhysicalPlan) -> Option<String> {
234 match plan {
235 PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => {
236 Some(table.clone())
237 }
238 PhysicalPlan::Filter { input, .. }
239 | PhysicalPlan::Projection { input, .. }
240 | PhysicalPlan::Aggregate { input, .. }
241 | PhysicalPlan::Distinct { input }
242 | PhysicalPlan::TopN { input, .. }
243 | PhysicalPlan::Sort { input, .. }
244 | PhysicalPlan::Limit { input, .. }
245 | PhysicalPlan::Derived { input, .. } => single_table_name(input),
246 PhysicalPlan::Join { .. } | PhysicalPlan::Dml { .. } => None,
247 }
248}