1use chryso_metadata::StatsCache;
2use chryso_planner::PhysicalPlan;
3pub use chryso_planner::cost::{Cost, CostModel};
4use serde::{Deserialize, Serialize, de::DeserializeOwned};
5use std::fs;
6use std::path::Path;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9#[serde(default)]
10pub struct CostModelConfig {
11 pub scan: f64,
12 pub filter: f64,
13 pub projection: f64,
14 pub join: f64,
15 pub sort: f64,
16 pub aggregate: f64,
17 pub limit: f64,
18 pub derived: f64,
19 pub dml: f64,
20 pub join_hash_multiplier: f64,
21 pub join_nested_multiplier: f64,
22 pub max_cost: f64,
23}
24
25impl Default for CostModelConfig {
26 fn default() -> Self {
27 Self {
28 scan: 1.0,
29 filter: 0.5,
30 projection: 0.1,
31 join: 5.0,
32 sort: 3.0,
33 aggregate: 4.0,
34 limit: 0.05,
35 derived: 0.1,
36 dml: 1.0,
37 join_hash_multiplier: 1.0,
38 join_nested_multiplier: 5.0,
39 max_cost: 1.0e18,
40 }
41 }
42}
43
44impl CostModelConfig {
45 pub const PARAM_SCAN: &'static str = "optimizer.cost.scan";
46 pub const PARAM_FILTER: &'static str = "optimizer.cost.filter";
47 pub const PARAM_PROJECTION: &'static str = "optimizer.cost.projection";
48 pub const PARAM_JOIN: &'static str = "optimizer.cost.join";
49 pub const PARAM_SORT: &'static str = "optimizer.cost.sort";
50 pub const PARAM_AGGREGATE: &'static str = "optimizer.cost.aggregate";
51 pub const PARAM_LIMIT: &'static str = "optimizer.cost.limit";
52 pub const PARAM_DERIVED: &'static str = "optimizer.cost.derived";
53 pub const PARAM_DML: &'static str = "optimizer.cost.dml";
54 pub const PARAM_JOIN_HASH_MULTIPLIER: &'static str = "optimizer.cost.join_hash_multiplier";
55 pub const PARAM_JOIN_NESTED_MULTIPLIER: &'static str = "optimizer.cost.join_nested_multiplier";
56 pub const PARAM_MAX_COST: &'static str = "optimizer.cost.max_cost";
57
58 pub fn load_from_path(path: impl AsRef<Path>) -> chryso_core::error::ChrysoResult<Self> {
59 let value: CostModelConfig = load_config_from_path(path, "cost config")?;
60 value.validate()?;
61 Ok(value)
62 }
63
64 pub fn validate(&self) -> chryso_core::error::ChrysoResult<()> {
65 let mut invalid = Vec::new();
66 for (name, value) in [
67 ("scan", self.scan),
68 ("filter", self.filter),
69 ("projection", self.projection),
70 ("join", self.join),
71 ("sort", self.sort),
72 ("aggregate", self.aggregate),
73 ("limit", self.limit),
74 ("derived", self.derived),
75 ("dml", self.dml),
76 ("join_hash_multiplier", self.join_hash_multiplier),
77 ("join_nested_multiplier", self.join_nested_multiplier),
78 ("max_cost", self.max_cost),
79 ] {
80 if !value.is_finite() || value <= 0.0 {
81 invalid.push(name);
82 }
83 }
84 if self.join_hash_multiplier < 1.0 {
85 invalid.push("join_hash_multiplier");
86 }
87 if self.join_nested_multiplier < 1.0 {
88 invalid.push("join_nested_multiplier");
89 }
90 if invalid.is_empty() {
91 Ok(())
92 } else {
93 Err(chryso_core::error::ChrysoError::new(format!(
94 "invalid cost config fields: {}",
95 invalid.join(", ")
96 )))
97 }
98 }
99
100 pub fn apply_system_params(
101 &self,
102 registry: &chryso_core::system_params::SystemParamRegistry,
103 tenant: Option<&str>,
104 ) -> Self {
105 let mut updated = self.clone();
106 let apply = |key: &str, target: &mut f64| {
107 if let Some(value) = registry.get_f64(tenant, key) {
108 if value.is_finite() && value > 0.0 {
109 *target = value;
110 }
111 }
112 };
113 apply(Self::PARAM_SCAN, &mut updated.scan);
114 apply(Self::PARAM_FILTER, &mut updated.filter);
115 apply(Self::PARAM_PROJECTION, &mut updated.projection);
116 apply(Self::PARAM_JOIN, &mut updated.join);
117 apply(Self::PARAM_SORT, &mut updated.sort);
118 apply(Self::PARAM_AGGREGATE, &mut updated.aggregate);
119 apply(Self::PARAM_LIMIT, &mut updated.limit);
120 apply(Self::PARAM_DERIVED, &mut updated.derived);
121 apply(Self::PARAM_DML, &mut updated.dml);
122 apply(
123 Self::PARAM_JOIN_HASH_MULTIPLIER,
124 &mut updated.join_hash_multiplier,
125 );
126 apply(
127 Self::PARAM_JOIN_NESTED_MULTIPLIER,
128 &mut updated.join_nested_multiplier,
129 );
130 apply(Self::PARAM_MAX_COST, &mut updated.max_cost);
131 updated
132 }
133}
134
135pub struct UnitCostModel;
136
137impl CostModel for UnitCostModel {
138 fn cost(&self, plan: &PhysicalPlan) -> Cost {
139 let default = CostModelConfig::default();
140 Cost(total_weight(plan, &default))
141 }
142}
143
144impl std::fmt::Debug for UnitCostModel {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.write_str("UnitCostModel")
147 }
148}
149
150pub struct UnitCostModelWithConfig {
151 config: CostModelConfig,
152}
153
154impl UnitCostModelWithConfig {
155 pub fn new(config: CostModelConfig) -> Self {
156 Self { config }
157 }
158}
159
160impl CostModel for UnitCostModelWithConfig {
161 fn cost(&self, plan: &PhysicalPlan) -> Cost {
162 Cost(total_weight(plan, &self.config))
163 }
164}
165
166pub struct StatsCostModel<'a> {
167 stats: &'a StatsCache,
168 config: CostModelConfig,
169}
170
171impl<'a> StatsCostModel<'a> {
172 pub fn new(stats: &'a StatsCache) -> Self {
173 Self {
174 stats,
175 config: CostModelConfig::default(),
176 }
177 }
178
179 pub fn with_config(stats: &'a StatsCache, config: CostModelConfig) -> Self {
180 let validated = if config.validate().is_ok() {
181 config
182 } else {
183 CostModelConfig::default()
184 };
185 Self {
186 stats,
187 config: validated,
188 }
189 }
190}
191
192impl CostModel for StatsCostModel<'_> {
193 fn cost(&self, plan: &PhysicalPlan) -> Cost {
194 let mut cost = total_stats_cost(plan, self.stats, &self.config);
195 if !cost.is_finite() || cost > self.config.max_cost {
196 cost = self.config.max_cost;
197 }
198 Cost(cost)
199 }
200}
201
202impl std::fmt::Debug for StatsCostModel<'_> {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.write_str("StatsCostModel")
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::{CostModel, CostModelConfig, StatsCache, StatsCostModel, UnitCostModel};
211 use chryso_core::system_params::{SystemParamRegistry, SystemParamValue};
212 use chryso_metadata::ColumnStats;
213 use chryso_planner::PhysicalPlan;
214
215 #[test]
216 fn unit_cost_counts_nodes() {
217 let plan = PhysicalPlan::Filter {
218 predicate: chryso_core::ast::Expr::Identifier("x".to_string()),
219 input: Box::new(PhysicalPlan::TableScan {
220 table: "t".to_string(),
221 }),
222 };
223 let cost = UnitCostModel.cost(&plan);
224 assert_eq!(cost.0, 1.5);
225 }
226
227 #[test]
228 fn join_algorithm_costs_differ() {
229 let left = PhysicalPlan::TableScan {
230 table: "t1".to_string(),
231 };
232 let right = PhysicalPlan::TableScan {
233 table: "t2".to_string(),
234 };
235 let hash = PhysicalPlan::Join {
236 join_type: chryso_core::ast::JoinType::Inner,
237 algorithm: chryso_planner::JoinAlgorithm::Hash,
238 left: Box::new(left.clone()),
239 right: Box::new(right.clone()),
240 on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
241 };
242 let nested = PhysicalPlan::Join {
243 join_type: chryso_core::ast::JoinType::Inner,
244 algorithm: chryso_planner::JoinAlgorithm::NestedLoop,
245 left: Box::new(left),
246 right: Box::new(right),
247 on: chryso_core::ast::Expr::Identifier("t1.id = t2.id".to_string()),
248 };
249 let model = UnitCostModel;
250 assert!(model.cost(&hash).0 < model.cost(&nested).0);
251 }
252
253 #[test]
254 fn stats_cost_uses_selectivity() {
255 let plan = PhysicalPlan::Filter {
256 predicate: chryso_core::ast::Expr::BinaryOp {
257 left: Box::new(chryso_core::ast::Expr::Identifier(
258 "sales.region".to_string(),
259 )),
260 op: chryso_core::ast::BinaryOperator::Eq,
261 right: Box::new(chryso_core::ast::Expr::Literal(
262 chryso_core::ast::Literal::String("us".to_string()),
263 )),
264 },
265 input: Box::new(PhysicalPlan::TableScan {
266 table: "sales".to_string(),
267 }),
268 };
269 let mut stats = StatsCache::new();
270 stats.insert_table_stats("sales", chryso_metadata::TableStats { row_count: 100.0 });
271 stats.insert_column_stats(
272 "sales",
273 "region",
274 ColumnStats {
275 distinct_count: 50.0,
276 null_fraction: 0.0,
277 },
278 );
279 let model = StatsCostModel::new(&stats);
280 let selective = model.cost(&plan);
281
282 stats.insert_column_stats(
283 "sales",
284 "region",
285 ColumnStats {
286 distinct_count: 1.0,
287 null_fraction: 0.0,
288 },
289 );
290 let model = StatsCostModel::new(&stats);
291 let non_selective = model.cost(&plan);
292 assert!(selective.0 < non_selective.0);
293 }
294
295 #[test]
296 fn config_validation_rejects_non_positive() {
297 let mut config = CostModelConfig::default();
298 config.join = 0.0;
299 let err = config.validate().expect_err("invalid config");
300 assert!(err.to_string().contains("join"));
301 }
302
303 #[test]
304 fn system_params_override_cost_config() {
305 let registry = SystemParamRegistry::new();
306 registry.set_default_param(CostModelConfig::PARAM_FILTER, SystemParamValue::Float(0.9));
307 let config = CostModelConfig::default();
308 let updated = config.apply_system_params(®istry, Some("tenant"));
309 assert_eq!(updated.filter, 0.9);
310 }
311
312 #[test]
313 fn system_params_ignore_invalid_values() {
314 let registry = SystemParamRegistry::new();
315 registry.set_default_param(CostModelConfig::PARAM_SORT, SystemParamValue::Float(0.0));
316 let config = CostModelConfig::default();
317 let updated = config.apply_system_params(®istry, Some("tenant"));
318 assert_eq!(updated.sort, config.sort);
319 }
320}
321
322pub(crate) fn load_config_from_path<T: DeserializeOwned>(
323 path: impl AsRef<Path>,
324 label: &str,
325) -> chryso_core::error::ChrysoResult<T> {
326 let content = fs::read_to_string(path.as_ref()).map_err(|err| {
327 chryso_core::error::ChrysoError::new(format!("read {label} failed: {err}"))
328 })?;
329 if path
330 .as_ref()
331 .extension()
332 .and_then(|ext| ext.to_str())
333 .map(|ext| ext.eq_ignore_ascii_case("toml"))
334 .unwrap_or(false)
335 {
336 toml::from_str(&content).map_err(|err| {
337 chryso_core::error::ChrysoError::new(format!("parse toml {label} failed: {err}"))
338 })
339 } else {
340 serde_json::from_str(&content).map_err(|err| {
341 chryso_core::error::ChrysoError::new(format!("parse json {label} failed: {err}"))
342 })
343 }
344}
345
346fn local_join_penalty(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
347 match plan {
348 PhysicalPlan::Join { algorithm, .. } => match algorithm {
349 chryso_planner::JoinAlgorithm::Hash => {
350 config.join * (config.join_hash_multiplier - 1.0)
351 }
352 chryso_planner::JoinAlgorithm::NestedLoop => {
353 config.join * (config.join_nested_multiplier - 1.0)
354 }
355 },
356 _ => 0.0,
357 }
358}
359
360fn node_weight(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
361 match plan {
362 PhysicalPlan::TableScan { .. } | PhysicalPlan::IndexScan { .. } => config.scan,
363 PhysicalPlan::Filter { .. } => config.filter,
364 PhysicalPlan::Projection { .. } => config.projection,
365 PhysicalPlan::Join { .. } => config.join,
366 PhysicalPlan::Aggregate { .. } => config.aggregate,
367 PhysicalPlan::Distinct { .. } => config.aggregate,
368 PhysicalPlan::TopN { .. } => config.sort,
369 PhysicalPlan::Sort { .. } => config.sort,
370 PhysicalPlan::Limit { .. } => config.limit,
371 PhysicalPlan::Derived { .. } => config.derived,
372 PhysicalPlan::Dml { .. } => config.dml,
373 }
374}
375
376fn total_weight(plan: &PhysicalPlan, config: &CostModelConfig) -> f64 {
377 let base = node_weight(plan, config) + local_join_penalty(plan, config);
379 let children = match plan {
380 PhysicalPlan::Join { left, right, .. } => {
381 total_weight(left, config) + total_weight(right, config)
382 }
383 PhysicalPlan::Filter { input, .. }
384 | PhysicalPlan::Projection { input, .. }
385 | PhysicalPlan::Aggregate { input, .. }
386 | PhysicalPlan::Distinct { input }
387 | PhysicalPlan::TopN { input, .. }
388 | PhysicalPlan::Sort { input, .. }
389 | PhysicalPlan::Limit { input, .. }
390 | PhysicalPlan::Derived { input, .. } => total_weight(input, config),
391 PhysicalPlan::TableScan { .. }
392 | PhysicalPlan::IndexScan { .. }
393 | PhysicalPlan::Dml { .. } => 0.0,
394 };
395 base + children
396}
397
398fn total_stats_cost(plan: &PhysicalPlan, stats: &StatsCache, config: &CostModelConfig) -> f64 {
399 let rows = estimate_rows(plan, stats);
401 let mut cost = rows * node_weight(plan, config) + local_join_penalty(plan, config);
402 cost += match plan {
403 PhysicalPlan::Join { left, right, .. } => {
404 total_stats_cost(left, stats, config) + total_stats_cost(right, stats, config)
405 }
406 PhysicalPlan::Filter { input, .. }
407 | PhysicalPlan::Projection { input, .. }
408 | PhysicalPlan::Aggregate { input, .. }
409 | PhysicalPlan::Distinct { input }
410 | PhysicalPlan::TopN { input, .. }
411 | PhysicalPlan::Sort { input, .. }
412 | PhysicalPlan::Limit { input, .. }
413 | PhysicalPlan::Derived { input, .. } => total_stats_cost(input, stats, config),
414 PhysicalPlan::TableScan { .. }
415 | PhysicalPlan::IndexScan { .. }
416 | PhysicalPlan::Dml { .. } => 0.0,
417 };
418 cost
419}
420
421fn estimate_rows(plan: &PhysicalPlan, stats: &StatsCache) -> f64 {
422 match plan {
423 PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => stats
424 .table_stats(table)
425 .map(|stats| stats.row_count)
426 .unwrap_or(1000.0),
427 PhysicalPlan::Dml { .. } => 1.0,
428 PhysicalPlan::Derived { input, .. } => estimate_rows(input, stats),
429 PhysicalPlan::Filter { predicate, input } => {
430 let base = estimate_rows(input, stats);
431 let table = single_table_name(input);
432 base * estimate_selectivity(predicate, stats, table.as_deref())
433 }
434 PhysicalPlan::Projection { input, .. } => estimate_rows(input, stats),
435 PhysicalPlan::Join { left, right, .. } => {
436 estimate_rows(left, stats) * estimate_rows(right, stats) * 0.1
437 }
438 PhysicalPlan::Aggregate { input, .. } => (estimate_rows(input, stats) * 0.1).max(1.0),
439 PhysicalPlan::Distinct { input } => (estimate_rows(input, stats) * 0.3).max(1.0),
440 PhysicalPlan::TopN { limit, input, .. } => estimate_rows(input, stats).min(*limit as f64),
441 PhysicalPlan::Sort { input, .. } => estimate_rows(input, stats),
442 PhysicalPlan::Limit { limit, input, .. } => match limit {
443 Some(limit) => estimate_rows(input, stats).min(*limit as f64),
444 None => estimate_rows(input, stats),
445 },
446 }
447}
448
449fn estimate_selectivity(
450 predicate: &chryso_core::ast::Expr,
451 stats: &StatsCache,
452 table: Option<&str>,
453) -> f64 {
454 use chryso_core::ast::{BinaryOperator, Expr};
455 match predicate {
456 Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::And) => {
457 estimate_selectivity(left, stats, table) * estimate_selectivity(right, stats, table)
458 }
459 Expr::BinaryOp { left, op, right } if matches!(op, BinaryOperator::Or) => {
460 let left_sel = estimate_selectivity(left, stats, table);
461 let right_sel = estimate_selectivity(right, stats, table);
462 (left_sel + right_sel - left_sel * right_sel).min(1.0)
463 }
464 Expr::IsNull { expr, negated } => {
465 let (table_name, column_name) = match expr.as_ref() {
466 Expr::Identifier(name) => match name.split_once('.') {
467 Some((prefix, column)) => (Some(prefix), column),
468 None => (table, name.as_str()),
469 },
470 _ => (table, ""),
471 };
472 if let (Some(table_name), column_name) = (table_name, column_name) {
473 if !column_name.is_empty() {
474 if let Some(stats) = stats.column_stats(table_name, column_name) {
475 let base = stats.null_fraction;
476 return if *negated { 1.0 - base } else { base };
477 }
478 }
479 }
480 if *negated { 0.9 } else { 0.1 }
481 }
482 Expr::BinaryOp { left, op, right } => {
483 if let Some(selectivity) = estimate_eq_selectivity(left, right, stats, table) {
484 match op {
485 BinaryOperator::Eq => selectivity,
486 BinaryOperator::NotEq => (1.0 - selectivity).max(0.0),
487 BinaryOperator::Lt
488 | BinaryOperator::LtEq
489 | BinaryOperator::Gt
490 | BinaryOperator::GtEq => 0.3,
491 _ => 0.3,
492 }
493 } else {
494 0.3
495 }
496 }
497 _ => 0.5,
498 }
499}
500
501fn estimate_eq_selectivity(
502 left: &chryso_core::ast::Expr,
503 right: &chryso_core::ast::Expr,
504 stats: &StatsCache,
505 table: Option<&str>,
506) -> Option<f64> {
507 let (ident, literal) = match (left, right) {
508 (chryso_core::ast::Expr::Identifier(name), chryso_core::ast::Expr::Literal(_)) => {
509 (name, right)
510 }
511 (chryso_core::ast::Expr::Literal(_), chryso_core::ast::Expr::Identifier(name)) => {
512 (name, left)
513 }
514 _ => return None,
515 };
516 let _ = literal;
517 let (table_name, column_name) = match ident.split_once('.') {
518 Some((prefix, column)) => (Some(prefix), column),
519 None => (table, ident.as_str()),
520 };
521 let table_name = table_name?;
522 let stats = stats.column_stats(table_name, column_name)?;
523 let distinct = stats.distinct_count.max(1.0);
524 Some(1.0 / distinct)
525}
526
527fn single_table_name(plan: &PhysicalPlan) -> Option<String> {
528 match plan {
529 PhysicalPlan::TableScan { table } | PhysicalPlan::IndexScan { table, .. } => {
530 Some(table.clone())
531 }
532 PhysicalPlan::Filter { input, .. }
533 | PhysicalPlan::Projection { input, .. }
534 | PhysicalPlan::Aggregate { input, .. }
535 | PhysicalPlan::Distinct { input }
536 | PhysicalPlan::TopN { input, .. }
537 | PhysicalPlan::Sort { input, .. }
538 | PhysicalPlan::Limit { input, .. }
539 | PhysicalPlan::Derived { input, .. } => single_table_name(input),
540 PhysicalPlan::Join { .. } | PhysicalPlan::Dml { .. } => None,
541 }
542}