Skip to main content

reddb_server/storage/query/planner/
cost.rs

1//! Cost Estimation
2//!
3//! Cost-based query plan selection with cardinality estimation.
4//!
5//! # Cost Model
6//!
7//! - **CPU cost**: Computation overhead
8//! - **IO cost**: Disk/memory access
9//! - **Network cost**: For distributed queries
10//! - **Memory cost**: Working memory required
11
12use std::sync::Arc;
13
14use super::stats_provider::{NullProvider, StatsProvider};
15use crate::storage::query::ast::{
16    CompareOp, FieldRef, Filter as AstFilter, GraphQuery, HybridQuery, JoinQuery, JoinType,
17    PathQuery, QueryExpr, TableQuery, VectorQuery,
18};
19use crate::storage::schema::Value;
20
21/// Cardinality estimate for a query result
22#[derive(Debug, Clone, Default)]
23pub struct CardinalityEstimate {
24    /// Estimated row/record count
25    pub rows: f64,
26    /// Selectivity factor (0.0 - 1.0)
27    pub selectivity: f64,
28    /// Confidence in the estimate (0.0 - 1.0)
29    pub confidence: f64,
30}
31
32impl CardinalityEstimate {
33    /// Create a new cardinality estimate
34    pub fn new(rows: f64, selectivity: f64) -> Self {
35        Self {
36            rows,
37            selectivity,
38            confidence: 1.0,
39        }
40    }
41
42    /// Full table scan estimate
43    pub fn full_scan(table_size: f64) -> Self {
44        Self {
45            rows: table_size,
46            selectivity: 1.0,
47            confidence: 1.0,
48        }
49    }
50
51    /// Apply a filter to reduce cardinality
52    pub fn with_filter(mut self, filter_selectivity: f64) -> Self {
53        self.rows *= filter_selectivity;
54        self.selectivity *= filter_selectivity;
55        self.confidence *= 0.9; // Reduce confidence with each estimate
56        self
57    }
58}
59
60/// Cost of executing a query plan.
61///
62/// Mirrors PostgreSQL's `Cost` split: `startup_cost` is the work needed
63/// before the **first** row can be produced, `total` is the work to
64/// produce the **last** row. Both are reported so plan selection can
65/// pick a low-startup plan when a small `LIMIT` is in scope, even if
66/// total work is higher.
67///
68/// See `src/storage/query/planner/README.md` § Invariant 1.
69#[derive(Debug, Clone, Default)]
70pub struct PlanCost {
71    /// CPU computation cost
72    pub cpu: f64,
73    /// IO access cost
74    pub io: f64,
75    /// Network transfer cost (for distributed)
76    pub network: f64,
77    /// Memory requirement
78    pub memory: f64,
79    /// Cost to produce the **first** row.
80    ///
81    /// Zero for streaming operators (full scan, index scan, filter over
82    /// scan). Equal to `total` for blocking operators (sort, hash join
83    /// build side, materialize).
84    pub startup_cost: f64,
85    /// Cost to produce the **last** row.
86    pub total: f64,
87}
88
89impl PlanCost {
90    /// Create a new cost estimate with `startup_cost = 0` (streaming).
91    pub fn new(cpu: f64, io: f64, memory: f64) -> Self {
92        let total = cpu + io * 10.0 + memory * 0.1; // IO is expensive
93        Self {
94            cpu,
95            io,
96            network: 0.0,
97            memory,
98            startup_cost: 0.0,
99            total,
100        }
101    }
102
103    /// Create a cost with an explicit `startup_cost`. Use for blocking
104    /// operators (sort, hash build) and for index point lookups whose
105    /// first-row cost is non-zero.
106    pub fn with_startup(cpu: f64, io: f64, memory: f64, startup_cost: f64) -> Self {
107        let total = cpu + io * 10.0 + memory * 0.1;
108        Self {
109            cpu,
110            io,
111            network: 0.0,
112            memory,
113            startup_cost: startup_cost.max(0.0),
114            total: total.max(startup_cost),
115        }
116    }
117
118    /// Compose two costs in a **pipelined** fashion: the second operator
119    /// consumes the first as a stream.
120    ///
121    /// Both `startup_cost` and `total` add together. Use for filter
122    /// over scan, projection over filter, etc.
123    pub fn combine_pipelined(&self, other: &PlanCost) -> PlanCost {
124        PlanCost {
125            cpu: self.cpu + other.cpu,
126            io: self.io + other.io,
127            network: self.network + other.network,
128            memory: self.memory.max(other.memory),
129            startup_cost: self.startup_cost + other.startup_cost,
130            total: self.total + other.total,
131        }
132    }
133
134    /// Compose two costs where `self` must be **fully consumed** before
135    /// `blocker` can produce its first row.
136    ///
137    /// `self.total` flows into `blocker.startup_cost`. Use for sort,
138    /// hash build, materialise — anything that has to drain its input
139    /// before emitting.
140    pub fn combine_blocking(&self, blocker: &PlanCost) -> PlanCost {
141        PlanCost {
142            cpu: self.cpu + blocker.cpu,
143            io: self.io + blocker.io,
144            network: self.network + blocker.network,
145            memory: self.memory.max(blocker.memory),
146            startup_cost: self.total + blocker.startup_cost,
147            total: self.total + blocker.total,
148        }
149    }
150
151    /// Backwards-compatible alias for [`combine_pipelined`].
152    ///
153    /// New code should prefer `combine_pipelined` / `combine_blocking`
154    /// explicitly. This is kept so existing callers compile unchanged.
155    pub fn combine(&self, other: &PlanCost) -> PlanCost {
156        self.combine_pipelined(other)
157    }
158
159    /// Scale cost by a factor (cardinality multiplier, etc.).
160    pub fn scale(&self, factor: f64) -> PlanCost {
161        PlanCost {
162            cpu: self.cpu * factor,
163            io: self.io * factor,
164            network: self.network * factor,
165            memory: self.memory,             // Memory doesn't scale linearly
166            startup_cost: self.startup_cost, // startup is per-plan, not per-row
167            total: self.total * factor,
168        }
169    }
170
171    /// Plan-comparison helper. Picks `Less` when `self` should be
172    /// preferred over `other`.
173    ///
174    /// When `limit` is `Some(k)` and `k < 0.1 * cardinality`, the
175    /// comparison switches from `total` to `startup_cost` — the client
176    /// will only consume a small slice of the result, so we want the
177    /// plan that produces the first rows fastest even if the full scan
178    /// would be more expensive.
179    ///
180    /// This mirrors PostgreSQL's `compare_path_costs_fuzzily` logic for
181    /// `STARTUP` vs `TOTAL` cost ordering.
182    pub fn prefer_over(
183        &self,
184        other: &PlanCost,
185        limit: Option<u64>,
186        cardinality: f64,
187    ) -> std::cmp::Ordering {
188        let use_startup = matches!(limit, Some(k) if (k as f64) < 0.1 * cardinality.max(1.0));
189        let (lhs, rhs) = if use_startup {
190            (self.startup_cost, other.startup_cost)
191        } else {
192            (self.total, other.total)
193        };
194        lhs.partial_cmp(&rhs).unwrap_or(std::cmp::Ordering::Equal)
195    }
196}
197
198/// Statistics about a table or graph
199#[derive(Debug, Clone, Default)]
200pub struct TableStats {
201    /// Total row count
202    pub row_count: u64,
203    /// Average row size in bytes
204    pub avg_row_size: u32,
205    /// Number of pages
206    pub page_count: u64,
207    /// Column statistics
208    pub columns: Vec<ColumnStats>,
209}
210
211/// Statistics about a column
212#[derive(Debug, Clone, Default)]
213pub struct ColumnStats {
214    /// Column name
215    pub name: String,
216    /// Number of distinct values
217    pub distinct_count: u64,
218    /// Null count
219    pub null_count: u64,
220    /// Minimum value (if orderable)
221    pub min_value: Option<String>,
222    /// Maximum value (if orderable)
223    pub max_value: Option<String>,
224    /// Has index
225    pub has_index: bool,
226}
227
228/// Cost estimator for query plans
229pub struct CostEstimator {
230    /// Default table row count estimate
231    default_row_count: f64,
232    /// Cost per row scan
233    row_scan_cost: f64,
234    /// Cost per index lookup
235    index_lookup_cost: f64,
236    /// Cost per hash join probe
237    hash_probe_cost: f64,
238    /// Cost per nested loop iteration
239    nested_loop_cost: f64,
240    /// Cost per graph edge traversal
241    edge_traversal_cost: f64,
242    /// Optional stats provider. When present, `estimate_table_cardinality`
243    /// and the selectivity computation use real per-table / per-column
244    /// statistics instead of the heuristic constants. `None` preserves the
245    /// legacy behaviour so callers can adopt stats incrementally.
246    stats: Arc<dyn StatsProvider>,
247}
248
249impl CostEstimator {
250    /// Create a new cost estimator with default parameters and a
251    /// [`NullProvider`] — no real stats, pure heuristic mode.
252    pub fn new() -> Self {
253        Self {
254            default_row_count: 1000.0,
255            row_scan_cost: 1.0,
256            index_lookup_cost: 0.1,
257            hash_probe_cost: 0.5,
258            nested_loop_cost: 2.0,
259            edge_traversal_cost: 1.5,
260            stats: Arc::new(NullProvider),
261        }
262    }
263
264    /// Create a cost estimator that consults `provider` for real table /
265    /// column / index statistics. Any lookups the provider cannot satisfy
266    /// fall back to the heuristic path automatically.
267    pub fn with_stats(provider: Arc<dyn StatsProvider>) -> Self {
268        Self {
269            stats: provider,
270            ..Self::new()
271        }
272    }
273
274    /// Swap the stats provider on an existing estimator. Useful for tests
275    /// and for planners that build one `CostEstimator` and repoint it at
276    /// per-query snapshots.
277    pub fn set_stats(&mut self, provider: Arc<dyn StatsProvider>) {
278        self.stats = provider;
279    }
280
281    /// Estimate cost of a query expression
282    pub fn estimate(&self, query: &QueryExpr) -> PlanCost {
283        match query {
284            QueryExpr::Table(tq) => self.estimate_table(tq),
285            QueryExpr::Graph(gq) => self.estimate_graph(gq),
286            QueryExpr::Join(jq) => self.estimate_join(jq),
287            QueryExpr::Path(pq) => self.estimate_path(pq),
288            QueryExpr::Vector(vq) => self.estimate_vector(vq),
289            QueryExpr::Hybrid(hq) => self.estimate_hybrid(hq),
290            // DML/DDL statements have minimal query cost
291            QueryExpr::Insert(_)
292            | QueryExpr::Update(_)
293            | QueryExpr::Delete(_)
294            | QueryExpr::CreateTable(_)
295            | QueryExpr::CreateCollection(_)
296            | QueryExpr::CreateVector(_)
297            | QueryExpr::DropTable(_)
298            | QueryExpr::DropGraph(_)
299            | QueryExpr::DropVector(_)
300            | QueryExpr::DropDocument(_)
301            | QueryExpr::DropKv(_)
302            | QueryExpr::DropCollection(_)
303            | QueryExpr::Truncate(_)
304            | QueryExpr::AlterTable(_)
305            | QueryExpr::GraphCommand(_)
306            | QueryExpr::SearchCommand(_)
307            | QueryExpr::CreateIndex(_)
308            | QueryExpr::DropIndex(_)
309            | QueryExpr::ProbabilisticCommand(_)
310            | QueryExpr::Ask(_)
311            | QueryExpr::SetConfig { .. }
312            | QueryExpr::ShowConfig { .. }
313            | QueryExpr::SetSecret { .. }
314            | QueryExpr::DeleteSecret { .. }
315            | QueryExpr::ShowSecrets { .. }
316            | QueryExpr::SetTenant(_)
317            | QueryExpr::ShowTenant
318            | QueryExpr::CreateTimeSeries(_)
319            | QueryExpr::CreateMetric(_)
320            | QueryExpr::AlterMetric(_)
321            | QueryExpr::CreateSlo(_)
322            | QueryExpr::DropTimeSeries(_)
323            | QueryExpr::CreateQueue(_)
324            | QueryExpr::AlterQueue(_)
325            | QueryExpr::DropQueue(_)
326            | QueryExpr::QueueSelect(_)
327            | QueryExpr::QueueCommand(_)
328            | QueryExpr::KvCommand(_)
329            | QueryExpr::ConfigCommand(_)
330            | QueryExpr::CreateTree(_)
331            | QueryExpr::DropTree(_)
332            | QueryExpr::TreeCommand(_)
333            | QueryExpr::ExplainAlter(_)
334            | QueryExpr::TransactionControl(_)
335            | QueryExpr::MaintenanceCommand(_)
336            | QueryExpr::CreateSchema(_)
337            | QueryExpr::DropSchema(_)
338            | QueryExpr::CreateSequence(_)
339            | QueryExpr::DropSequence(_)
340            | QueryExpr::CopyFrom(_)
341            | QueryExpr::CreateView(_)
342            | QueryExpr::DropView(_)
343            | QueryExpr::RefreshMaterializedView(_)
344            | QueryExpr::CreatePolicy(_)
345            | QueryExpr::DropPolicy(_)
346            | QueryExpr::CreateServer(_)
347            | QueryExpr::DropServer(_)
348            | QueryExpr::CreateForeignTable(_)
349            | QueryExpr::DropForeignTable(_)
350            | QueryExpr::Grant(_)
351            | QueryExpr::Revoke(_)
352            | QueryExpr::AlterUser(_)
353            | QueryExpr::CreateIamPolicy { .. }
354            | QueryExpr::DropIamPolicy { .. }
355            | QueryExpr::AttachPolicy { .. }
356            | QueryExpr::DetachPolicy { .. }
357            | QueryExpr::ShowPolicies { .. }
358            | QueryExpr::ShowEffectivePermissions { .. }
359            | QueryExpr::RankOf(_)
360            | QueryExpr::ApproxRankOf(_)
361            | QueryExpr::RankRange(_)
362            | QueryExpr::SimulatePolicy { .. }
363            | QueryExpr::LintPolicy { .. }
364            | QueryExpr::MigratePolicyMode { .. }
365            | QueryExpr::CreateMigration(_)
366            | QueryExpr::ApplyMigration(_)
367            | QueryExpr::RollbackMigration(_)
368            | QueryExpr::ExplainMigration(_)
369            | QueryExpr::EventsBackfill(_)
370            | QueryExpr::EventsBackfillStatus { .. } => PlanCost::new(1.0, 1.0, 0.0),
371        }
372    }
373
374    /// Estimate cardinality of a query result
375    pub fn estimate_cardinality(&self, query: &QueryExpr) -> CardinalityEstimate {
376        match query {
377            QueryExpr::Table(tq) => self.estimate_table_cardinality(tq),
378            QueryExpr::Graph(gq) => self.estimate_graph_cardinality(gq),
379            QueryExpr::Join(jq) => self.estimate_join_cardinality(jq),
380            QueryExpr::Path(pq) => self.estimate_path_cardinality(pq),
381            QueryExpr::Vector(vq) => self.estimate_vector_cardinality(vq),
382            QueryExpr::Hybrid(hq) => self.estimate_hybrid_cardinality(hq),
383            // DML/DDL/Command statements return affected-row count or nothing
384            QueryExpr::Insert(_)
385            | QueryExpr::Update(_)
386            | QueryExpr::Delete(_)
387            | QueryExpr::CreateTable(_)
388            | QueryExpr::CreateCollection(_)
389            | QueryExpr::CreateVector(_)
390            | QueryExpr::DropTable(_)
391            | QueryExpr::DropGraph(_)
392            | QueryExpr::DropVector(_)
393            | QueryExpr::DropDocument(_)
394            | QueryExpr::DropKv(_)
395            | QueryExpr::DropCollection(_)
396            | QueryExpr::Truncate(_)
397            | QueryExpr::AlterTable(_)
398            | QueryExpr::GraphCommand(_)
399            | QueryExpr::SearchCommand(_)
400            | QueryExpr::CreateIndex(_)
401            | QueryExpr::DropIndex(_)
402            | QueryExpr::ProbabilisticCommand(_)
403            | QueryExpr::Ask(_)
404            | QueryExpr::SetConfig { .. }
405            | QueryExpr::ShowConfig { .. }
406            | QueryExpr::SetSecret { .. }
407            | QueryExpr::DeleteSecret { .. }
408            | QueryExpr::ShowSecrets { .. }
409            | QueryExpr::SetTenant(_)
410            | QueryExpr::ShowTenant
411            | QueryExpr::CreateTimeSeries(_)
412            | QueryExpr::CreateMetric(_)
413            | QueryExpr::AlterMetric(_)
414            | QueryExpr::CreateSlo(_)
415            | QueryExpr::DropTimeSeries(_)
416            | QueryExpr::CreateQueue(_)
417            | QueryExpr::AlterQueue(_)
418            | QueryExpr::DropQueue(_)
419            | QueryExpr::QueueSelect(_)
420            | QueryExpr::QueueCommand(_)
421            | QueryExpr::KvCommand(_)
422            | QueryExpr::ConfigCommand(_)
423            | QueryExpr::CreateTree(_)
424            | QueryExpr::DropTree(_)
425            | QueryExpr::TreeCommand(_)
426            | QueryExpr::ExplainAlter(_)
427            | QueryExpr::TransactionControl(_)
428            | QueryExpr::MaintenanceCommand(_)
429            | QueryExpr::CreateSchema(_)
430            | QueryExpr::DropSchema(_)
431            | QueryExpr::CreateSequence(_)
432            | QueryExpr::DropSequence(_)
433            | QueryExpr::CopyFrom(_)
434            | QueryExpr::CreateView(_)
435            | QueryExpr::DropView(_)
436            | QueryExpr::RefreshMaterializedView(_)
437            | QueryExpr::CreatePolicy(_)
438            | QueryExpr::DropPolicy(_)
439            | QueryExpr::CreateServer(_)
440            | QueryExpr::DropServer(_)
441            | QueryExpr::CreateForeignTable(_)
442            | QueryExpr::DropForeignTable(_)
443            | QueryExpr::Grant(_)
444            | QueryExpr::Revoke(_)
445            | QueryExpr::AlterUser(_)
446            | QueryExpr::CreateIamPolicy { .. }
447            | QueryExpr::DropIamPolicy { .. }
448            | QueryExpr::AttachPolicy { .. }
449            | QueryExpr::DetachPolicy { .. }
450            | QueryExpr::ShowPolicies { .. }
451            | QueryExpr::ShowEffectivePermissions { .. }
452            | QueryExpr::RankOf(_)
453            | QueryExpr::ApproxRankOf(_)
454            | QueryExpr::RankRange(_)
455            | QueryExpr::SimulatePolicy { .. }
456            | QueryExpr::LintPolicy { .. }
457            | QueryExpr::MigratePolicyMode { .. }
458            | QueryExpr::CreateMigration(_)
459            | QueryExpr::ApplyMigration(_)
460            | QueryExpr::RollbackMigration(_)
461            | QueryExpr::ExplainMigration(_)
462            | QueryExpr::EventsBackfill(_)
463            | QueryExpr::EventsBackfillStatus { .. } => CardinalityEstimate::new(1.0, 1.0),
464        }
465    }
466
467    // =========================================================================
468    // Table Query Estimation
469    // =========================================================================
470
471    fn estimate_table(&self, query: &TableQuery) -> PlanCost {
472        let cardinality = self.estimate_table_cardinality(query);
473
474        let cpu = cardinality.rows * self.row_scan_cost;
475
476        // I/O cost: use Mackert-Lohman when we have index stats and a filter
477        // column with a known index; otherwise fall back to the naive heuristic.
478        let io = self.estimate_table_io(query, cardinality.rows);
479
480        let memory = cardinality.rows * 100.0; // 100 bytes per row estimate
481
482        PlanCost::new(cpu, io, memory)
483    }
484
485    /// Compute the I/O page cost for a table scan.
486    ///
487    /// When the query has a simple equality or range filter on an indexed
488    /// column, use `IndexStats::correlated_io_cost` (Mackert-Lohman) which
489    /// accounts for `index_correlation` (0.0 = random I/O, 1.0 = sequential).
490    /// Falls back to the naive `rows / 100` heuristic otherwise.
491    fn estimate_table_io(&self, query: &TableQuery, result_rows: f64) -> f64 {
492        const ROWS_PER_PAGE: f64 = 100.0;
493
494        // Look up total heap pages from table stats if available
495        let table_stats = self.stats.table_stats(&query.table);
496        let heap_pages = table_stats
497            .map(|s| s.page_count as f64)
498            .unwrap_or_else(|| (result_rows / ROWS_PER_PAGE).max(1.0));
499
500        // If the filter is a simple comparison on an indexed column, use
501        // the Mackert-Lohman formula with correlation from IndexStats.
502        if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
503            if let Some(col) = first_filter_column(&filter, &query.table) {
504                if let Some(idx) = self.stats.index_stats(&query.table, col) {
505                    return idx.correlated_io_cost(result_rows, heap_pages);
506                }
507            }
508        }
509
510        // Heuristic fallback: assume sequential pages = rows / 100
511        (result_rows / ROWS_PER_PAGE).ceil()
512    }
513
514    fn estimate_table_cardinality(&self, query: &TableQuery) -> CardinalityEstimate {
515        // Prefer real row counts from the stats provider; fall back to the
516        // heuristic `default_row_count` when no stats are registered.
517        let base_rows = self
518            .stats
519            .table_stats(&query.table)
520            .map(|s| s.row_count as f64)
521            .unwrap_or(self.default_row_count);
522
523        let mut estimate = CardinalityEstimate::full_scan(base_rows);
524
525        // Apply filter selectivity (stats-aware when provider has index
526        // stats on the compared column).
527        if let Some(filter) = crate::storage::query::sql_lowering::effective_table_filter(query) {
528            let selectivity = self.filter_selectivity(&filter, &query.table);
529            estimate = estimate.with_filter(selectivity);
530        }
531
532        // Apply limit
533        if let Some(limit) = query.limit {
534            estimate.rows = estimate.rows.min(limit as f64);
535        }
536
537        estimate
538    }
539
540    /// Stats-aware selectivity computation.
541    ///
542    /// Resolution order (best → worst):
543    ///   1. `column_mcv` for equality on a known frequent value
544    ///   2. `column_histogram` for ranges and BETWEEN
545    ///   3. `index_stats.point_selectivity()` for indexed columns
546    ///   4. Hardcoded heuristic constants as final fallback
547    ///
548    /// Mirrors postgres `var_eq_const` / `histogram_selectivity` in
549    /// `src/backend/utils/adt/selfuncs.c`. Histogram + MCV data
550    /// structures already live in `super::histogram`; this method is
551    /// where we finally consume them on the hot planner path.
552    fn filter_selectivity(&self, filter: &AstFilter, table: &str) -> f64 {
553        match filter {
554            AstFilter::Compare { field, op, value } => {
555                let column = column_name_for_table(field, table);
556                match op {
557                    CompareOp::Eq => self.eq_selectivity(table, column, value),
558                    CompareOp::Ne => 1.0 - self.eq_selectivity(table, column, value),
559                    CompareOp::Lt | CompareOp::Le => {
560                        self.range_selectivity(table, column, None, Some(value))
561                    }
562                    CompareOp::Gt | CompareOp::Ge => {
563                        self.range_selectivity(table, column, Some(value), None)
564                    }
565                }
566            }
567            AstFilter::Between {
568                field, low, high, ..
569            } => {
570                let column = column_name_for_table(field, table);
571                self.range_selectivity(table, column, Some(low), Some(high))
572            }
573            AstFilter::In { field, values, .. } => {
574                let column = column_name_for_table(field, table);
575                // If we have an MCV list, sum the per-value frequencies
576                // for values that are actually in the list, plus the
577                // residual estimate for the rest.
578                if let Some(c) = column {
579                    if let Some(mcv) = self.stats.column_mcv(table, c) {
580                        let mut hits: f64 = 0.0;
581                        let mut residual_count = 0usize;
582                        for v in values {
583                            if let Some(cv) = column_value_from(v) {
584                                if let Some(freq) = mcv.frequency_of(&cv) {
585                                    hits += freq;
586                                } else {
587                                    residual_count += 1;
588                                }
589                            } else {
590                                residual_count += 1;
591                            }
592                        }
593                        let total = mcv.total_frequency();
594                        let distinct = self.stats.distinct_values(table, c).unwrap_or(100);
595                        let non_mcv_distinct =
596                            distinct.saturating_sub(mcv.values.len() as u64).max(1);
597                        let per_residual = (1.0 - total) / non_mcv_distinct as f64;
598                        let estimate = hits + (residual_count as f64) * per_residual;
599                        return estimate.clamp(0.0, 1.0).min(0.5);
600                    }
601                    if let Some(s) = self.stats.index_stats(table, c) {
602                        return (s.point_selectivity() * values.len() as f64).min(0.5);
603                    }
604                }
605                (values.len() as f64 * 0.01).min(0.5)
606            }
607            AstFilter::Like { .. } => 0.1,
608            AstFilter::StartsWith { .. } => 0.15,
609            AstFilter::EndsWith { .. } => 0.15,
610            AstFilter::Contains { .. } => 0.1,
611            AstFilter::IsNull { .. } => 0.01,
612            AstFilter::IsNotNull { .. } => 0.99,
613            AstFilter::And(left, right) => {
614                self.filter_selectivity(left, table) * self.filter_selectivity(right, table)
615            }
616            AstFilter::Or(left, right) => {
617                let s1 = self.filter_selectivity(left, table);
618                let s2 = self.filter_selectivity(right, table);
619                s1 + s2 - (s1 * s2)
620            }
621            AstFilter::Not(inner) => 1.0 - self.filter_selectivity(inner, table),
622            AstFilter::CompareFields { .. } => {
623                // Column-to-column predicates lack histogram leverage
624                // — assume moderate selectivity. Histogram/MCV hooks
625                // only help literal-valued filters.
626                0.1
627            }
628            AstFilter::CompareExpr { .. } => {
629                // Expression-shaped predicates: conservative 0.1 until
630                // the planner learns to walk Expr trees. Matches the
631                // CompareFields default.
632                0.1
633            }
634        }
635    }
636
637    // =========================================================================
638    // Graph Query Estimation
639    // =========================================================================
640
641    fn estimate_graph(&self, query: &GraphQuery) -> PlanCost {
642        let cardinality = self.estimate_graph_cardinality(query);
643
644        // Graph queries are more expensive due to pointer chasing
645        let nodes = query.pattern.nodes.len() as f64;
646        let edges = query.pattern.edges.len() as f64;
647
648        let cpu = cardinality.rows * self.edge_traversal_cost * (nodes + edges);
649        let io = cardinality.rows * 0.1; // More random IO
650        let memory = cardinality.rows * 200.0; // Larger due to paths
651
652        PlanCost::new(cpu, io, memory)
653    }
654
655    fn estimate_graph_cardinality(&self, query: &GraphQuery) -> CardinalityEstimate {
656        let nodes = query.pattern.nodes.len() as f64;
657        let edges = query.pattern.edges.len() as f64;
658
659        // Each edge reduces cardinality
660        let base_rows = self.default_row_count;
661        let edge_factor = 0.1_f64.powf(edges); // Each edge is highly selective
662
663        let mut estimate = CardinalityEstimate::new(base_rows * nodes * edge_factor, edge_factor);
664        estimate.confidence = 0.5; // Graph estimates are less accurate
665
666        // Apply filter
667        if let Some(ref filter) = query.filter {
668            let selectivity = Self::estimate_filter_selectivity(filter);
669            estimate = estimate.with_filter(selectivity);
670        }
671
672        estimate
673    }
674
675    // =========================================================================
676    // Join Query Estimation
677    // =========================================================================
678
679    fn estimate_join(&self, query: &JoinQuery) -> PlanCost {
680        let left_cost = self.estimate(&query.left);
681        let right_cost = self.estimate(&query.right);
682
683        let left_card = self.estimate_cardinality(&query.left);
684        let right_card = self.estimate_cardinality(&query.right);
685
686        // Hash join cost model.
687        //
688        // Build side (left) is **blocking** — we must drain the entire
689        // left input and populate the hash table before any probe can
690        // produce its first output row. Probe side (right) is then
691        // streamed pipelined.
692        let build_cpu = left_card.rows * self.hash_probe_cost;
693        let probe_cpu = right_card.rows * self.hash_probe_cost;
694        let join_memory = left_card.rows * 100.0; // hash table footprint
695
696        // The build operator: zero work upstream, blocking on left input.
697        let build_op = PlanCost::with_startup(build_cpu, 0.0, join_memory, build_cpu);
698        // The probe operator: pipelined over right input.
699        let probe_op = PlanCost::new(probe_cpu, 0.0, 0.0);
700
701        // Compose: left → block on build → pipelined probe with right.
702        let after_build = left_cost.combine_blocking(&build_op);
703        after_build
704            .combine_pipelined(&right_cost)
705            .combine_pipelined(&probe_op)
706    }
707
708    fn estimate_join_cardinality(&self, query: &JoinQuery) -> CardinalityEstimate {
709        let left = self.estimate_cardinality(&query.left);
710        let right = self.estimate_cardinality(&query.right);
711
712        // Join selectivity based on join type
713        let selectivity = match query.join_type {
714            JoinType::Inner => 0.1,      // Inner join is selective
715            JoinType::LeftOuter => 1.0,  // Left join preserves left side
716            JoinType::RightOuter => 1.0, // Right join preserves right side
717            JoinType::FullOuter => 1.0,  // Full outer preserves both sides entirely
718            JoinType::Cross => 1.0,      // Cartesian product — every pair matches
719        };
720
721        CardinalityEstimate::new(
722            left.rows * right.rows * selectivity,
723            left.selectivity * right.selectivity * selectivity,
724        )
725    }
726
727    // =========================================================================
728    // Path Query Estimation
729    // =========================================================================
730
731    fn estimate_path(&self, query: &PathQuery) -> PlanCost {
732        let cardinality = self.estimate_path_cardinality(query);
733
734        // BFS/DFS cost
735        let max_hops = query.max_length;
736        let branching_factor: f64 = 5.0; // Average edges per node
737
738        let nodes_visited = branching_factor.powf(max_hops as f64).min(10000.0);
739        let cpu = nodes_visited * self.edge_traversal_cost;
740        let io = nodes_visited * 0.1;
741        let memory = nodes_visited * 50.0; // Visited set
742
743        PlanCost::new(cpu, io, memory)
744    }
745
746    fn estimate_path_cardinality(&self, query: &PathQuery) -> CardinalityEstimate {
747        // Path queries typically return few results
748        let max_paths = 10.0;
749        CardinalityEstimate::new(max_paths, 0.001)
750    }
751
752    // =========================================================================
753    // Vector Query Estimation
754    // =========================================================================
755
756    fn estimate_vector(&self, query: &VectorQuery) -> PlanCost {
757        // HNSW search is O(log n) with relatively low constant
758        // Typical search visits ~100-500 nodes for 1M vectors
759        let k = query.k as f64;
760
761        // Base cost from HNSW traversal — must descend the layer graph
762        // before *any* candidate can be returned. This is the operator's
763        // intrinsic startup cost.
764        let hnsw_cost = 100.0 * (1.0 + k.ln()); // ~100-300 node visits
765
766        // Metadata filtering adds cost if present
767        let filter_cost =
768            if crate::storage::query::sql_lowering::effective_vector_filter(query).is_some() {
769                50.0
770            } else {
771                0.0
772            };
773
774        let cpu = hnsw_cost + filter_cost;
775        let io = 20.0; // HNSW layers are cached
776        let memory = k * 32.0 + 1000.0; // k results + working set
777
778        // Vector search is *partly* blocking: HNSW must traverse the
779        // entry layers before the first neighbour is known, so the
780        // first-row cost is roughly the descent cost. Subsequent rows
781        // come essentially free until `k`.
782        PlanCost::with_startup(cpu, io, memory, hnsw_cost * 0.5)
783    }
784
785    fn estimate_vector_cardinality(&self, query: &VectorQuery) -> CardinalityEstimate {
786        // Vector search returns exactly k results (or fewer if not enough vectors)
787        let k = query.k as f64;
788        CardinalityEstimate::new(k, 0.1)
789    }
790
791    // =========================================================================
792    // Hybrid Query Estimation
793    // =========================================================================
794
795    fn estimate_hybrid(&self, query: &HybridQuery) -> PlanCost {
796        // Hybrid cost = structured + vector + fusion overhead
797        let structured_cost = self.estimate(&query.structured);
798        let vector_cost = self.estimate_vector(&query.vector);
799
800        // Fusion overhead depends on strategy
801        let fusion_overhead = match &query.fusion {
802            crate::storage::query::ast::FusionStrategy::Rerank { .. } => 50.0,
803            crate::storage::query::ast::FusionStrategy::FilterThenSearch => 10.0,
804            crate::storage::query::ast::FusionStrategy::SearchThenFilter => 10.0,
805            crate::storage::query::ast::FusionStrategy::RRF { .. } => 30.0,
806            crate::storage::query::ast::FusionStrategy::Intersection => 20.0,
807            crate::storage::query::ast::FusionStrategy::Union { .. } => 40.0,
808        };
809
810        PlanCost::new(
811            structured_cost.cpu + vector_cost.cpu + fusion_overhead,
812            structured_cost.io + vector_cost.io,
813            structured_cost.memory + vector_cost.memory,
814        )
815    }
816
817    fn estimate_hybrid_cardinality(&self, query: &HybridQuery) -> CardinalityEstimate {
818        let structured_card = self.estimate_cardinality(&query.structured);
819        let vector_card = self.estimate_vector_cardinality(&query.vector);
820
821        // Result size depends on fusion strategy
822        let rows = match &query.fusion {
823            crate::storage::query::ast::FusionStrategy::Intersection => {
824                structured_card.rows.min(vector_card.rows)
825            }
826            crate::storage::query::ast::FusionStrategy::Union { .. } => {
827                structured_card.rows + vector_card.rows
828            }
829            _ => vector_card.rows, // Rerank and filter strategies return vector k
830        };
831
832        CardinalityEstimate::new(rows, 0.2)
833    }
834
835    // =========================================================================
836    // Filter Selectivity
837    // =========================================================================
838
839    fn estimate_filter_selectivity(filter: &AstFilter) -> f64 {
840        match filter {
841            AstFilter::Compare { op, .. } => {
842                match op {
843                    CompareOp::Eq => 0.01, // Equality is very selective
844                    CompareOp::Ne => 0.99, // Inequality is not selective
845                    CompareOp::Lt | CompareOp::Le => 0.3,
846                    CompareOp::Gt | CompareOp::Ge => 0.3,
847                }
848            }
849            AstFilter::Between { .. } => 0.25,
850            AstFilter::In { values, .. } => {
851                // Each value adds 1% selectivity
852                (values.len() as f64 * 0.01).min(0.5)
853            }
854            AstFilter::Like { .. } => 0.1,
855            AstFilter::StartsWith { .. } => 0.15,
856            AstFilter::EndsWith { .. } => 0.15,
857            AstFilter::Contains { .. } => 0.1,
858            AstFilter::IsNull { .. } => 0.01,
859            AstFilter::IsNotNull { .. } => 0.99,
860            AstFilter::And(left, right) => {
861                Self::estimate_filter_selectivity(left) * Self::estimate_filter_selectivity(right)
862            }
863            AstFilter::Or(left, right) => {
864                let s1 = Self::estimate_filter_selectivity(left);
865                let s2 = Self::estimate_filter_selectivity(right);
866                s1 + s2 - (s1 * s2) // Inclusion-exclusion
867            }
868            AstFilter::Not(inner) => 1.0 - Self::estimate_filter_selectivity(inner),
869            AstFilter::CompareFields { .. } => 0.1,
870            AstFilter::CompareExpr { .. } => 0.1,
871        }
872    }
873}
874
875impl CostEstimator {
876    /// Equality selectivity for `column = value`.
877    ///
878    /// Resolution order:
879    /// 1. MCV list — exact frequency for tracked values, residual
880    ///    formula for untracked values.
881    /// 2. `index_stats.point_selectivity()` — `1 / distinct_keys`.
882    /// 3. Heuristic constant `0.01`.
883    fn eq_selectivity(&self, table: &str, column: Option<&str>, value: &Value) -> f64 {
884        if let Some(col) = column {
885            // 1. Most-common-values lookup.
886            if let Some(mcv) = self.stats.column_mcv(table, col) {
887                if let Some(cv) = column_value_from(value) {
888                    if let Some(freq) = mcv.frequency_of(&cv) {
889                        return freq;
890                    }
891                    // Untracked value: residual / non_mcv_distinct.
892                    let total = mcv.total_frequency();
893                    let distinct = self.stats.distinct_values(table, col).unwrap_or(100);
894                    let non_mcv_distinct = distinct.saturating_sub(mcv.values.len() as u64).max(1);
895                    return ((1.0 - total) / non_mcv_distinct as f64).clamp(0.0, 1.0);
896                }
897            }
898            // 2. Index stats fallback.
899            if let Some(s) = self.stats.index_stats(table, col) {
900                return s.point_selectivity();
901            }
902        }
903        // 3. Heuristic.
904        0.01
905    }
906
907    /// Range selectivity for `lo <= column <= hi`. Either bound may
908    /// be `None` to express an open side. Used by `<`, `<=`, `>`,
909    /// `>=`, and `BETWEEN`.
910    ///
911    /// Resolution order:
912    /// 1. Histogram — `Histogram::range_selectivity` with bounds
913    ///    converted via `column_value_from`.
914    /// 2. `index_stats.point_selectivity() * (distinct_keys / 2)`
915    ///    capped at the legacy heuristic.
916    /// 3. Heuristic `0.3` for one-sided, `0.25` for two-sided.
917    fn range_selectivity(
918        &self,
919        table: &str,
920        column: Option<&str>,
921        lo: Option<&Value>,
922        hi: Option<&Value>,
923    ) -> f64 {
924        if let Some(col) = column {
925            // 1. Histogram bucket arithmetic.
926            if let Some(h) = self.stats.column_histogram(table, col) {
927                let lo_cv = lo.and_then(column_value_from);
928                let hi_cv = hi.and_then(column_value_from);
929                return h.range_selectivity(lo_cv.as_ref(), hi_cv.as_ref());
930            }
931            // 2. Index stats fallback.
932            if let Some(s) = self.stats.index_stats(table, col) {
933                let cap = if lo.is_some() && hi.is_some() {
934                    0.25
935                } else {
936                    0.3
937                };
938                return (s.point_selectivity() * (s.distinct_keys as f64 / 2.0)).min(cap);
939            }
940        }
941        // 3. Heuristic.
942        if lo.is_some() && hi.is_some() {
943            0.25
944        } else {
945            0.3
946        }
947    }
948}
949
950impl Default for CostEstimator {
951    fn default() -> Self {
952        Self::new()
953    }
954}
955
956/// Convert a query AST `Value` into a histogram-comparable
957/// [`super::histogram::ColumnValue`]. Returns `None` for value types
958/// that histograms don't support (Bool, Null, Bytes, etc.) — callers
959/// fall through to the heuristic path.
960fn column_value_from(v: &crate::storage::schema::Value) -> Option<super::histogram::ColumnValue> {
961    use super::histogram::ColumnValue;
962    use crate::storage::schema::Value;
963    match v {
964        Value::Integer(i) | Value::BigInt(i) => Some(ColumnValue::Int(*i)),
965        Value::UnsignedInteger(u) => Some(ColumnValue::Int(*u as i64)),
966        Value::Float(f) if f.is_finite() => Some(ColumnValue::Float(*f)),
967        Value::Text(s) => Some(ColumnValue::Text(s.to_string())),
968        Value::Email(s)
969        | Value::Url(s)
970        | Value::NodeRef(s)
971        | Value::EdgeRef(s)
972        | Value::TableRef(s)
973        | Value::Password(s) => Some(ColumnValue::Text(s.clone())),
974        Value::Timestamp(t) => Some(ColumnValue::Int(*t)),
975        Value::Duration(d) => Some(ColumnValue::Int(*d)),
976        Value::TimestampMs(t) => Some(ColumnValue::Int(*t)),
977        Value::Decimal(d) => Some(ColumnValue::Int(*d)),
978        Value::Date(d) => Some(ColumnValue::Int(i64::from(*d))),
979        Value::Time(t) => Some(ColumnValue::Int(i64::from(*t))),
980        Value::Phone(p) => Some(ColumnValue::Int(*p as i64)),
981        Value::Semver(v) => Some(ColumnValue::Int(i64::from(*v))),
982        Value::Port(v) => Some(ColumnValue::Int(i64::from(*v))),
983        Value::PageRef(v) => Some(ColumnValue::Int(i64::from(*v))),
984        Value::EnumValue(v) => Some(ColumnValue::Int(i64::from(*v))),
985        Value::Latitude(v) => Some(ColumnValue::Int(i64::from(*v))),
986        Value::Longitude(v) => Some(ColumnValue::Int(i64::from(*v))),
987        // Other variants (Null, Blob, Boolean, IpAddr, MacAddr,
988        // Vector, Json, Uuid, NodeRef, EdgeRef, vector ref...) are
989        // not orderable in a histogram-meaningful way; the planner
990        // falls through to the heuristic for these.
991        _ => None,
992    }
993}
994
995/// Resolve a `FieldRef` to a bare column name when it refers to `table`.
996/// Returns `None` when the field targets another relation — in that case
997/// Extract the first plain column name from a filter for index-stat lookup.
998/// Walks AND nodes; stops at OR/NOT (too complex for simple correlation lookup).
999fn first_filter_column<'a>(filter: &'a AstFilter, table: &str) -> Option<&'a str> {
1000    match filter {
1001        AstFilter::Compare { field, .. } => column_name_for_table(field, table),
1002        AstFilter::Between { field, .. } => column_name_for_table(field, table),
1003        AstFilter::And(l, r) => {
1004            first_filter_column(l, table).or_else(|| first_filter_column(r, table))
1005        }
1006        _ => None,
1007    }
1008}
1009
1010/// the legacy heuristic still applies.
1011fn column_name_for_table<'a>(field: &'a FieldRef, table: &str) -> Option<&'a str> {
1012    match field {
1013        FieldRef::TableColumn { table: t, column } if t == table || t.is_empty() => {
1014            Some(column.as_str())
1015        }
1016        // Node / edge property refs don't map to table-level stats.
1017        _ => None,
1018    }
1019}
1020
1021#[cfg(test)]
1022mod tests {
1023    use super::super::stats_provider::StaticProvider;
1024    use super::*;
1025    use crate::storage::index::{IndexKind, IndexStats};
1026    use crate::storage::query::ast::{FieldRef, Projection};
1027    use crate::storage::schema::Value;
1028
1029    fn eq_filter(table: &str, column: &str, value: i64) -> AstFilter {
1030        AstFilter::Compare {
1031            field: FieldRef::column(table, column),
1032            op: CompareOp::Eq,
1033            value: Value::Integer(value),
1034        }
1035    }
1036
1037    fn table_query(name: &str, filter: Option<AstFilter>) -> TableQuery {
1038        TableQuery {
1039            table: name.to_string(),
1040            source: None,
1041            alias: None,
1042            select_items: Vec::new(),
1043            columns: vec![Projection::All],
1044            where_expr: None,
1045            filter,
1046            group_by_exprs: Vec::new(),
1047            group_by: Vec::new(),
1048            having_expr: None,
1049            having: None,
1050            order_by: vec![],
1051            limit: None,
1052            limit_param: None,
1053            offset: None,
1054            offset_param: None,
1055            expand: None,
1056            as_of: None,
1057            sessionize: None,
1058        }
1059    }
1060
1061    #[test]
1062    fn injected_row_count_overrides_default() {
1063        let provider = Arc::new(StaticProvider::new().with_table(
1064            "users",
1065            TableStats {
1066                row_count: 50_000,
1067                avg_row_size: 256,
1068                page_count: 500,
1069                columns: vec![],
1070            },
1071        ));
1072        let estimator = CostEstimator::with_stats(provider);
1073        let q = table_query("users", None);
1074        let card = estimator.estimate_table_cardinality(&q);
1075        // Default would be 1000; provider says 50_000.
1076        assert_eq!(card.rows, 50_000.0);
1077    }
1078
1079    #[test]
1080    fn stats_aware_eq_selectivity_beats_default() {
1081        let provider = Arc::new(
1082            StaticProvider::new()
1083                .with_table(
1084                    "users",
1085                    TableStats {
1086                        row_count: 1_000_000,
1087                        avg_row_size: 256,
1088                        page_count: 10_000,
1089                        columns: vec![],
1090                    },
1091                )
1092                .with_index(
1093                    "users",
1094                    "email",
1095                    IndexStats {
1096                        entries: 1_000_000,
1097                        distinct_keys: 1_000_000,
1098                        approx_bytes: 0,
1099                        kind: IndexKind::Hash,
1100                        has_bloom: true,
1101                        index_correlation: 0.0,
1102                    },
1103                ),
1104        );
1105        let estimator = CostEstimator::with_stats(provider);
1106        let q = table_query("users", Some(eq_filter("users", "email", 0)));
1107        let card = estimator.estimate_table_cardinality(&q);
1108        // 1M rows × (1 / 1M distinct) ≈ 1 row
1109        assert!(card.rows < 2.0, "expected ~1 row, got {}", card.rows);
1110    }
1111
1112    #[test]
1113    fn fallback_when_no_index_stats() {
1114        let provider = Arc::new(StaticProvider::new().with_table(
1115            "users",
1116            TableStats {
1117                row_count: 1_000_000,
1118                avg_row_size: 256,
1119                page_count: 10_000,
1120                columns: vec![],
1121            },
1122        ));
1123        let estimator = CostEstimator::with_stats(provider);
1124        let q = table_query("users", Some(eq_filter("users", "email", 0)));
1125        let card = estimator.estimate_table_cardinality(&q);
1126        // Heuristic 0.01 on 1M rows = 10_000
1127        assert!((card.rows - 10_000.0).abs() < 1.0);
1128    }
1129
1130    #[test]
1131    fn null_provider_keeps_legacy_behaviour() {
1132        let estimator = CostEstimator::new();
1133        let q = table_query("whatever", Some(eq_filter("whatever", "id", 1)));
1134        let card = estimator.estimate_table_cardinality(&q);
1135        // Default 1000 rows × 0.01 eq selectivity = 10
1136        assert!((card.rows - 10.0).abs() < 1.0);
1137    }
1138
1139    #[test]
1140    fn and_combines_stats_selectivities() {
1141        let provider = Arc::new(
1142            StaticProvider::new()
1143                .with_table(
1144                    "t",
1145                    TableStats {
1146                        row_count: 100_000,
1147                        avg_row_size: 64,
1148                        page_count: 100,
1149                        columns: vec![],
1150                    },
1151                )
1152                .with_index(
1153                    "t",
1154                    "a",
1155                    IndexStats {
1156                        entries: 100_000,
1157                        distinct_keys: 10,
1158                        approx_bytes: 0,
1159                        kind: IndexKind::BTree,
1160                        has_bloom: false,
1161                        index_correlation: 0.0,
1162                    },
1163                )
1164                .with_index(
1165                    "t",
1166                    "b",
1167                    IndexStats {
1168                        entries: 100_000,
1169                        distinct_keys: 1000,
1170                        approx_bytes: 0,
1171                        kind: IndexKind::BTree,
1172                        has_bloom: false,
1173                        index_correlation: 0.0,
1174                    },
1175                ),
1176        );
1177        let estimator = CostEstimator::with_stats(provider);
1178        let filter = AstFilter::And(
1179            Box::new(eq_filter("t", "a", 1)),
1180            Box::new(eq_filter("t", "b", 1)),
1181        );
1182        let q = table_query("t", Some(filter));
1183        let card = estimator.estimate_table_cardinality(&q);
1184        // 100_000 × (1/10) × (1/1000) = 10
1185        assert!(card.rows < 15.0, "got {}", card.rows);
1186    }
1187
1188    #[test]
1189    fn test_table_cost_estimation() {
1190        let estimator = CostEstimator::new();
1191
1192        let query = QueryExpr::Table(TableQuery {
1193            table: "hosts".to_string(),
1194            source: None,
1195            alias: None,
1196            select_items: Vec::new(),
1197            columns: vec![Projection::All],
1198            where_expr: None,
1199            filter: None,
1200            group_by_exprs: Vec::new(),
1201            group_by: Vec::new(),
1202            having_expr: None,
1203            having: None,
1204            order_by: vec![],
1205            limit: None,
1206            limit_param: None,
1207            offset: None,
1208            offset_param: None,
1209            expand: None,
1210            as_of: None,
1211            sessionize: None,
1212        });
1213
1214        let cost = estimator.estimate(&query);
1215        assert!(cost.cpu > 0.0);
1216        assert!(cost.total > 0.0);
1217    }
1218
1219    #[test]
1220    fn test_filter_selectivity() {
1221        let estimator = CostEstimator::new();
1222
1223        let eq_filter = AstFilter::Compare {
1224            field: FieldRef::column("hosts", "id"),
1225            op: CompareOp::Eq,
1226            value: Value::Integer(1),
1227        };
1228        assert!(CostEstimator::estimate_filter_selectivity(&eq_filter) < 0.1);
1229
1230        let ne_filter = AstFilter::Compare {
1231            field: FieldRef::column("hosts", "id"),
1232            op: CompareOp::Ne,
1233            value: Value::Integer(1),
1234        };
1235        assert!(CostEstimator::estimate_filter_selectivity(&ne_filter) > 0.9);
1236    }
1237
1238    #[test]
1239    fn test_and_selectivity() {
1240        let estimator = CostEstimator::new();
1241
1242        let and_filter = AstFilter::And(
1243            Box::new(AstFilter::Compare {
1244                field: FieldRef::column("hosts", "a"),
1245                op: CompareOp::Eq,
1246                value: Value::Integer(1),
1247            }),
1248            Box::new(AstFilter::Compare {
1249                field: FieldRef::column("hosts", "b"),
1250                op: CompareOp::Eq,
1251                value: Value::Integer(2),
1252            }),
1253        );
1254
1255        let selectivity = CostEstimator::estimate_filter_selectivity(&and_filter);
1256        assert!(selectivity < 0.01); // AND should be very selective
1257    }
1258
1259    #[test]
1260    fn test_cardinality_with_limit() {
1261        let estimator = CostEstimator::new();
1262
1263        let query = TableQuery {
1264            table: "hosts".to_string(),
1265            source: None,
1266            alias: None,
1267            select_items: Vec::new(),
1268            columns: vec![Projection::All],
1269            where_expr: None,
1270            filter: None,
1271            group_by_exprs: Vec::new(),
1272            group_by: Vec::new(),
1273            having_expr: None,
1274            having: None,
1275            order_by: vec![],
1276            limit: Some(10),
1277            limit_param: None,
1278            offset: None,
1279            offset_param: None,
1280            expand: None,
1281            as_of: None,
1282            sessionize: None,
1283        };
1284
1285        let card = estimator.estimate_table_cardinality(&query);
1286        assert!(card.rows <= 10.0);
1287    }
1288
1289    // ---------------------------------------------------------------
1290    // Target 2: startup_cost vs total_cost split
1291    // ---------------------------------------------------------------
1292
1293    #[test]
1294    fn startup_zero_for_full_scan() {
1295        // estimate_table is implemented as a streaming sequential scan
1296        // (no startup cost — the first row is producible as soon as the
1297        // first page is read).
1298        let estimator = CostEstimator::new();
1299        let q = table_query("any_table", None);
1300        let cost = estimator.estimate(&QueryExpr::Table(q));
1301        assert_eq!(cost.startup_cost, 0.0, "full scan must have zero startup");
1302        assert!(cost.total > 0.0);
1303    }
1304
1305    #[test]
1306    fn startup_nonzero_for_blocking_combine() {
1307        // combine_blocking models a sort or hash build: the input must
1308        // be fully consumed before the blocker can emit its first row.
1309        let input = PlanCost::new(100.0, 10.0, 50.0); // cost = 100 + 100 + 5 = 205
1310        let blocker = PlanCost::new(20.0, 0.0, 10.0); // cost = 20 + 0 + 1 = 21
1311        let composed = input.combine_blocking(&blocker);
1312        // Blocking startup absorbs all of input.total
1313        assert_eq!(composed.startup_cost, input.total);
1314        // Total is input.total + blocker.total
1315        assert_eq!(composed.total, input.total + blocker.total);
1316        assert!(composed.startup_cost > 0.0);
1317    }
1318
1319    #[test]
1320    fn pipelined_combine_adds_startup_directly() {
1321        let upstream = PlanCost::with_startup(50.0, 5.0, 10.0, 30.0);
1322        let downstream = PlanCost::with_startup(20.0, 0.0, 0.0, 5.0);
1323        let composed = upstream.combine_pipelined(&downstream);
1324        assert_eq!(composed.startup_cost, 30.0 + 5.0);
1325        assert_eq!(composed.total, upstream.total + downstream.total);
1326    }
1327
1328    #[test]
1329    fn cost_prefers_low_startup_when_limit_small() {
1330        // Two plans with the same total but different startup. With a
1331        // small LIMIT, the planner must pick the low-startup plan.
1332        let fast_first = PlanCost {
1333            cpu: 100.0,
1334            io: 10.0,
1335            network: 0.0,
1336            memory: 50.0,
1337            startup_cost: 5.0,
1338            total: 200.0,
1339        };
1340        let slow_first = PlanCost {
1341            cpu: 100.0,
1342            io: 10.0,
1343            network: 0.0,
1344            memory: 50.0,
1345            startup_cost: 150.0,
1346            total: 200.0,
1347        };
1348        // Cardinality 10_000, LIMIT 10 → 10 < 0.1 * 10_000 = 1000 → use startup.
1349        assert_eq!(
1350            fast_first.prefer_over(&slow_first, Some(10), 10_000.0),
1351            std::cmp::Ordering::Less
1352        );
1353    }
1354
1355    #[test]
1356    fn cost_prefers_low_total_when_no_limit() {
1357        // Same two plans, no LIMIT — total wins.
1358        let low_total = PlanCost {
1359            cpu: 50.0,
1360            io: 5.0,
1361            network: 0.0,
1362            memory: 0.0,
1363            startup_cost: 30.0,
1364            total: 100.0,
1365        };
1366        let high_total = PlanCost {
1367            cpu: 100.0,
1368            io: 10.0,
1369            network: 0.0,
1370            memory: 0.0,
1371            startup_cost: 5.0,
1372            total: 200.0,
1373        };
1374        assert_eq!(
1375            low_total.prefer_over(&high_total, None, 10_000.0),
1376            std::cmp::Ordering::Less
1377        );
1378    }
1379
1380    #[test]
1381    fn limit_threshold_falls_back_to_total_when_limit_large() {
1382        // LIMIT 5000 vs cardinality 10_000 → 5000 > 1000 → use total.
1383        let low_total = PlanCost {
1384            cpu: 50.0,
1385            io: 5.0,
1386            network: 0.0,
1387            memory: 0.0,
1388            startup_cost: 30.0,
1389            total: 100.0,
1390        };
1391        let low_startup = PlanCost {
1392            cpu: 100.0,
1393            io: 10.0,
1394            network: 0.0,
1395            memory: 0.0,
1396            startup_cost: 5.0,
1397            total: 200.0,
1398        };
1399        assert_eq!(
1400            low_total.prefer_over(&low_startup, Some(5000), 10_000.0),
1401            std::cmp::Ordering::Less
1402        );
1403    }
1404
1405    #[test]
1406    fn hash_join_startup_includes_build_cost() {
1407        // Direct combine_blocking semantics: a hash join must drain the
1408        // left input and build the hash table before producing the first
1409        // probe result.
1410        let left = PlanCost::new(80.0, 8.0, 100.0); // table scan
1411        let build = PlanCost::with_startup(50.0, 0.0, 200.0, 50.0); // build op
1412        let after_build = left.combine_blocking(&build);
1413        assert!(
1414            after_build.startup_cost >= left.total,
1415            "after-build startup ({}) must absorb left.total ({})",
1416            after_build.startup_cost,
1417            left.total
1418        );
1419        assert!(after_build.total >= after_build.startup_cost);
1420    }
1421
1422    #[test]
1423    fn vector_search_reports_nonzero_startup() {
1424        // estimate_vector now uses with_startup so HNSW descent shows
1425        // up as startup_cost > 0 (and < total — subsequent neighbours
1426        // are essentially free).
1427        let estimator = CostEstimator::new();
1428        // We can't easily build a VectorQuery without the AST helpers,
1429        // so test the direct cost surface with_startup uses.
1430        let v = PlanCost::with_startup(150.0, 20.0, 1320.0, 50.0);
1431        assert!(v.startup_cost > 0.0);
1432        assert!(v.startup_cost < v.total);
1433        let _ = estimator; // suppress unused
1434    }
1435
1436    #[test]
1437    fn with_startup_clamps_total_below_startup() {
1438        // If a caller asks for total < startup, with_startup raises total.
1439        let cost = PlanCost::with_startup(1.0, 0.0, 0.0, 100.0);
1440        assert!(cost.total >= cost.startup_cost);
1441    }
1442
1443    #[test]
1444    fn default_plancost_has_zero_startup() {
1445        let c = PlanCost::default();
1446        assert_eq!(c.startup_cost, 0.0);
1447        assert_eq!(c.total, 0.0);
1448    }
1449
1450    // ---------------------------------------------------------------
1451    // Perf 1.3: histogram + MCV plug-in into filter_selectivity
1452    // ---------------------------------------------------------------
1453
1454    use super::super::histogram::{ColumnValue, Histogram, MostCommonValues};
1455
1456    fn provider_with_skew() -> Arc<StaticProvider> {
1457        // Build a histogram where 80 of 100 values fall in [0, 9]
1458        // and the rest spread sparsely up to 1000. range_selectivity
1459        // for `<= 9` should be ~0.8, vastly beating the heuristic 0.3.
1460        let mut sample: Vec<ColumnValue> = Vec::new();
1461        for i in 0..80 {
1462            sample.push(ColumnValue::Int(i % 10));
1463        }
1464        for i in 0..20 {
1465            sample.push(ColumnValue::Int(10 + i * 50));
1466        }
1467        let h = Histogram::equi_depth_from_sample(sample, 10);
1468
1469        let mcv = MostCommonValues::new(vec![
1470            (ColumnValue::Text("boss".to_string()), 0.5),
1471            (ColumnValue::Text("intern".to_string()), 0.05),
1472        ]);
1473
1474        Arc::new(
1475            StaticProvider::new()
1476                .with_table(
1477                    "people",
1478                    TableStats {
1479                        row_count: 100_000,
1480                        avg_row_size: 64,
1481                        page_count: 100,
1482                        columns: vec![],
1483                    },
1484                )
1485                .with_histogram("people", "score", h)
1486                .with_mcv("people", "role", mcv),
1487        )
1488    }
1489
1490    #[test]
1491    fn eq_uses_mcv_when_value_is_tracked() {
1492        let provider = provider_with_skew();
1493        let estimator = CostEstimator::with_stats(provider);
1494        let filter = AstFilter::Compare {
1495            field: FieldRef::column("people", "role"),
1496            op: CompareOp::Eq,
1497            value: Value::text("boss".to_string()),
1498        };
1499        // MCV says "boss" is 50% of the table → selectivity 0.5,
1500        // not the 0.01 heuristic.
1501        let s = estimator.filter_selectivity(&filter, "people");
1502        assert!(
1503            (s - 0.5).abs() < 1e-9,
1504            "MCV-tracked equality should report exact frequency, got {s}"
1505        );
1506    }
1507
1508    #[test]
1509    fn eq_uses_residual_for_non_mcv_value() {
1510        let provider = provider_with_skew();
1511        let estimator = CostEstimator::with_stats(provider);
1512        let filter = AstFilter::Compare {
1513            field: FieldRef::column("people", "role"),
1514            op: CompareOp::Eq,
1515            value: Value::text("staff".to_string()),
1516        };
1517        // 1 - 0.55 (mcv totals) = 0.45 spread across (distinct - 2)
1518        // distinct values. We don't have an exact distinct count, so
1519        // the planner uses the default 100 → 0.45 / 98 ≈ 0.0046.
1520        let s = estimator.filter_selectivity(&filter, "people");
1521        assert!(s > 0.0 && s < 0.01, "residual eq should be tiny, got {s}");
1522    }
1523
1524    #[test]
1525    fn ne_is_one_minus_eq_under_mcv() {
1526        let provider = provider_with_skew();
1527        let estimator = CostEstimator::with_stats(provider);
1528        let filter = AstFilter::Compare {
1529            field: FieldRef::column("people", "role"),
1530            op: CompareOp::Ne,
1531            value: Value::text("boss".to_string()),
1532        };
1533        let s = estimator.filter_selectivity(&filter, "people");
1534        // 1 - 0.5 == 0.5
1535        assert!((s - 0.5).abs() < 1e-9, "Ne selectivity = 0.5, got {s}");
1536    }
1537
1538    #[test]
1539    fn range_uses_histogram_when_present() {
1540        let provider = provider_with_skew();
1541        let estimator = CostEstimator::with_stats(provider);
1542        let filter = AstFilter::Compare {
1543            field: FieldRef::column("people", "score"),
1544            op: CompareOp::Le,
1545            value: Value::Integer(9),
1546        };
1547        // Histogram says ~80% of values are in [0, 9], heuristic
1548        // would have said 0.3.
1549        let s = estimator.filter_selectivity(&filter, "people");
1550        assert!(
1551            s > 0.5,
1552            "histogram-based range selectivity should beat 0.3, got {s}"
1553        );
1554    }
1555
1556    #[test]
1557    fn between_uses_histogram() {
1558        let provider = provider_with_skew();
1559        let estimator = CostEstimator::with_stats(provider);
1560        let filter = AstFilter::Between {
1561            field: FieldRef::column("people", "score"),
1562            low: Value::Integer(0),
1563            high: Value::Integer(9),
1564        };
1565        let s = estimator.filter_selectivity(&filter, "people");
1566        assert!(s > 0.5, "BETWEEN should use histogram too, got {s}");
1567    }
1568
1569    #[test]
1570    fn graceful_fallback_when_histogram_absent() {
1571        // Provider has no histogram on `unknown_col` — must fall
1572        // through to the 0.3 heuristic without panicking.
1573        let provider = Arc::new(StaticProvider::new().with_table(
1574            "people",
1575            TableStats {
1576                row_count: 1000,
1577                avg_row_size: 64,
1578                page_count: 10,
1579                columns: vec![],
1580            },
1581        ));
1582        let estimator = CostEstimator::with_stats(provider);
1583        let filter = AstFilter::Compare {
1584            field: FieldRef::column("people", "unknown_col"),
1585            op: CompareOp::Lt,
1586            value: Value::Integer(50),
1587        };
1588        let s = estimator.filter_selectivity(&filter, "people");
1589        assert!((s - 0.3).abs() < 1e-9, "fallback heuristic 0.3, got {s}");
1590    }
1591}