Skip to main content

chryso_optimizer/
cost.rs

1use chryso_metadata::StatsCache;
2use chryso_planner::PhysicalPlan;
3pub use chryso_planner::cost::{Cost, CostModel};
4use serde::{Deserialize, Serialize, de::DeserializeOwned};
5use std::fs;
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9#[serde(default)]
10pub struct CostModelConfig {
11    pub scan: f64,
12    pub filter: f64,
13    pub projection: f64,
14    pub join: f64,
15    pub sort: f64,
16    pub aggregate: f64,
17    pub limit: f64,
18    pub derived: f64,
19    pub dml: f64,
20    pub join_hash_multiplier: f64,
21    pub join_nested_multiplier: f64,
22    pub max_cost: f64,
23}
24
25impl Default for CostModelConfig {
26    fn default() -> Self {
27        Self {
28            scan: 1.0,
29            filter: 0.5,
30            projection: 0.1,
31            join: 5.0,
32            sort: 3.0,
33            aggregate: 4.0,
34            limit: 0.05,
35            derived: 0.1,
36            dml: 1.0,
37            join_hash_multiplier: 1.0,
38            join_nested_multiplier: 5.0,
39            max_cost: 1.0e18,
40        }
41    }
42}
43
44impl CostModelConfig {
45    pub const PARAM_SCAN: &'static str = "optimizer.cost.scan";
46    pub const PARAM_FILTER: &'static str = "optimizer.cost.filter";
47    pub const PARAM_PROJECTION: &'static str = "optimizer.cost.projection";
48    pub const PARAM_JOIN: &'static str = "optimizer.cost.join";
49    pub const PARAM_SORT: &'static str = "optimizer.cost.sort";
50    pub const PARAM_AGGREGATE: &'static str = "optimizer.cost.aggregate";
51    pub const PARAM_LIMIT: &'static str = "optimizer.cost.limit";
52    pub const PARAM_DERIVED: &'static str = "optimizer.cost.derived";
53    pub const PARAM_DML: &'static str = "optimizer.cost.dml";
54    pub const PARAM_JOIN_HASH_MULTIPLIER: &'static str = "optimizer.cost.join_hash_multiplier";
55    pub const PARAM_JOIN_NESTED_MULTIPLIER: &'static str = "optimizer.cost.join_nested_multiplier";
56    pub const PARAM_MAX_COST: &'static str = "optimizer.cost.max_cost";
57
58    pub fn load_from_path(path: impl AsRef<Path>) -> chryso_core::error::ChrysoResult<Self> {
59        let value: CostModelConfig = load_config_from_path(path, "cost config")?;
60        value.validate()?;
61        Ok(value)
62    }
63
64    pub fn validate(&self) -> chryso_core::error::ChrysoResult<()> {
65        let mut invalid = Vec::new();
66        for (name, value) in [
67            ("scan", self.scan),
68            ("filter", self.filter),
69            ("projection", self.projection),
70            ("join", self.join),
71            ("sort", self.sort),
72            ("aggregate", self.aggregate),
73            ("limit", self.limit),
74            ("derived", self.derived),
75            ("dml", self.dml),
76            ("join_hash_multiplier", self.join_hash_multiplier),
77            ("join_nested_multiplier", self.join_nested_multiplier),
78            ("max_cost", self.max_cost),
79        ] {
80            if !value.is_finite() || value <= 0.0 {
81                invalid.push(name);
82            }
83        }
84        if self.join_hash_multiplier < 1.0 {
85            invalid.push("join_hash_multiplier");
86        }
87        if self.join_nested_multiplier < 1.0 {
88            invalid.push("join_nested_multiplier");
89        }
90        if invalid.is_empty() {
91            Ok(())
92        } else {
93            Err(chryso_core::error::ChrysoError::new(format!(
94                "invalid cost config fields: {}",
95                invalid.join(", ")
96            )))
97        }
98    }
99
100    pub fn apply_system_params(
101        &self,
102        registry: &chryso_core::system_params::SystemParamRegistry,
103        tenant: Option<&str>,
104    ) -> Self {
105        let mut updated = self.clone();
106        let apply = |key: &str, target: &mut f64| {
107            if let Some(value) = registry.get_f64(tenant, key) {
108                if value.is_finite() && value > 0.0 {
109                    *target = value;
110                }
111            }
112        };
113        apply(Self::PARAM_SCAN, &mut updated.scan);
114        apply(Self::PARAM_FILTER, &mut updated.filter);
115        apply(Self::PARAM_PROJECTION, &mut updated.projection);
116        apply(Self::PARAM_JOIN, &mut updated.join);
117        apply(Self::PARAM_SORT, &mut updated.sort);
118        apply(Self::PARAM_AGGREGATE, &mut updated.aggregate);
119        apply(Self::PARAM_LIMIT, &mut updated.limit);
120        apply(Self::PARAM_DERIVED, &mut updated.derived);
121        apply(Self::PARAM_DML, &mut updated.dml);
122        apply(
123            Self::PARAM_JOIN_HASH_MULTIPLIER,
124            &mut updated.join_hash_multiplier,
125        );
126        apply(
127            Self::PARAM_JOIN_NESTED_MULTIPLIER,
128            &mut updated.join_nested_multiplier,
129        );
130        apply(Self::PARAM_MAX_COST, &mut updated.max_cost);
131        updated
132    }
133}
134
135pub struct UnitCostModel;
136
137impl CostModel for UnitCostModel {
138    fn cost(&self, plan: &PhysicalPlan) -> Cost {
139        let default = CostModelConfig::default();
140        Cost(total_weight(plan, &default))
141    }
142}
143
144impl std::fmt::Debug for UnitCostModel {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.write_str("UnitCostModel")
147    }
148}
149
150pub struct UnitCostModelWithConfig {
151    config: CostModelConfig,
152}
153
154impl UnitCostModelWithConfig {
155    pub fn new(config: CostModelConfig) -> Self {
156        Self { config }
157    }
158}
159
160impl CostModel for UnitCostModelWithConfig {
161    fn cost(&self, plan: &PhysicalPlan) -> Cost {
162        Cost(total_weight(plan, &self.config))
163    }
164}
165
166pub struct StatsCostModel<'a> {
167    stats: &'a StatsCache,
168    config: CostModelConfig,
169}
170
171impl<'a> StatsCostModel<'a> {
172    pub fn new(stats: &'a StatsCache) -> Self {
173        Self {
174            stats,
175            config: CostModelConfig::default(),
176        }
177    }
178
179    pub fn with_config(stats: &'a StatsCache, config: CostModelConfig) -> Self {
180        let validated = if config.validate().is_ok() {
181            config
182        } else {
183            CostModelConfig::default()
184        };
185        Self {
186            stats,
187            config: validated,
188        }
189    }
190}
191
192impl CostModel for StatsCostModel<'_> {
193    fn cost(&self, plan: &PhysicalPlan) -> Cost {
194        let mut cost = total_stats_cost(plan, self.stats, &self.config);
195        if !cost.is_finite() || cost > self.config.max_cost {
196            cost = self.config.max_cost;
197        }
198        Cost(cost)
199    }
200}
201
202impl std::fmt::Debug for StatsCostModel<'_> {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.write_str("StatsCostModel")
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::{CostModel, CostModelConfig, StatsCache, StatsCostModel, UnitCostModel};
211    use chryso_core::system_params::{SystemParamRegistry, SystemParamValue};
212    use chryso_metadata::ColumnStats;
213    use chryso_planner::PhysicalPlan;
214
215    #[test]
216    fn unit_cost_counts_nodes() {
217        let plan = PhysicalPlan::Filter {
218            predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
219            input: Box::new(PhysicalPlan::TableScan {
220                table: "t".to_string(),
221            }),
222        };
223        let cost = UnitCostModel.cost(&plan);
224        assert_eq!(cost.0, 1.5);
225    }
226
227    #[test]
228    fn join_algorithm_costs_differ() {
229        let left = PhysicalPlan::TableScan {
230            table: "t1".to_string(),
231        };
232        let right = PhysicalPlan::TableScan {
233            table: "t2".to_string(),
234        };
235        let hash = PhysicalPlan::Join {
236            join_type: chryso_core::ast::JoinType::Inner,
237            algorithm: chryso_planner::JoinAlgorithm::Hash,
238            left: Box::new(left.clone()),
239            right: Box::new(right.clone()),
240            on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
241        };
242        let nested = PhysicalPlan::Join {
243            join_type: chryso_core::ast::JoinType::Inner,
244            algorithm: chryso_planner::JoinAlgorithm::NestedLoop,
245            left: Box::new(left),
246            right: Box::new(right),
247            on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
248        };
249        let model = UnitCostModel;
250        assert!(model.cost(&hash).0 < model.cost(&nested).0);
251    }
252
253    #[test]
254    fn stats_cost_uses_selectivity() {
255        let plan = PhysicalPlan::Filter {
256            predicate: chryso_core::ast::Expr::BinaryOp {
257                left: Box::new(chryso_core::ast::Expr::Identifier(
258                    "sales.region".to_string(),
259                )),
260                op: chryso_core::ast::BinaryOperator::Eq,
261                right: Box::new(chryso_core::ast::Expr::Literal(
262                    chryso_core::ast::Literal::String("us".to_string()),
263                )),
264            },
265            input: Box::new(PhysicalPlan::TableScan {
266                table: "sales".to_string(),
267            }),
268        };
269        let mut stats = StatsCache::new();
270        stats.insert_table_stats("sales", chryso_metadata::TableStats { row_count: 100.0 });
271        stats.insert_column_stats(
272            "sales",
273            "region",
274            ColumnStats {
275                distinct_count: 50.0,
276                null_fraction: 0.0,
277            },
278        );
279        let model = StatsCostModel::new(&stats);
280        let selective = model.cost(&plan);
281
282        stats.insert_column_stats(
283            "sales",
284            "region",
285            ColumnStats {
286                distinct_count: 1.0,
287                null_fraction: 0.0,
288            },
289        );
290        let model = StatsCostModel::new(&stats);
291        let non_selective = model.cost(&plan);
292        assert!(selective.0 < non_selective.0);
293    }
294
295    #[test]
296    fn config_validation_rejects_non_positive() {
297        let mut config = CostModelConfig::default();
298        config.join = 0.0;
299        let err = config.validate().expect_err("invalid config");
300        assert!(err.to_string().contains("join"));
301    }
302
303    #[test]
304    fn system_params_override_cost_config() {
305        let registry = SystemParamRegistry::new();
306        registry.set_default_param(CostModelConfig::PARAM_FILTER, SystemParamValue::Float(0.9));
307        let config = CostModelConfig::default();
308        let updated = config.apply_system_params(&registry, Some("tenant"));
309        assert_eq!(updated.filter, 0.9);
310    }
311
312    #[test]
313    fn system_params_ignore_invalid_values() {
314        let registry = SystemParamRegistry::new();
315        registry.set_default_param(CostModelConfig::PARAM_SORT, SystemParamValue::Float(0.0));
316        let config = CostModelConfig::default();
317        let updated = config.apply_system_params(&registry, Some("tenant"));
318        assert_eq!(updated.sort, config.sort);
319    }
320}
321
322pub(crate) fn load_config_from_path<T: DeserializeOwned>(
323    path: impl AsRef<Path>,
324    label: &str,
325) -> chryso_core::error::ChrysoResult<T> {
326    let content = fs::read_to_string(path.as_ref()).map_err(|err| {
327        chryso_core::error::ChrysoError::new(format!("read {label} failed: {err}"))
328    })?;
329    if path
330        .as_ref()
331        .extension()
332        .and_then(|ext| ext.to_str())
333        .map(|ext| ext.eq_ignore_ascii_case("toml"))
334        .unwrap_or(false)
335    {
336        toml::from_str(&content).map_err(|err| {
337            chryso_core::error::ChrysoError::new(format!("parse toml {label} failed: {err}"))
338        })
339    } else {
340        serde_json::from_str(&content).map_err(|err| {
341            chryso_core::error::ChrysoError::new(format!("parse json {label} failed: {err}"))
342        })
343    }
344}
345
346fn local_join_penalty(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
347    match plan {
348        PhysicalPlan::Join { algorithm, .. } => match algorithm {
349            chryso_planner::JoinAlgorithm::Hash => {
350                config.join * (config.join_hash_multiplier - 1.0)
351            }
352            chryso_planner::JoinAlgorithm::NestedLoop => {
353                config.join * (config.join_nested_multiplier - 1.0)
354            }
355        },
356        _ => 0.0,
357    }
358}
359
360fn node_weight(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
361    match plan {
362        PhysicalPlan::TableScan { .. } | PhysicalPlan::IndexScan { .. } => config.scan,
363        PhysicalPlan::Filter { .. } => config.filter,
364        PhysicalPlan::Projection { .. } => config.projection,
365        PhysicalPlan::Join { .. } => config.join,
366        PhysicalPlan::Aggregate { .. } => config.aggregate,
367        PhysicalPlan::Distinct { .. } => config.aggregate,
368        PhysicalPlan::TopN { .. } => config.sort,
369        PhysicalPlan::Sort { .. } => config.sort,
370        PhysicalPlan::Limit { .. } => config.limit,
371        PhysicalPlan::Derived { .. } => config.derived,
372        PhysicalPlan::Dml { .. } => config.dml,
373    }
374}
375
376fn total_weight(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
377    // Unit cost uses configurable weights for every node in the tree.
378    let base = node_weight(plan, config) + local_join_penalty(plan, config);
379    let children = match plan {
380        PhysicalPlan::Join { left, right, .. } => {
381            total_weight(left, config) + total_weight(right, config)
382        }
383        PhysicalPlan::Filter { input, .. }
384        | PhysicalPlan::Projection { input, .. }
385        | PhysicalPlan::Aggregate { input, .. }
386        | PhysicalPlan::Distinct { input }
387        | PhysicalPlan::TopN { input, .. }
388        | PhysicalPlan::Sort { input, .. }
389        | PhysicalPlan::Limit { input, .. }
390        | PhysicalPlan::Derived { input, .. } => total_weight(input, config),
391        PhysicalPlan::TableScan { .. }
392        | PhysicalPlan::IndexScan { .. }
393        | PhysicalPlan::Dml { .. } => 0.0,
394    };
395    base + children
396}
397
398fn total_stats_cost(plan: &PhysicalPlan, stats: &StatsCache, config: &CostModelConfig) -> f64 {
399    // Stats cost applies selectivity per node and accumulates subtree contributions.
400    let rows = estimate_rows(plan, stats);
401    let mut cost = rows * node_weight(plan, config) + local_join_penalty(plan, config);
402    cost += match plan {
403        PhysicalPlan::Join { left, right, .. } => {
404            total_stats_cost(left, stats, config) + total_stats_cost(right, stats, config)
405        }
406        PhysicalPlan::Filter { input, .. }
407        | PhysicalPlan::Projection { input, .. }
408        | PhysicalPlan::Aggregate { input, .. }
409        | PhysicalPlan::Distinct { input }
410        | PhysicalPlan::TopN { input, .. }
411        | PhysicalPlan::Sort { input, .. }
412        | PhysicalPlan::Limit { input, .. }
413        | PhysicalPlan::Derived { input, .. } => total_stats_cost(input, stats, config),
414        PhysicalPlan::TableScan { .. }
415        | PhysicalPlan::IndexScan { .. }
416        | PhysicalPlan::Dml { .. } => 0.0,
417    };
418    cost
419}
420
421fn estimate_rows(plan: &PhysicalPlan, stats: &StatsCache) -> f64 {
422    match plan {
423        PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => stats
424            .table_stats(table)
425            .map(|stats| stats.row_count)
426            .unwrap_or(1000.0),
427        PhysicalPlan::Dml { .. } => 1.0,
428        PhysicalPlan::Derived { input, .. } => estimate_rows(input, stats),
429        PhysicalPlan::Filter { predicate, input } => {
430            let base = estimate_rows(input, stats);
431            let table = single_table_name(input);
432            base * estimate_selectivity(predicate, stats, table.as_deref())
433        }
434        PhysicalPlan::Projection { input, .. } => estimate_rows(input, stats),
435        PhysicalPlan::Join { left, right, .. } => {
436            estimate_rows(left, stats) * estimate_rows(right, stats) * 0.1
437        }
438        PhysicalPlan::Aggregate { input, .. } => (estimate_rows(input, stats) * 0.1).max(1.0),
439        PhysicalPlan::Distinct { input } => (estimate_rows(input, stats) * 0.3).max(1.0),
440        PhysicalPlan::TopN { limit, input, .. } => estimate_rows(input, stats).min(*limit as f64),
441        PhysicalPlan::Sort { input, .. } => estimate_rows(input, stats),
442        PhysicalPlan::Limit { limit, input, .. } => match limit {
443            Some(limit) => estimate_rows(input, stats).min(*limit as f64),
444            None => estimate_rows(input, stats),
445        },
446    }
447}
448
449fn estimate_selectivity(
450    predicate: &chryso_core::ast::Expr,
451    stats: &StatsCache,
452    table: Option<&str>,
453) -> f64 {
454    use chryso_core::ast::{BinaryOperator, Expr};
455    match predicate {
456        Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::And) => {
457            estimate_selectivity(left, stats, table) * estimate_selectivity(right, stats, table)
458        }
459        Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::Or) => {
460            let left_sel = estimate_selectivity(left, stats, table);
461            let right_sel = estimate_selectivity(right, stats, table);
462            (left_sel + right_sel - left_sel * right_sel).min(1.0)
463        }
464        Expr::IsNull { expr, negated } => {
465            let (table_name, column_name) = match expr.as_ref() {
466                Expr::Identifier(name) => match name.split_once('.') {
467                    Some((prefix, column)) => (Some(prefix), column),
468                    None => (table, name.as_str()),
469                },
470                _ => (table, ""),
471            };
472            if let (Some(table_name), column_name) = (table_name, column_name) {
473                if !column_name.is_empty() {
474                    if let Some(stats) = stats.column_stats(table_name, column_name) {
475                        let base = stats.null_fraction;
476                        return if *negated { 1.0 - base } else { base };
477                    }
478                }
479            }
480            if *negated { 0.9 } else { 0.1 }
481        }
482        Expr::BinaryOp { left, op, right } => {
483            if let Some(selectivity) = estimate_eq_selectivity(left, right, stats, table) {
484                match op {
485                    BinaryOperator::Eq => selectivity,
486                    BinaryOperator::NotEq => (1.0 - selectivity).max(0.0),
487                    BinaryOperator::Lt
488                    | BinaryOperator::LtEq
489                    | BinaryOperator::Gt
490                    | BinaryOperator::GtEq => 0.3,
491                    _ => 0.3,
492                }
493            } else {
494                0.3
495            }
496        }
497        _ => 0.5,
498    }
499}
500
501fn estimate_eq_selectivity(
502    left: &chryso_core::ast::Expr,
503    right: &chryso_core::ast::Expr,
504    stats: &StatsCache,
505    table: Option<&str>,
506) -> Option<f64> {
507    let (ident, literal) = match (left, right) {
508        (chryso_core::ast::Expr::Identifier(name), chryso_core::ast::Expr::Literal(_)) => {
509            (name, right)
510        }
511        (chryso_core::ast::Expr::Literal(_), chryso_core::ast::Expr::Identifier(name)) => {
512            (name, left)
513        }
514        _ => return None,
515    };
516    let _ = literal;
517    let (table_name, column_name) = match ident.split_once('.') {
518        Some((prefix, column)) => (Some(prefix), column),
519        None => (table, ident.as_str()),
520    };
521    let table_name = table_name?;
522    let stats = stats.column_stats(table_name, column_name)?;
523    let distinct = stats.distinct_count.max(1.0);
524    Some(1.0 / distinct)
525}
526
527fn single_table_name(plan: &PhysicalPlan) -> Option<String> {
528    match plan {
529        PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => {
530            Some(table.clone())
531        }
532        PhysicalPlan::Filter { input, .. }
533        | PhysicalPlan::Projection { input, .. }
534        | PhysicalPlan::Aggregate { input, .. }
535        | PhysicalPlan::Distinct { input }
536        | PhysicalPlan::TopN { input, .. }
537        | PhysicalPlan::Sort { input, .. }
538        | PhysicalPlan::Limit { input, .. }
539        | PhysicalPlan::Derived { input, .. } => single_table_name(input),
540        PhysicalPlan::Join { .. } | PhysicalPlan::Dml { .. } => None,
541    }
542}