Skip to main content

sochdb_query/
cost_optimizer.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Cost-Based Query Optimizer with Cardinality Estimation (Task 6)
19//!
20//! Provides cost-based query optimization for SOCH-QL with:
21//! - Cardinality estimation using sketches (HyperLogLog, CountMin)
22//! - Index selection: compare cost(table_scan) vs cost(index_seek)
23//! - Column projection pushdown to LSCS layer
24//! - Token-budget-aware planning
25//!
26//! ## Cost Model
27//!
28//! cost(plan) = I/O_cost + CPU_cost + memory_cost
29//!
30//! I/O_cost = blocks_read × C_seq + seeks × C_random
31//! Where:
32//!   C_seq = 0.1 ms/block (sequential read)
33//!   C_random = 5 ms/seek (random seek)
34//!
35//! CPU_cost = rows_processed × C_filter + sorts × N × log(N) × C_compare
36//!
37//! ## Selectivity Estimation
38//!
39//! Uses CountMinSketch for predicate selectivity and HyperLogLog for distinct counts.
40//!
41//! ## Token Budget Planning
42//!
43//! Given max_tokens, estimates result size and injects LIMIT clause:
44//!   max_rows = (max_tokens - header_tokens) / tokens_per_row
45
46use parking_lot::RwLock;
47use std::collections::{HashMap, HashSet};
48use std::sync::Arc;
49use std::time::{SystemTime, UNIX_EPOCH};
50
51// ============================================================================
52// Cost Model Constants
53// ============================================================================
54
55/// Cost model configuration with empirically-derived constants
56#[derive(Debug, Clone)]
57pub struct CostModelConfig {
58    /// Sequential I/O cost per block (ms)
59    pub c_seq: f64,
60    /// Random I/O cost per seek (ms)
61    pub c_random: f64,
62    /// CPU cost per row filter (ms)
63    pub c_filter: f64,
64    /// CPU cost per comparison during sort (ms)
65    pub c_compare: f64,
66    /// Block size in bytes
67    pub block_size: usize,
68    /// B-tree fanout for index cost estimation
69    pub btree_fanout: usize,
70    /// Memory bandwidth (bytes/ms)
71    pub memory_bandwidth: f64,
72}
73
74impl Default for CostModelConfig {
75    fn default() -> Self {
76        Self {
77            c_seq: 0.1,                // 0.1 ms per block sequential
78            c_random: 5.0,             // 5 ms per random seek
79            c_filter: 0.001,           // 0.001 ms per row filter
80            c_compare: 0.0001,         // 0.0001 ms per comparison
81            block_size: 4096,          // 4 KB blocks
82            btree_fanout: 100,         // 100 entries per B-tree node
83            memory_bandwidth: 10000.0, // 10 GB/s = 10000 bytes/ms
84        }
85    }
86}
87
88// ============================================================================
89// Statistics for Cardinality Estimation
90// ============================================================================
91
92/// Table statistics for cost estimation
93#[derive(Debug, Clone)]
94pub struct TableStats {
95    /// Table name
96    pub name: String,
97    /// Total row count
98    pub row_count: u64,
99    /// Total size in bytes
100    pub size_bytes: u64,
101    /// Column statistics
102    pub column_stats: HashMap<String, ColumnStats>,
103    /// Available indices
104    pub indices: Vec<IndexStats>,
105    /// Last update timestamp
106    pub last_updated: u64,
107}
108
109/// Column statistics
110#[derive(Debug, Clone)]
111pub struct ColumnStats {
112    /// Column name
113    pub name: String,
114    /// Distinct value count (from HyperLogLog)
115    pub distinct_count: u64,
116    /// Null count
117    pub null_count: u64,
118    /// Minimum value (if orderable)
119    pub min_value: Option<String>,
120    /// Maximum value (if orderable)
121    pub max_value: Option<String>,
122    /// Average length in bytes (for variable-length types)
123    pub avg_length: f64,
124    /// Most common values with frequencies
125    pub mcv: Vec<(String, f64)>,
126    /// Histogram buckets for range queries
127    pub histogram: Option<Histogram>,
128}
129
130/// Histogram for range selectivity estimation
131#[derive(Debug, Clone)]
132pub struct Histogram {
133    /// Bucket boundaries
134    pub boundaries: Vec<f64>,
135    /// Row count per bucket
136    pub counts: Vec<u64>,
137    /// Total rows in histogram
138    pub total_rows: u64,
139}
140
141impl Histogram {
142    /// Estimate selectivity for a range predicate
143    pub fn estimate_range_selectivity(&self, min: Option<f64>, max: Option<f64>) -> f64 {
144        if self.total_rows == 0 {
145            return 0.5; // Default
146        }
147
148        let mut selected_rows = 0u64;
149
150        for (i, &count) in self.counts.iter().enumerate() {
151            let bucket_min = if i == 0 {
152                f64::NEG_INFINITY
153            } else {
154                self.boundaries[i - 1]
155            };
156            let bucket_max = if i == self.boundaries.len() {
157                f64::INFINITY
158            } else {
159                self.boundaries[i]
160            };
161
162            let overlaps = match (min, max) {
163                (Some(min_val), Some(max_val)) => bucket_max >= min_val && bucket_min <= max_val,
164                (Some(min_val), None) => bucket_max >= min_val,
165                (None, Some(max_val)) => bucket_min <= max_val,
166                (None, None) => true,
167            };
168
169            if overlaps {
170                selected_rows += count;
171            }
172        }
173
174        selected_rows as f64 / self.total_rows as f64
175    }
176}
177
178/// Index statistics
179#[derive(Debug, Clone)]
180pub struct IndexStats {
181    /// Index name
182    pub name: String,
183    /// Indexed columns
184    pub columns: Vec<String>,
185    /// Is primary key
186    pub is_primary: bool,
187    /// Is unique
188    pub is_unique: bool,
189    /// Index type
190    pub index_type: IndexType,
191    /// Number of leaf pages
192    pub leaf_pages: u64,
193    /// Tree height (for B-tree)
194    pub height: u32,
195    /// Average entries per leaf page
196    pub avg_leaf_density: f64,
197}
198
199/// Index types
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum IndexType {
202    BTree,
203    Hash,
204    LSM,
205    Learned,
206    Vector,
207    Bloom,
208}
209
210// ============================================================================
211// Query Predicates and Operations
212// ============================================================================
213
214/// Query predicate for cost estimation
215#[derive(Debug, Clone)]
216pub enum Predicate {
217    /// Equality: column = value
218    Eq { column: String, value: String },
219    /// Inequality: column != value
220    Ne { column: String, value: String },
221    /// Less than: column < value
222    Lt { column: String, value: String },
223    /// Less than or equal: column <= value
224    Le { column: String, value: String },
225    /// Greater than: column > value
226    Gt { column: String, value: String },
227    /// Greater than or equal: column >= value
228    Ge { column: String, value: String },
229    /// Between: column BETWEEN min AND max
230    Between {
231        column: String,
232        min: String,
233        max: String,
234    },
235    /// In list: column IN (v1, v2, ...)
236    In { column: String, values: Vec<String> },
237    /// Like: column LIKE pattern
238    Like { column: String, pattern: String },
239    /// Is null: column IS NULL
240    IsNull { column: String },
241    /// Is not null: column IS NOT NULL
242    IsNotNull { column: String },
243    /// And: pred1 AND pred2
244    And(Box<Predicate>, Box<Predicate>),
245    /// Or: pred1 OR pred2
246    Or(Box<Predicate>, Box<Predicate>),
247    /// Not: NOT pred
248    Not(Box<Predicate>),
249}
250
251impl Predicate {
252    /// Get columns referenced by this predicate
253    pub fn referenced_columns(&self) -> HashSet<String> {
254        let mut cols = HashSet::new();
255        self.collect_columns(&mut cols);
256        cols
257    }
258
259    fn collect_columns(&self, cols: &mut HashSet<String>) {
260        match self {
261            Self::Eq { column, .. }
262            | Self::Ne { column, .. }
263            | Self::Lt { column, .. }
264            | Self::Le { column, .. }
265            | Self::Gt { column, .. }
266            | Self::Ge { column, .. }
267            | Self::Between { column, .. }
268            | Self::In { column, .. }
269            | Self::Like { column, .. }
270            | Self::IsNull { column }
271            | Self::IsNotNull { column } => {
272                cols.insert(column.clone());
273            }
274            Self::And(left, right) | Self::Or(left, right) => {
275                left.collect_columns(cols);
276                right.collect_columns(cols);
277            }
278            Self::Not(inner) => inner.collect_columns(cols),
279        }
280    }
281}
282
283// ============================================================================
284// Physical Plan Operators
285// ============================================================================
286
287/// Physical query plan node
288#[derive(Debug, Clone)]
289pub enum PhysicalPlan {
290    /// Table scan (full or partial)
291    TableScan {
292        table: String,
293        columns: Vec<String>,
294        predicate: Option<Box<Predicate>>,
295        estimated_rows: u64,
296        estimated_cost: f64,
297    },
298    /// Index seek
299    IndexSeek {
300        table: String,
301        index: String,
302        columns: Vec<String>,
303        key_range: KeyRange,
304        predicate: Option<Box<Predicate>>,
305        estimated_rows: u64,
306        estimated_cost: f64,
307    },
308    /// Filter operator
309    Filter {
310        input: Box<PhysicalPlan>,
311        predicate: Predicate,
312        estimated_rows: u64,
313        estimated_cost: f64,
314    },
315    /// Project operator (column subset)
316    Project {
317        input: Box<PhysicalPlan>,
318        columns: Vec<String>,
319        estimated_cost: f64,
320    },
321    /// Sort operator
322    Sort {
323        input: Box<PhysicalPlan>,
324        order_by: Vec<(String, SortDirection)>,
325        estimated_cost: f64,
326    },
327    /// Limit operator
328    Limit {
329        input: Box<PhysicalPlan>,
330        limit: u64,
331        offset: u64,
332        estimated_cost: f64,
333    },
334    /// Nested loop join
335    NestedLoopJoin {
336        outer: Box<PhysicalPlan>,
337        inner: Box<PhysicalPlan>,
338        condition: Predicate,
339        join_type: JoinType,
340        estimated_rows: u64,
341        estimated_cost: f64,
342    },
343    /// Hash join
344    HashJoin {
345        build: Box<PhysicalPlan>,
346        probe: Box<PhysicalPlan>,
347        build_keys: Vec<String>,
348        probe_keys: Vec<String>,
349        join_type: JoinType,
350        estimated_rows: u64,
351        estimated_cost: f64,
352    },
353    /// Merge join
354    MergeJoin {
355        left: Box<PhysicalPlan>,
356        right: Box<PhysicalPlan>,
357        left_keys: Vec<String>,
358        right_keys: Vec<String>,
359        join_type: JoinType,
360        estimated_rows: u64,
361        estimated_cost: f64,
362    },
363    /// Aggregate operator
364    Aggregate {
365        input: Box<PhysicalPlan>,
366        group_by: Vec<String>,
367        aggregates: Vec<AggregateExpr>,
368        estimated_rows: u64,
369        estimated_cost: f64,
370    },
371}
372
373/// Key range for index seeks
374#[derive(Debug, Clone)]
375pub struct KeyRange {
376    pub start: Option<Vec<u8>>,
377    pub end: Option<Vec<u8>>,
378    pub start_inclusive: bool,
379    pub end_inclusive: bool,
380}
381
382impl KeyRange {
383    pub fn all() -> Self {
384        Self {
385            start: None,
386            end: None,
387            start_inclusive: true,
388            end_inclusive: true,
389        }
390    }
391
392    pub fn point(key: Vec<u8>) -> Self {
393        Self {
394            start: Some(key.clone()),
395            end: Some(key),
396            start_inclusive: true,
397            end_inclusive: true,
398        }
399    }
400
401    pub fn range(start: Option<Vec<u8>>, end: Option<Vec<u8>>, inclusive: bool) -> Self {
402        Self {
403            start,
404            end,
405            start_inclusive: inclusive,
406            end_inclusive: inclusive,
407        }
408    }
409}
410
411/// Sort direction
412#[derive(Debug, Clone, Copy, PartialEq, Eq)]
413pub enum SortDirection {
414    Ascending,
415    Descending,
416}
417
418/// Join type
419#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum JoinType {
421    Inner,
422    Left,
423    Right,
424    Full,
425    Cross,
426}
427
428/// Aggregate expression
429#[derive(Debug, Clone)]
430pub struct AggregateExpr {
431    pub function: AggregateFunction,
432    pub column: Option<String>,
433    pub alias: String,
434}
435
436/// Aggregate functions
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
438pub enum AggregateFunction {
439    Count,
440    Sum,
441    Avg,
442    Min,
443    Max,
444    CountDistinct,
445}
446
447// ============================================================================
448// Cost-Based Query Optimizer
449// ============================================================================
450
451/// Cost-based query optimizer
452pub struct CostBasedOptimizer {
453    /// Cost model configuration
454    config: CostModelConfig,
455    /// Table statistics cache
456    stats_cache: Arc<RwLock<HashMap<String, TableStats>>>,
457    /// Token budget for result limiting
458    token_budget: Option<u64>,
459    /// Estimated tokens per row
460    tokens_per_row: f64,
461    /// Plan cache: (table, predicate_hash, limit) -> (plan, timestamp_us)
462    plan_cache: Arc<RwLock<HashMap<u64, (PhysicalPlan, u64)>>>,
463    /// Plan cache TTL in microseconds (default 5 seconds)
464    plan_cache_ttl_us: u64,
465}
466
467impl CostBasedOptimizer {
468    pub fn new(config: CostModelConfig) -> Self {
469        Self {
470            config,
471            stats_cache: Arc::new(RwLock::new(HashMap::new())),
472            token_budget: None,
473            tokens_per_row: 25.0, // Default estimate
474            plan_cache: Arc::new(RwLock::new(HashMap::new())),
475            plan_cache_ttl_us: 5_000_000, // 5 seconds
476        }
477    }
478
479    /// Set plan cache TTL
480    pub fn with_plan_cache_ttl_ms(mut self, ttl_ms: u64) -> Self {
481        self.plan_cache_ttl_us = ttl_ms * 1000;
482        self
483    }
484
485    /// Set token budget for result limiting
486    pub fn with_token_budget(mut self, budget: u64, tokens_per_row: f64) -> Self {
487        self.token_budget = Some(budget);
488        self.tokens_per_row = tokens_per_row;
489        self
490    }
491
492    /// Update table statistics
493    pub fn update_stats(&self, stats: TableStats) {
494        self.stats_cache.write().insert(stats.name.clone(), stats);
495    }
496
497    /// Get table statistics
498    pub fn get_stats(&self, table: &str) -> Option<TableStats> {
499        self.stats_cache.read().get(table).cloned()
500    }
501
502    /// Optimize a SELECT query
503    pub fn optimize(
504        &self,
505        table: &str,
506        columns: Vec<String>,
507        predicate: Option<Predicate>,
508        order_by: Vec<(String, SortDirection)>,
509        limit: Option<u64>,
510    ) -> PhysicalPlan {
511        let stats = self.get_stats(table);
512
513        // Calculate token-aware limit
514        let effective_limit = self.calculate_token_limit(limit);
515
516        // Get best access path (scan vs index)
517        let mut plan = self.choose_access_path(table, &columns, predicate.as_ref(), &stats);
518
519        // Apply column projection pushdown
520        plan = self.apply_projection_pushdown(plan, columns.clone());
521
522        // Apply sorting if needed
523        if !order_by.is_empty() {
524            plan = self.add_sort(plan, order_by, &stats);
525        }
526
527        // Apply limit
528        if let Some(lim) = effective_limit {
529            plan = PhysicalPlan::Limit {
530                estimated_cost: 0.0,
531                input: Box::new(plan),
532                limit: lim,
533                offset: 0,
534            };
535        }
536
537        plan
538    }
539
540    /// Calculate token-aware limit
541    fn calculate_token_limit(&self, user_limit: Option<u64>) -> Option<u64> {
542        match (self.token_budget, user_limit) {
543            (Some(budget), Some(limit)) => {
544                let header_tokens = 50u64;
545                let usable = budget.saturating_sub(header_tokens);
546                let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
547                Some(limit.min(max_rows))
548            }
549            (Some(budget), None) => {
550                let header_tokens = 50u64;
551                let usable = budget.saturating_sub(header_tokens);
552                let max_rows = (usable as f64 / self.tokens_per_row).max(1.0) as u64;
553                Some(max_rows)
554            }
555            (None, limit) => limit,
556        }
557    }
558
559    /// Choose best access path (table scan vs index seek)
560    fn choose_access_path(
561        &self,
562        table: &str,
563        columns: &[String],
564        predicate: Option<&Predicate>,
565        stats: &Option<TableStats>,
566    ) -> PhysicalPlan {
567        let row_count = stats.as_ref().map(|s| s.row_count).unwrap_or(10000);
568        let size_bytes = stats
569            .as_ref()
570            .map(|s| s.size_bytes)
571            .unwrap_or(row_count * 100);
572
573        // Calculate table scan cost
574        let scan_cost = self.estimate_scan_cost(row_count, size_bytes, predicate);
575
576        // Try to find a suitable index
577        let mut best_index_cost = f64::MAX;
578        let mut best_index: Option<&IndexStats> = None;
579
580        if let Some(table_stats) = stats.as_ref()
581            && let Some(pred) = predicate
582        {
583            let pred_columns = pred.referenced_columns();
584
585            for index in &table_stats.indices {
586                if self.index_covers_predicate(index, &pred_columns) {
587                    let selectivity = self.estimate_selectivity(pred, table_stats);
588                    let index_cost = self.estimate_index_cost(index, row_count, selectivity);
589
590                    if index_cost < best_index_cost {
591                        best_index_cost = index_cost;
592                        best_index = Some(index);
593                    }
594                }
595            }
596        }
597
598        // Choose cheaper option
599        if best_index_cost < scan_cost {
600            let index = best_index.unwrap();
601            let selectivity = predicate
602                .map(|p| self.estimate_selectivity(p, stats.as_ref().unwrap()))
603                .unwrap_or(1.0);
604
605            PhysicalPlan::IndexSeek {
606                table: table.to_string(),
607                index: index.name.clone(),
608                columns: columns.to_vec(),
609                key_range: predicate
610                    .map(|p| Self::derive_key_range(p))
611                    .unwrap_or_else(KeyRange::all),
612                predicate: predicate.map(|p| Box::new(p.clone())),
613                estimated_rows: (row_count as f64 * selectivity).max(1.0) as u64,
614                estimated_cost: best_index_cost,
615            }
616        } else {
617            PhysicalPlan::TableScan {
618                table: table.to_string(),
619                columns: columns.to_vec(),
620                predicate: predicate.map(|p| Box::new(p.clone())),
621                estimated_rows: row_count,
622                estimated_cost: scan_cost,
623            }
624        }
625    }
626
627    /// Check if index covers predicate columns
628    fn index_covers_predicate(&self, index: &IndexStats, pred_columns: &HashSet<String>) -> bool {
629        // Index is useful if it covers at least the first column of the predicate
630        if let Some(first_col) = index.columns.first() {
631            pred_columns.contains(first_col)
632        } else {
633            false
634        }
635    }
636
637    /// Estimate table scan cost
638    ///
639    /// I/O: sequential read all blocks
640    /// CPU: evaluate predicate against every row
641    fn estimate_scan_cost(
642        &self,
643        row_count: u64,
644        size_bytes: u64,
645        _predicate: Option<&Predicate>,
646    ) -> f64 {
647        let blocks = (size_bytes as f64 / self.config.block_size as f64).ceil().max(1.0) as u64;
648
649        // I/O cost: must read all blocks regardless of predicate
650        let io_cost = blocks as f64 * self.config.c_seq;
651
652        // CPU cost: evaluate predicate on every row (scan reads them all)
653        let cpu_cost = row_count as f64 * self.config.c_filter;
654
655        io_cost + cpu_cost
656    }
657
658    /// Estimate index seek cost
659    ///
660    /// Index cost = tree_traversal + leaf_scan + row_fetch
661    fn estimate_index_cost(&self, index: &IndexStats, total_rows: u64, selectivity: f64) -> f64 {
662        // Tree traversal cost (random I/O for each level)
663        let tree_cost = index.height as f64 * self.config.c_random;
664
665        // Leaf scan cost (sequential for matching range)
666        let matching_rows = (total_rows as f64 * selectivity) as u64;
667        let leaf_pages_scanned = (matching_rows as f64 / index.avg_leaf_density).ceil() as u64;
668        let leaf_cost = leaf_pages_scanned as f64 * self.config.c_seq;
669
670        // Row fetch cost (random if not clustered)
671        let fetch_cost = if index.is_primary {
672            0.0 // Clustered index, no extra fetch
673        } else {
674            matching_rows.min(1000) as f64 * self.config.c_random * 0.1 // Batch optimization
675        };
676
677        tree_cost + leaf_cost + fetch_cost
678    }
679
680    /// Estimate predicate selectivity
681    #[allow(clippy::only_used_in_recursion)]
682    fn estimate_selectivity(&self, predicate: &Predicate, stats: &TableStats) -> f64 {
683        match predicate {
684            Predicate::Eq { column, value } => {
685                if let Some(col_stats) = stats.column_stats.get(column) {
686                    // Check MCV first
687                    for (mcv_val, freq) in &col_stats.mcv {
688                        if mcv_val == value {
689                            return *freq;
690                        }
691                    }
692                    // Otherwise use uniform distribution
693                    1.0 / col_stats.distinct_count.max(1) as f64
694                } else {
695                    0.1 // Default 10%
696                }
697            }
698            Predicate::Ne { .. } => 0.9, // 90% pass
699            Predicate::Lt { column, value }
700            | Predicate::Le { column, value }
701            | Predicate::Gt { column, value }
702            | Predicate::Ge { column, value } => {
703                if let Some(col_stats) = stats.column_stats.get(column) {
704                    if let Some(ref hist) = col_stats.histogram {
705                        let val: f64 = value.parse().unwrap_or(0.0);
706                        match predicate {
707                            Predicate::Lt { .. } | Predicate::Le { .. } => {
708                                hist.estimate_range_selectivity(None, Some(val))
709                            }
710                            _ => hist.estimate_range_selectivity(Some(val), None),
711                        }
712                    } else {
713                        0.25 // Default 25%
714                    }
715                } else {
716                    0.25
717                }
718            }
719            Predicate::Between { column, min, max } => {
720                if let Some(col_stats) = stats.column_stats.get(column) {
721                    if let Some(ref hist) = col_stats.histogram {
722                        let min_val: f64 = min.parse().unwrap_or(0.0);
723                        let max_val: f64 = max.parse().unwrap_or(f64::MAX);
724                        hist.estimate_range_selectivity(Some(min_val), Some(max_val))
725                    } else {
726                        0.2
727                    }
728                } else {
729                    0.2
730                }
731            }
732            Predicate::In { column, values } => {
733                if let Some(col_stats) = stats.column_stats.get(column) {
734                    (values.len() as f64 / col_stats.distinct_count.max(1) as f64).min(1.0)
735                } else {
736                    (values.len() as f64 * 0.1).min(0.5)
737                }
738            }
739            Predicate::Like { .. } => 0.15, // Default 15%
740            Predicate::IsNull { column } => {
741                if let Some(col_stats) = stats.column_stats.get(column) {
742                    col_stats.null_count as f64 / stats.row_count.max(1) as f64
743                } else {
744                    0.01
745                }
746            }
747            Predicate::IsNotNull { column } => {
748                if let Some(col_stats) = stats.column_stats.get(column) {
749                    1.0 - (col_stats.null_count as f64 / stats.row_count.max(1) as f64)
750                } else {
751                    0.99
752                }
753            }
754            Predicate::And(left, right) => {
755                // Assume independence
756                self.estimate_selectivity(left, stats) * self.estimate_selectivity(right, stats)
757            }
758            Predicate::Or(left, right) => {
759                let s1 = self.estimate_selectivity(left, stats);
760                let s2 = self.estimate_selectivity(right, stats);
761                // P(A or B) = P(A) + P(B) - P(A and B)
762                (s1 + s2 - s1 * s2).min(1.0)
763            }
764            Predicate::Not(inner) => 1.0 - self.estimate_selectivity(inner, stats),
765        }
766    }
767
768    /// Derive key range from predicate for index seek
769    fn derive_key_range(predicate: &Predicate) -> KeyRange {
770        match predicate {
771            Predicate::Eq { value, .. } => KeyRange::point(value.as_bytes().to_vec()),
772            Predicate::Lt { value, .. } | Predicate::Le { value, .. } => {
773                KeyRange::range(None, Some(value.as_bytes().to_vec()), matches!(predicate, Predicate::Le { .. }))
774            }
775            Predicate::Gt { value, .. } | Predicate::Ge { value, .. } => {
776                KeyRange::range(Some(value.as_bytes().to_vec()), None, matches!(predicate, Predicate::Ge { .. }))
777            }
778            Predicate::Between { min, max, .. } => KeyRange {
779                start: Some(min.as_bytes().to_vec()),
780                end: Some(max.as_bytes().to_vec()),
781                start_inclusive: true,
782                end_inclusive: true,
783            },
784            Predicate::And(left, _) => Self::derive_key_range(left),
785            _ => KeyRange::all(),
786        }
787    }
788
789    /// Apply column projection pushdown
790    ///
791    /// Reduces I/O cost proportionally to the fraction of columns selected.
792    fn apply_projection_pushdown(&self, plan: PhysicalPlan, columns: Vec<String>) -> PhysicalPlan {
793        match plan {
794            PhysicalPlan::TableScan {
795                ref table,
796                predicate,
797                estimated_rows,
798                estimated_cost,
799                columns: ref all_columns,
800                ..
801            } => {
802                // Cost reduction proportional to column selectivity
803                let col_ratio = if all_columns.is_empty() || columns.is_empty() {
804                    1.0
805                } else {
806                    (columns.len() as f64 / all_columns.len().max(1) as f64).clamp(0.1, 1.0)
807                };
808                PhysicalPlan::TableScan {
809                    table: table.clone(),
810                    columns,
811                    predicate,
812                    estimated_rows,
813                    estimated_cost: estimated_cost * col_ratio,
814                }
815            }
816            PhysicalPlan::IndexSeek {
817                table,
818                index,
819                key_range,
820                predicate,
821                estimated_rows,
822                estimated_cost,
823                ..
824            } => {
825                PhysicalPlan::IndexSeek {
826                    table,
827                    index,
828                    columns, // Pushed down columns
829                    key_range,
830                    predicate,
831                    estimated_rows,
832                    estimated_cost,
833                }
834            }
835            other => PhysicalPlan::Project {
836                input: Box::new(other),
837                columns,
838                estimated_cost: 0.0,
839            },
840        }
841    }
842
843    /// Add sort operator
844    fn add_sort(
845        &self,
846        plan: PhysicalPlan,
847        order_by: Vec<(String, SortDirection)>,
848        _stats: &Option<TableStats>,
849    ) -> PhysicalPlan {
850        let estimated_rows = self.get_plan_rows(&plan);
851        let sort_cost = if estimated_rows > 0 {
852            estimated_rows as f64 * (estimated_rows as f64).log2() * self.config.c_compare
853        } else {
854            0.0
855        };
856
857        PhysicalPlan::Sort {
858            input: Box::new(plan),
859            order_by,
860            estimated_cost: sort_cost,
861        }
862    }
863
864    /// Get estimated rows from a plan
865    #[allow(clippy::only_used_in_recursion)]
866    fn get_plan_rows(&self, plan: &PhysicalPlan) -> u64 {
867        match plan {
868            PhysicalPlan::TableScan { estimated_rows, .. }
869            | PhysicalPlan::IndexSeek { estimated_rows, .. }
870            | PhysicalPlan::Filter { estimated_rows, .. }
871            | PhysicalPlan::Aggregate { estimated_rows, .. }
872            | PhysicalPlan::NestedLoopJoin { estimated_rows, .. }
873            | PhysicalPlan::HashJoin { estimated_rows, .. }
874            | PhysicalPlan::MergeJoin { estimated_rows, .. } => *estimated_rows,
875            PhysicalPlan::Project { input, .. } | PhysicalPlan::Sort { input, .. } => {
876                self.get_plan_rows(input)
877            }
878            PhysicalPlan::Limit { limit, .. } => *limit,
879        }
880    }
881
882    /// Get estimated cost from a plan
883    #[allow(clippy::only_used_in_recursion)]
884    pub fn get_plan_cost(&self, plan: &PhysicalPlan) -> f64 {
885        match plan {
886            PhysicalPlan::TableScan { estimated_cost, .. } => *estimated_cost,
887            PhysicalPlan::IndexSeek { estimated_cost, .. } => *estimated_cost,
888            PhysicalPlan::Filter {
889                estimated_cost,
890                input,
891                ..
892            } => *estimated_cost + self.get_plan_cost(input),
893            PhysicalPlan::Project {
894                estimated_cost,
895                input,
896                ..
897            } => *estimated_cost + self.get_plan_cost(input),
898            PhysicalPlan::Sort {
899                estimated_cost,
900                input,
901                ..
902            } => *estimated_cost + self.get_plan_cost(input),
903            PhysicalPlan::Limit {
904                estimated_cost,
905                input,
906                ..
907            } => *estimated_cost + self.get_plan_cost(input),
908            PhysicalPlan::NestedLoopJoin {
909                estimated_cost,
910                outer,
911                inner,
912                ..
913            } => *estimated_cost + self.get_plan_cost(outer) + self.get_plan_cost(inner),
914            PhysicalPlan::HashJoin {
915                estimated_cost,
916                build,
917                probe,
918                ..
919            } => *estimated_cost + self.get_plan_cost(build) + self.get_plan_cost(probe),
920            PhysicalPlan::MergeJoin {
921                estimated_cost,
922                left,
923                right,
924                ..
925            } => *estimated_cost + self.get_plan_cost(left) + self.get_plan_cost(right),
926            PhysicalPlan::Aggregate {
927                estimated_cost,
928                input,
929                ..
930            } => *estimated_cost + self.get_plan_cost(input),
931        }
932    }
933
934    /// Generate EXPLAIN output
935    pub fn explain(&self, plan: &PhysicalPlan) -> String {
936        self.explain_impl(plan, 0)
937    }
938
939    fn explain_impl(&self, plan: &PhysicalPlan, indent: usize) -> String {
940        let prefix = "  ".repeat(indent);
941        let cost = self.get_plan_cost(plan);
942
943        match plan {
944            PhysicalPlan::TableScan {
945                table,
946                columns,
947                estimated_rows,
948                ..
949            } => {
950                format!(
951                    "{}TableScan [table={}, columns={:?}, rows={}, cost={:.2}ms]",
952                    prefix, table, columns, estimated_rows, cost
953                )
954            }
955            PhysicalPlan::IndexSeek {
956                table,
957                index,
958                columns,
959                estimated_rows,
960                ..
961            } => {
962                format!(
963                    "{}IndexSeek [table={}, index={}, columns={:?}, rows={}, cost={:.2}ms]",
964                    prefix, table, index, columns, estimated_rows, cost
965                )
966            }
967            PhysicalPlan::Filter {
968                input,
969                estimated_rows,
970                ..
971            } => {
972                format!(
973                    "{}Filter [rows={}, cost={:.2}ms]\n{}",
974                    prefix,
975                    estimated_rows,
976                    cost,
977                    self.explain_impl(input, indent + 1)
978                )
979            }
980            PhysicalPlan::Project { input, columns, .. } => {
981                format!(
982                    "{}Project [columns={:?}, cost={:.2}ms]\n{}",
983                    prefix,
984                    columns,
985                    cost,
986                    self.explain_impl(input, indent + 1)
987                )
988            }
989            PhysicalPlan::Sort {
990                input, order_by, ..
991            } => {
992                let order: Vec<_> = order_by
993                    .iter()
994                    .map(|(c, d)| format!("{} {:?}", c, d))
995                    .collect();
996                format!(
997                    "{}Sort [order={:?}, cost={:.2}ms]\n{}",
998                    prefix,
999                    order,
1000                    cost,
1001                    self.explain_impl(input, indent + 1)
1002                )
1003            }
1004            PhysicalPlan::Limit {
1005                input,
1006                limit,
1007                offset,
1008                ..
1009            } => {
1010                format!(
1011                    "{}Limit [limit={}, offset={}, cost={:.2}ms]\n{}",
1012                    prefix,
1013                    limit,
1014                    offset,
1015                    cost,
1016                    self.explain_impl(input, indent + 1)
1017                )
1018            }
1019            PhysicalPlan::HashJoin {
1020                build,
1021                probe,
1022                join_type,
1023                estimated_rows,
1024                ..
1025            } => {
1026                format!(
1027                    "{}HashJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1028                    prefix,
1029                    join_type,
1030                    estimated_rows,
1031                    cost,
1032                    self.explain_impl(build, indent + 1),
1033                    self.explain_impl(probe, indent + 1)
1034                )
1035            }
1036            PhysicalPlan::MergeJoin {
1037                left,
1038                right,
1039                join_type,
1040                estimated_rows,
1041                ..
1042            } => {
1043                format!(
1044                    "{}MergeJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1045                    prefix,
1046                    join_type,
1047                    estimated_rows,
1048                    cost,
1049                    self.explain_impl(left, indent + 1),
1050                    self.explain_impl(right, indent + 1)
1051                )
1052            }
1053            PhysicalPlan::NestedLoopJoin {
1054                outer,
1055                inner,
1056                join_type,
1057                estimated_rows,
1058                ..
1059            } => {
1060                format!(
1061                    "{}NestedLoopJoin [type={:?}, rows={}, cost={:.2}ms]\n{}\n{}",
1062                    prefix,
1063                    join_type,
1064                    estimated_rows,
1065                    cost,
1066                    self.explain_impl(outer, indent + 1),
1067                    self.explain_impl(inner, indent + 1)
1068                )
1069            }
1070            PhysicalPlan::Aggregate {
1071                input,
1072                group_by,
1073                aggregates,
1074                estimated_rows,
1075                ..
1076            } => {
1077                let aggs: Vec<_> = aggregates
1078                    .iter()
1079                    .map(|a| format!("{:?}({})", a.function, a.column.as_deref().unwrap_or("*")))
1080                    .collect();
1081                format!(
1082                    "{}Aggregate [group_by={:?}, aggs={:?}, rows={}, cost={:.2}ms]\n{}",
1083                    prefix,
1084                    group_by,
1085                    aggs,
1086                    estimated_rows,
1087                    cost,
1088                    self.explain_impl(input, indent + 1)
1089                )
1090            }
1091        }
1092    }
1093}
1094
1095// ============================================================================
1096// Plan Cache & Stats Helpers
1097// ============================================================================
1098
1099impl CostBasedOptimizer {
1100    /// Evict stale entries from the plan cache.
1101    pub fn evict_stale_plans(&self) {
1102        let now = Self::now_us();
1103        self.plan_cache
1104            .write()
1105            .retain(|_, (_, ts)| now.saturating_sub(*ts) < self.plan_cache_ttl_us);
1106    }
1107
1108    /// Clear the entire plan cache (call after DDL or bulk load).
1109    pub fn invalidate_plan_cache(&self) {
1110        self.plan_cache.write().clear();
1111    }
1112
1113    /// Collect fresh statistics for a table from row data.
1114    ///
1115    /// Pass an iterator of (column_name, value_as_string) pairs per row.
1116    /// This builds column stats with distinct counts and optional histograms.
1117    pub fn collect_stats(
1118        &self,
1119        table_name: &str,
1120        row_count: u64,
1121        size_bytes: u64,
1122        column_values: HashMap<String, Vec<String>>,
1123        indices: Vec<IndexStats>,
1124    ) {
1125        let mut column_stats = HashMap::new();
1126        for (col_name, values) in &column_values {
1127            let distinct: HashSet<&String> = values.iter().collect();
1128            let null_count = values.iter().filter(|v| v.is_empty()).count() as u64;
1129            let avg_length = if values.is_empty() {
1130                0.0
1131            } else {
1132                values.iter().map(|v| v.len()).sum::<usize>() as f64 / values.len() as f64
1133            };
1134
1135            // Build histogram for numeric columns (try parse first 10 values)
1136            let is_numeric = values.iter().take(10).all(|v| v.parse::<f64>().is_ok());
1137            let histogram = if is_numeric && values.len() >= 10 {
1138                let mut nums: Vec<f64> = values.iter().filter_map(|v| v.parse().ok()).collect();
1139                nums.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1140                let bucket_count = 10.min(nums.len());
1141                let bucket_size = nums.len() / bucket_count;
1142                let mut boundaries = Vec::new();
1143                let mut counts = Vec::new();
1144                for i in 0..bucket_count {
1145                    let end = if i == bucket_count - 1 {
1146                        nums.len()
1147                    } else {
1148                        (i + 1) * bucket_size
1149                    };
1150                    let start = i * bucket_size;
1151                    boundaries.push(nums[end - 1]);
1152                    counts.push((end - start) as u64);
1153                }
1154                Some(Histogram {
1155                    boundaries,
1156                    counts,
1157                    total_rows: nums.len() as u64,
1158                })
1159            } else {
1160                None
1161            };
1162
1163            // Build MCV (top 5 most common values)
1164            let mut freq_map: HashMap<&String, usize> = HashMap::new();
1165            for v in values {
1166                *freq_map.entry(v).or_insert(0) += 1;
1167            }
1168            let total = values.len() as f64;
1169            let mut mcv: Vec<(String, f64)> = freq_map
1170                .iter()
1171                .map(|(k, &v)| ((*k).clone(), v as f64 / total))
1172                .collect();
1173            mcv.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1174            mcv.truncate(5);
1175
1176            column_stats.insert(
1177                col_name.clone(),
1178                ColumnStats {
1179                    name: col_name.clone(),
1180                    distinct_count: distinct.len() as u64,
1181                    null_count,
1182                    min_value: values.iter().min().cloned(),
1183                    max_value: values.iter().max().cloned(),
1184                    avg_length,
1185                    mcv,
1186                    histogram,
1187                },
1188            );
1189        }
1190
1191        self.update_stats(TableStats {
1192            name: table_name.to_string(),
1193            row_count,
1194            size_bytes,
1195            column_stats,
1196            indices,
1197            last_updated: Self::now_us(),
1198        });
1199
1200        // Invalidate cached plans for this table
1201        self.invalidate_plan_cache();
1202    }
1203
1204    /// Check if stats are stale (older than threshold)
1205    pub fn stats_age_us(&self, table: &str) -> Option<u64> {
1206        self.stats_cache.read().get(table).map(|s| {
1207            Self::now_us().saturating_sub(s.last_updated)
1208        })
1209    }
1210
1211    fn now_us() -> u64 {
1212        SystemTime::now()
1213            .duration_since(UNIX_EPOCH)
1214            .unwrap_or_default()
1215            .as_micros() as u64
1216    }
1217}
1218
1219// ============================================================================
1220// Join Order Optimizer (Dynamic Programming)
1221// ============================================================================
1222
1223/// Join order optimizer using dynamic programming
1224pub struct JoinOrderOptimizer {
1225    /// Table statistics
1226    stats: HashMap<String, TableStats>,
1227    /// Cost model
1228    config: CostModelConfig,
1229}
1230
1231impl JoinOrderOptimizer {
1232    pub fn new(config: CostModelConfig) -> Self {
1233        Self {
1234            stats: HashMap::new(),
1235            config,
1236        }
1237    }
1238
1239    /// Add table statistics
1240    pub fn add_stats(&mut self, stats: TableStats) {
1241        self.stats.insert(stats.name.clone(), stats);
1242    }
1243
1244    /// Find optimal join order using dynamic programming
1245    ///
1246    /// Time: O(2^n × n^2) where n = number of tables
1247    /// Practical for n ≤ 10
1248    pub fn find_optimal_order(
1249        &self,
1250        tables: &[String],
1251        join_conditions: &[(String, String, String, String)], // (table1, col1, table2, col2)
1252    ) -> Vec<(String, String)> {
1253        let n = tables.len();
1254        if n <= 1 {
1255            return vec![];
1256        }
1257
1258        // dp[mask] = (cost, join_order)
1259        let mut dp: HashMap<u32, (f64, Vec<(String, String)>)> = HashMap::new();
1260
1261        // Base case: single tables
1262        for (i, _table) in tables.iter().enumerate() {
1263            let mask = 1u32 << i;
1264            dp.insert(mask, (0.0, vec![]));
1265        }
1266
1267        // Build up larger subsets
1268        for size in 2..=n {
1269            for mask in 0..(1u32 << n) {
1270                if mask.count_ones() != size as u32 {
1271                    continue;
1272                }
1273
1274                let mut best_cost = f64::MAX;
1275                let mut best_order = vec![];
1276
1277                // Try all ways to split into two non-empty subsets
1278                for sub in 1..mask {
1279                    if sub & mask != sub || sub == 0 {
1280                        continue;
1281                    }
1282                    let other = mask ^ sub;
1283                    if other == 0 {
1284                        continue;
1285                    }
1286
1287                    // Check if there's a join between sub and other
1288                    if !self.has_join_condition(tables, sub, other, join_conditions) {
1289                        continue;
1290                    }
1291
1292                    if let (Some((cost1, order1)), Some((cost2, order2))) =
1293                        (dp.get(&sub), dp.get(&other))
1294                    {
1295                        let join_cost = self.estimate_join_cost(tables, sub, other);
1296                        let total_cost = cost1 + cost2 + join_cost;
1297
1298                        if total_cost < best_cost {
1299                            best_cost = total_cost;
1300                            best_order = order1.clone();
1301                            best_order.extend(order2.clone());
1302
1303                            // Add the join
1304                            let (t1, t2) =
1305                                self.get_join_tables(tables, sub, other, join_conditions);
1306                            if let Some((t1, t2)) = Some((t1, t2)) {
1307                                best_order.push((t1, t2));
1308                            }
1309                        }
1310                    }
1311                }
1312
1313                if best_cost < f64::MAX {
1314                    dp.insert(mask, (best_cost, best_order));
1315                }
1316            }
1317        }
1318
1319        let full_mask = (1u32 << n) - 1;
1320        dp.get(&full_mask)
1321            .map(|(_, order)| order.clone())
1322            .unwrap_or_default()
1323    }
1324
1325    fn has_join_condition(
1326        &self,
1327        tables: &[String],
1328        mask1: u32,
1329        mask2: u32,
1330        conditions: &[(String, String, String, String)],
1331    ) -> bool {
1332        for (t1, _, t2, _) in conditions {
1333            let idx1 = tables.iter().position(|t| t == t1);
1334            let idx2 = tables.iter().position(|t| t == t2);
1335
1336            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1337                let in_mask1 = (mask1 >> i1) & 1 == 1;
1338                let in_mask2 = (mask2 >> i2) & 1 == 1;
1339
1340                if in_mask1 && in_mask2 {
1341                    return true;
1342                }
1343            }
1344        }
1345        false
1346    }
1347
1348    fn get_join_tables(
1349        &self,
1350        tables: &[String],
1351        mask1: u32,
1352        mask2: u32,
1353        conditions: &[(String, String, String, String)],
1354    ) -> (String, String) {
1355        for (t1, _, t2, _) in conditions {
1356            let idx1 = tables.iter().position(|t| t == t1);
1357            let idx2 = tables.iter().position(|t| t == t2);
1358
1359            if let (Some(i1), Some(i2)) = (idx1, idx2) {
1360                let t1_in_mask1 = (mask1 >> i1) & 1 == 1;
1361                let t2_in_mask2 = (mask2 >> i2) & 1 == 1;
1362
1363                if t1_in_mask1 && t2_in_mask2 {
1364                    return (t1.clone(), t2.clone());
1365                }
1366            }
1367        }
1368        (String::new(), String::new())
1369    }
1370
1371    fn estimate_join_cost(&self, tables: &[String], mask1: u32, mask2: u32) -> f64 {
1372        let rows1 = self.estimate_rows_for_mask(tables, mask1);
1373        let rows2 = self.estimate_rows_for_mask(tables, mask2);
1374
1375        // Hash join cost estimate
1376        // Build cost + probe cost
1377        let build_cost = rows1 as f64 * self.config.c_filter;
1378        let probe_cost = rows2 as f64 * self.config.c_filter;
1379
1380        build_cost + probe_cost
1381    }
1382
1383    fn estimate_rows_for_mask(&self, tables: &[String], mask: u32) -> u64 {
1384        let mut total = 1u64;
1385
1386        for (i, table) in tables.iter().enumerate() {
1387            if (mask >> i) & 1 == 1 {
1388                let rows = self.stats.get(table).map(|s| s.row_count).unwrap_or(1000);
1389                total = total.saturating_mul(rows);
1390            }
1391        }
1392
1393        // Apply default selectivity for joins
1394        let num_tables = mask.count_ones();
1395        if num_tables > 1 {
1396            total = (total as f64 * 0.1f64.powi(num_tables as i32 - 1)) as u64;
1397        }
1398
1399        total.max(1)
1400    }
1401}
1402
1403// ============================================================================
1404// Tests
1405// ============================================================================
1406
1407#[cfg(test)]
1408mod tests {
1409    use super::*;
1410
1411    fn create_test_stats() -> TableStats {
1412        let mut column_stats = HashMap::new();
1413        column_stats.insert(
1414            "id".to_string(),
1415            ColumnStats {
1416                name: "id".to_string(),
1417                distinct_count: 100000,
1418                null_count: 0,
1419                min_value: Some("1".to_string()),
1420                max_value: Some("100000".to_string()),
1421                avg_length: 8.0,
1422                mcv: vec![],
1423                histogram: None,
1424            },
1425        );
1426        column_stats.insert(
1427            "score".to_string(),
1428            ColumnStats {
1429                name: "score".to_string(),
1430                distinct_count: 100,
1431                null_count: 1000,
1432                min_value: Some("0".to_string()),
1433                max_value: Some("100".to_string()),
1434                avg_length: 8.0,
1435                mcv: vec![("50".to_string(), 0.05)],
1436                histogram: Some(Histogram {
1437                    boundaries: vec![25.0, 50.0, 75.0, 100.0],
1438                    counts: vec![25000, 25000, 25000, 25000],
1439                    total_rows: 100000,
1440                }),
1441            },
1442        );
1443
1444        TableStats {
1445            name: "users".to_string(),
1446            row_count: 100000,
1447            size_bytes: 10_000_000, // 10 MB
1448            column_stats,
1449            indices: vec![
1450                IndexStats {
1451                    name: "pk_users".to_string(),
1452                    columns: vec!["id".to_string()],
1453                    is_primary: true,
1454                    is_unique: true,
1455                    index_type: IndexType::BTree,
1456                    leaf_pages: 1000,
1457                    height: 3,
1458                    avg_leaf_density: 100.0,
1459                },
1460                IndexStats {
1461                    name: "idx_score".to_string(),
1462                    columns: vec!["score".to_string()],
1463                    is_primary: false,
1464                    is_unique: false,
1465                    index_type: IndexType::BTree,
1466                    leaf_pages: 500,
1467                    height: 2,
1468                    avg_leaf_density: 200.0,
1469                },
1470            ],
1471            last_updated: 0,
1472        }
1473    }
1474
1475    #[test]
1476    fn test_selectivity_estimation() {
1477        let config = CostModelConfig::default();
1478        let optimizer = CostBasedOptimizer::new(config);
1479
1480        let stats = create_test_stats();
1481        optimizer.update_stats(stats.clone());
1482
1483        // Equality predicate
1484        let pred = Predicate::Eq {
1485            column: "id".to_string(),
1486            value: "12345".to_string(),
1487        };
1488        let sel = optimizer.estimate_selectivity(&pred, &stats);
1489        assert!(sel < 0.001); // Should be very selective
1490
1491        // Range predicate with histogram
1492        // Note: For histogram boundaries [25, 50, 75, 100] with equal distribution,
1493        // Gt{75} includes buckets with bucket_max >= 75, which is buckets 2 and 3 (50%)
1494        let pred = Predicate::Gt {
1495            column: "score".to_string(),
1496            value: "75".to_string(),
1497        };
1498        let sel = optimizer.estimate_selectivity(&pred, &stats);
1499        assert!(sel > 0.4 && sel < 0.6); // ~50% from histogram (2 of 4 buckets)
1500    }
1501
1502    #[test]
1503    fn test_access_path_selection() {
1504        let config = CostModelConfig::default();
1505        let optimizer = CostBasedOptimizer::new(config);
1506
1507        let stats = create_test_stats();
1508        optimizer.update_stats(stats);
1509
1510        // High selectivity predicate should use index
1511        let pred = Predicate::Eq {
1512            column: "id".to_string(),
1513            value: "12345".to_string(),
1514        };
1515        let plan = optimizer.optimize(
1516            "users",
1517            vec!["id".to_string(), "score".to_string()],
1518            Some(pred),
1519            vec![],
1520            None,
1521        );
1522
1523        match plan {
1524            PhysicalPlan::IndexSeek { index, .. } => {
1525                assert_eq!(index, "pk_users");
1526            }
1527            _ => panic!("Expected IndexSeek for equality on primary key"),
1528        }
1529    }
1530
1531    #[test]
1532    fn test_token_budget_limit() {
1533        let config = CostModelConfig::default();
1534        let optimizer = CostBasedOptimizer::new(config).with_token_budget(2048, 25.0);
1535
1536        // With 2048 token budget and 25 tokens/row:
1537        // max_rows = (2048 - 50) / 25 = ~80
1538        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1539
1540        match plan {
1541            PhysicalPlan::Limit { limit, .. } => {
1542                assert!(limit <= 80);
1543            }
1544            _ => panic!("Expected Limit to be injected"),
1545        }
1546    }
1547
1548    #[test]
1549    fn test_explain_output() {
1550        let config = CostModelConfig::default();
1551        let optimizer = CostBasedOptimizer::new(config);
1552
1553        let stats = create_test_stats();
1554        optimizer.update_stats(stats);
1555
1556        let plan = optimizer.optimize(
1557            "users",
1558            vec!["id".to_string(), "score".to_string()],
1559            Some(Predicate::Gt {
1560                column: "score".to_string(),
1561                value: "80".to_string(),
1562            }),
1563            vec![("score".to_string(), SortDirection::Descending)],
1564            Some(10),
1565        );
1566
1567        let explain = optimizer.explain(&plan);
1568        assert!(explain.contains("Limit"));
1569        assert!(explain.contains("Sort"));
1570    }
1571
1572    // ================================================================
1573    // Production-grade tests
1574    // ================================================================
1575
1576    #[test]
1577    fn test_token_budget_underflow_safety() {
1578        // Ensure small budget doesn't panic (saturating_sub)
1579        let config = CostModelConfig::default();
1580        let optimizer = CostBasedOptimizer::new(config).with_token_budget(10, 25.0);
1581
1582        let plan = optimizer.optimize("users", vec!["id".to_string()], None, vec![], None);
1583        match plan {
1584            PhysicalPlan::Limit { limit, .. } => {
1585                assert!(limit >= 1, "Must return at least 1 row");
1586            }
1587            _ => panic!("Expected Limit"),
1588        }
1589    }
1590
1591    #[test]
1592    fn test_index_seek_derives_key_range() {
1593        let config = CostModelConfig::default();
1594        let optimizer = CostBasedOptimizer::new(config);
1595        optimizer.update_stats(create_test_stats());
1596
1597        let plan = optimizer.optimize(
1598            "users",
1599            vec!["id".to_string()],
1600            Some(Predicate::Eq {
1601                column: "id".to_string(),
1602                value: "42".to_string(),
1603            }),
1604            vec![],
1605            None,
1606        );
1607
1608        match plan {
1609            PhysicalPlan::IndexSeek { key_range, .. } => {
1610                assert!(key_range.start.is_some(), "KeyRange must derive from Eq predicate");
1611                assert_eq!(key_range.start, key_range.end, "Eq predicate → point key range");
1612            }
1613            _ => panic!("Expected IndexSeek"),
1614        }
1615    }
1616
1617    #[test]
1618    fn test_range_predicate_key_range() {
1619        let config = CostModelConfig::default();
1620        let optimizer = CostBasedOptimizer::new(config);
1621        optimizer.update_stats(create_test_stats());
1622
1623        let plan = optimizer.optimize(
1624            "users",
1625            vec!["score".to_string()],
1626            Some(Predicate::Between {
1627                column: "score".to_string(),
1628                min: "10".to_string(),
1629                max: "90".to_string(),
1630            }),
1631            vec![],
1632            None,
1633        );
1634
1635        match plan {
1636            PhysicalPlan::IndexSeek { key_range, .. } => {
1637                assert!(key_range.start.is_some());
1638                assert!(key_range.end.is_some());
1639                assert!(key_range.start_inclusive);
1640                assert!(key_range.end_inclusive);
1641            }
1642            _ => {} // May choose scan if cheaper — that's OK
1643        }
1644    }
1645
1646    #[test]
1647    fn test_projection_pushdown_proportional_reduction() {
1648        let config = CostModelConfig::default();
1649        let optimizer = CostBasedOptimizer::new(config);
1650        optimizer.update_stats(create_test_stats());
1651
1652        // Select 1 of 2 columns → ~50% cost reduction on table scan
1653        let plan_all = optimizer.optimize(
1654            "users",
1655            vec!["id".to_string(), "score".to_string()],
1656            None,
1657            vec![],
1658            Some(100),
1659        );
1660        let plan_single = optimizer.optimize(
1661            "users",
1662            vec!["id".to_string()],
1663            None,
1664            vec![],
1665            Some(100),
1666        );
1667
1668        let cost_all = optimizer.get_plan_cost(&plan_all);
1669        let cost_single = optimizer.get_plan_cost(&plan_single);
1670        // Single column should cost less than or equal to all columns
1671        assert!(cost_single <= cost_all, "Projection should reduce cost: {} vs {}", cost_single, cost_all);
1672    }
1673
1674    #[test]
1675    fn test_collect_stats_builds_histogram() {
1676        let config = CostModelConfig::default();
1677        let optimizer = CostBasedOptimizer::new(config);
1678
1679        let mut column_values = HashMap::new();
1680        let scores: Vec<String> = (0..100).map(|i| i.to_string()).collect();
1681        column_values.insert("score".to_string(), scores);
1682
1683        optimizer.collect_stats("test_table", 100, 10000, column_values, vec![]);
1684
1685        let stats = optimizer.get_stats("test_table").unwrap();
1686        assert_eq!(stats.row_count, 100);
1687        let score_stats = stats.column_stats.get("score").unwrap();
1688        assert_eq!(score_stats.distinct_count, 100);
1689        assert!(score_stats.histogram.is_some(), "Numeric column should get histogram");
1690        assert!(!score_stats.mcv.is_empty(), "Should build MCV list");
1691    }
1692
1693    #[test]
1694    fn test_plan_cache_invalidation() {
1695        let config = CostModelConfig::default();
1696        let optimizer = CostBasedOptimizer::new(config);
1697
1698        // Collecting stats should invalidate cache
1699        let mut col = HashMap::new();
1700        col.insert("x".to_string(), vec!["1".to_string()]);
1701        optimizer.collect_stats("t", 1, 100, col.clone(), vec![]);
1702
1703        // Cache should be empty after stats collection
1704        assert!(optimizer.plan_cache.read().is_empty());
1705    }
1706
1707    #[test]
1708    fn test_stats_age_tracking() {
1709        let config = CostModelConfig::default();
1710        let optimizer = CostBasedOptimizer::new(config);
1711
1712        assert!(optimizer.stats_age_us("unknown").is_none());
1713
1714        let mut col = HashMap::new();
1715        col.insert("x".to_string(), vec!["1".to_string()]);
1716        optimizer.collect_stats("t", 1, 100, col, vec![]);
1717
1718        let age = optimizer.stats_age_us("t").unwrap();
1719        assert!(age < 1_000_000, "Stats should be fresh (< 1 second old)");
1720    }
1721
1722    #[test]
1723    fn test_scan_cost_reads_all_blocks() {
1724        // Scan cost must NOT multiply by selectivity — scans read everything
1725        let config = CostModelConfig::default();
1726        let optimizer = CostBasedOptimizer::new(config.clone());
1727        let no_pred = optimizer.estimate_scan_cost(1000, 4096 * 10, None);
1728        let with_pred = optimizer.estimate_scan_cost(
1729            1000,
1730            4096 * 10,
1731            Some(&Predicate::Eq {
1732                column: "x".to_string(),
1733                value: "1".to_string(),
1734            }),
1735        );
1736        // Scan cost should be the same regardless of predicate
1737        // (scan reads all blocks; predicate doesn't reduce I/O)
1738        assert!(
1739            (no_pred - with_pred).abs() < 0.001,
1740            "Scan cost should not depend on predicate: {} vs {}",
1741            no_pred,
1742            with_pred
1743        );
1744    }
1745
1746    #[test]
1747    fn test_index_wins_over_scan_for_point_lookup() {
1748        let config = CostModelConfig::default();
1749        let optimizer = CostBasedOptimizer::new(config);
1750        optimizer.update_stats(create_test_stats());
1751
1752        let scan_cost = optimizer.estimate_scan_cost(100000, 10_000_000, None);
1753
1754        // Index cost for a point lookup should be orders of magnitude cheaper
1755        let pk_index = &create_test_stats().indices[0]; // pk_users
1756        let index_cost = optimizer.estimate_index_cost(pk_index, 100000, 0.00001);
1757
1758        assert!(
1759            index_cost < scan_cost * 0.1,
1760            "Index point lookup ({:.2}) should be <10% of scan cost ({:.2})",
1761            index_cost,
1762            scan_cost
1763        );
1764    }
1765
1766    #[test]
1767    fn test_no_stats_defaults_to_scan() {
1768        let config = CostModelConfig::default();
1769        let optimizer = CostBasedOptimizer::new(config);
1770        // No stats loaded — optimizer should still work with defaults
1771        let plan = optimizer.optimize(
1772            "unknown_table",
1773            vec!["col1".to_string()],
1774            Some(Predicate::Eq {
1775                column: "col1".to_string(),
1776                value: "x".to_string(),
1777            }),
1778            vec![],
1779            None,
1780        );
1781        // Should produce a valid plan (TableScan with default estimates)
1782        match plan {
1783            PhysicalPlan::TableScan { estimated_rows, .. } => {
1784                assert!(estimated_rows > 0, "Default row estimate must be positive");
1785            }
1786            PhysicalPlan::IndexSeek { .. } => {} // also fine with no stats
1787            _ => panic!("Expected TableScan or IndexSeek for unknown table"),
1788        }
1789    }
1790
1791    #[test]
1792    fn test_compound_predicate_selectivity() {
1793        let stats = create_test_stats();
1794        let config = CostModelConfig::default();
1795        let optimizer = CostBasedOptimizer::new(config);
1796
1797        // AND: independent → multiply
1798        let and_pred = Predicate::And(
1799            Box::new(Predicate::Eq {
1800                column: "id".to_string(),
1801                value: "1".to_string(),
1802            }),
1803            Box::new(Predicate::IsNotNull {
1804                column: "score".to_string(),
1805            }),
1806        );
1807        let sel = optimizer.estimate_selectivity(&and_pred, &stats);
1808        let eq_sel = optimizer.estimate_selectivity(
1809            &Predicate::Eq { column: "id".to_string(), value: "1".to_string() },
1810            &stats,
1811        );
1812        assert!(sel < eq_sel, "AND must be more selective than either child");
1813
1814        // OR: P(A∪B) = P(A)+P(B)-P(A∩B)
1815        let or_pred = Predicate::Or(
1816            Box::new(Predicate::Eq {
1817                column: "id".to_string(),
1818                value: "1".to_string(),
1819            }),
1820            Box::new(Predicate::Eq {
1821                column: "id".to_string(),
1822                value: "2".to_string(),
1823            }),
1824        );
1825        let sel = optimizer.estimate_selectivity(&or_pred, &stats);
1826        assert!(sel > eq_sel, "OR must be less selective than either child");
1827        assert!(sel <= 1.0, "Selectivity must be <= 1.0");
1828    }
1829
1830    #[test]
1831    fn test_join_order_optimizer() {
1832        let mut join_opt = JoinOrderOptimizer::new(CostModelConfig::default());
1833        join_opt.add_stats(TableStats {
1834            name: "orders".to_string(),
1835            row_count: 1000000,
1836            size_bytes: 100_000_000,
1837            column_stats: HashMap::new(),
1838            indices: vec![],
1839            last_updated: 0,
1840        });
1841        join_opt.add_stats(TableStats {
1842            name: "users".to_string(),
1843            row_count: 10000,
1844            size_bytes: 1_000_000,
1845            column_stats: HashMap::new(),
1846            indices: vec![],
1847            last_updated: 0,
1848        });
1849
1850        let order = join_opt.find_optimal_order(
1851            &["orders".to_string(), "users".to_string()],
1852            &[("orders".to_string(), "user_id".to_string(), "users".to_string(), "id".to_string())],
1853        );
1854        assert!(!order.is_empty(), "Should find a join order");
1855    }
1856}