chryso_optimizer/
cost.rs

1pub use chryso_planner::cost::{Cost, CostModel};
2use chryso_metadata::StatsCache;
3use chryso_planner::PhysicalPlan;
4
5pub struct UnitCostModel;
6
7impl CostModel for UnitCostModel {
8    fn cost(&self, plan: &PhysicalPlan) -> Cost {
9        Cost(plan.node_count() as f64 + join_penalty(plan))
10    }
11}
12
13impl std::fmt::Debug for UnitCostModel {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        f.write_str("UnitCostModel")
16    }
17}
18
19pub struct StatsCostModel<'a> {
20    stats: &'a StatsCache,
21}
22
23impl<'a> StatsCostModel<'a> {
24    pub fn new(stats: &'a StatsCache) -> Self {
25        Self { stats }
26    }
27}
28
29impl CostModel for StatsCostModel<'_> {
30    fn cost(&self, plan: &PhysicalPlan) -> Cost {
31        Cost(estimate_rows(plan, self.stats) + join_penalty(plan))
32    }
33}
34
35impl std::fmt::Debug for StatsCostModel<'_> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.write_str("StatsCostModel")
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::{CostModel, StatsCostModel, StatsCache, UnitCostModel};
44    use chryso_metadata::ColumnStats;
45    use chryso_planner::PhysicalPlan;
46
47    #[test]
48    fn unit_cost_counts_nodes() {
49        let plan = PhysicalPlan::Filter {
50            predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
51            input: Box::new(PhysicalPlan::TableScan {
52                table: "t".to_string(),
53            }),
54        };
55        let cost = UnitCostModel.cost(&plan);
56        assert_eq!(cost.0, 2.0);
57    }
58
59    #[test]
60    fn join_algorithm_costs_differ() {
61        let left = PhysicalPlan::TableScan {
62            table: "t1".to_string(),
63        };
64        let right = PhysicalPlan::TableScan {
65            table: "t2".to_string(),
66        };
67        let hash = PhysicalPlan::Join {
68            join_type: chryso_core::ast::JoinType::Inner,
69            algorithm: chryso_planner::JoinAlgorithm::Hash,
70            left: Box::new(left.clone()),
71            right: Box::new(right.clone()),
72            on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
73        };
74        let nested = PhysicalPlan::Join {
75            join_type: chryso_core::ast::JoinType::Inner,
76            algorithm: chryso_planner::JoinAlgorithm::NestedLoop,
77            left: Box::new(left),
78            right: Box::new(right),
79            on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
80        };
81        let model = UnitCostModel;
82        assert!(model.cost(&hash).0 < model.cost(&nested).0);
83    }
84
85    #[test]
86    fn stats_cost_uses_selectivity() {
87        let plan = PhysicalPlan::Filter {
88            predicate: chryso_core::ast::Expr::BinaryOp {
89                left: Box::new(chryso_core::ast::Expr::Identifier(
90                    "sales.region".to_string(),
91                )),
92                op: chryso_core::ast::BinaryOperator::Eq,
93                right: Box::new(chryso_core::ast::Expr::Literal(
94                    chryso_core::ast::Literal::String("us".to_string()),
95                )),
96            },
97            input: Box::new(PhysicalPlan::TableScan {
98                table: "sales".to_string(),
99            }),
100        };
101        let mut stats = StatsCache::new();
102        stats.insert_table_stats("sales", chryso_metadata::TableStats { row_count: 100.0 });
103        stats.insert_column_stats(
104            "sales",
105            "region",
106            ColumnStats {
107                distinct_count: 50.0,
108                null_fraction: 0.0,
109            },
110        );
111        let model = StatsCostModel::new(&stats);
112        let cost = model.cost(&plan);
113        assert!(cost.0 < 5.0);
114    }
115}
116
117fn join_penalty(plan: &PhysicalPlan) -> f64 {
118    match plan {
119        PhysicalPlan::Join { algorithm, .. } => match algorithm {
120            chryso_planner::JoinAlgorithm::Hash => 1.0,
121            chryso_planner::JoinAlgorithm::NestedLoop => 5.0,
122        },
123        _ => 0.0,
124    }
125}
126
127fn estimate_rows(plan: &PhysicalPlan, stats: &StatsCache) -> f64 {
128    match plan {
129        PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => stats
130            .table_stats(table)
131            .map(|stats| stats.row_count)
132            .unwrap_or(1000.0),
133        PhysicalPlan::Dml { .. } => 1.0,
134        PhysicalPlan::Derived { input, .. } => estimate_rows(input, stats),
135        PhysicalPlan::Filter { predicate, input } => {
136            let base = estimate_rows(input, stats);
137            let table = single_table_name(input);
138            base * estimate_selectivity(predicate, stats, table.as_deref())
139        }
140        PhysicalPlan::Projection { input, .. } => estimate_rows(input, stats),
141        PhysicalPlan::Join { left, right, .. } => {
142            estimate_rows(left, stats) * estimate_rows(right, stats) * 0.1
143        }
144        PhysicalPlan::Aggregate { input, .. } => (estimate_rows(input, stats) * 0.1).max(1.0),
145        PhysicalPlan::Distinct { input } => (estimate_rows(input, stats) * 0.3).max(1.0),
146        PhysicalPlan::TopN { limit, input, .. } => estimate_rows(input, stats).min(*limit as f64),
147        PhysicalPlan::Sort { input, .. } => estimate_rows(input, stats),
148        PhysicalPlan::Limit { limit, input, .. } => match limit {
149            Some(limit) => estimate_rows(input, stats).min(*limit as f64),
150            None => estimate_rows(input, stats),
151        },
152    }
153}
154
155fn estimate_selectivity(
156    predicate: &chryso_core::ast::Expr,
157    stats: &StatsCache,
158    table: Option<&str>,
159) -> f64 {
160    use chryso_core::ast::{BinaryOperator, Expr};
161    match predicate {
162        Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::And) => {
163            estimate_selectivity(left, stats, table) * estimate_selectivity(right, stats, table)
164        }
165        Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::Or) => {
166            let left_sel = estimate_selectivity(left, stats, table);
167            let right_sel = estimate_selectivity(right, stats, table);
168            (left_sel + right_sel - left_sel * right_sel).min(1.0)
169        }
170        Expr::IsNull { expr, negated } => {
171            let (table_name, column_name) = match expr.as_ref() {
172                Expr::Identifier(name) => match name.split_once('.') {
173                    Some((prefix, column)) => (Some(prefix), column),
174                    None => (table, name.as_str()),
175                },
176                _ => (table, ""),
177            };
178            if let (Some(table_name), column_name) = (table_name, column_name) {
179                if !column_name.is_empty() {
180                    if let Some(stats) = stats.column_stats(table_name, column_name) {
181                        let base = stats.null_fraction;
182                        return if *negated { 1.0 - base } else { base };
183                    }
184                }
185            }
186            if *negated { 0.9 } else { 0.1 }
187        }
188        Expr::BinaryOp { left, op, right } => {
189            if let Some(selectivity) = estimate_eq_selectivity(left, right, stats, table) {
190                match op {
191                    BinaryOperator::Eq => selectivity,
192                    BinaryOperator::NotEq => (1.0 - selectivity).max(0.0),
193                    BinaryOperator::Lt
194                    | BinaryOperator::LtEq
195                    | BinaryOperator::Gt
196                    | BinaryOperator::GtEq => 0.3,
197                    _ => 0.3,
198                }
199            } else {
200                0.3
201            }
202        }
203        _ => 0.5,
204    }
205}
206
207fn estimate_eq_selectivity(
208    left: &chryso_core::ast::Expr,
209    right: &chryso_core::ast::Expr,
210    stats: &StatsCache,
211    table: Option<&str>,
212) -> Option<f64> {
213    let (ident, literal) = match (left, right) {
214        (chryso_core::ast::Expr::Identifier(name), chryso_core::ast::Expr::Literal(_)) => {
215            (name, right)
216        }
217        (chryso_core::ast::Expr::Literal(_), chryso_core::ast::Expr::Identifier(name)) => {
218            (name, left)
219        }
220        _ => return None,
221    };
222    let _ = literal;
223    let (table_name, column_name) = match ident.split_once('.') {
224        Some((prefix, column)) => (Some(prefix), column),
225        None => (table, ident.as_str()),
226    };
227    let table_name = table_name?;
228    let stats = stats.column_stats(table_name, column_name)?;
229    let distinct = stats.distinct_count.max(1.0);
230    Some(1.0 / distinct)
231}
232
233fn single_table_name(plan: &PhysicalPlan) -> Option<String> {
234    match plan {
235        PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => {
236            Some(table.clone())
237        }
238        PhysicalPlan::Filter { input, .. }
239        | PhysicalPlan::Projection { input, .. }
240        | PhysicalPlan::Aggregate { input, .. }
241        | PhysicalPlan::Distinct { input }
242        | PhysicalPlan::TopN { input, .. }
243        | PhysicalPlan::Sort { input, .. }
244        | PhysicalPlan::Limit { input, .. }
245        | PhysicalPlan::Derived { input, .. } => single_table_name(input),
246        PhysicalPlan::Join { .. } | PhysicalPlan::Dml { .. } => None,
247    }
248}