Skip to main content

oxigdal_query/optimizer/
cost_model.rs

1//! Cost-based optimization model.
2
3use crate::parser::ast::*;
4use serde::{Deserialize, Serialize};
5
6/// Cost estimate for a query plan.
7#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize)]
8pub struct Cost {
9    /// CPU cost (operations).
10    pub cpu: f64,
11    /// IO cost (bytes read).
12    pub io: f64,
13    /// Memory cost (bytes used).
14    pub memory: f64,
15    /// Network cost (bytes transferred).
16    pub network: f64,
17}
18
19impl Cost {
20    /// Create a new cost.
21    pub fn new(cpu: f64, io: f64, memory: f64, network: f64) -> Self {
22        Self {
23            cpu,
24            io,
25            memory,
26            network,
27        }
28    }
29
30    /// Zero cost.
31    pub fn zero() -> Self {
32        Self::new(0.0, 0.0, 0.0, 0.0)
33    }
34
35    /// Total cost with weighted factors.
36    pub fn total(&self) -> f64 {
37        // Weights for different cost components
38        const CPU_WEIGHT: f64 = 1.0;
39        const IO_WEIGHT: f64 = 10.0;
40        const MEMORY_WEIGHT: f64 = 0.1;
41        const NETWORK_WEIGHT: f64 = 20.0;
42
43        self.cpu * CPU_WEIGHT
44            + self.io * IO_WEIGHT
45            + self.memory * MEMORY_WEIGHT
46            + self.network * NETWORK_WEIGHT
47    }
48
49    /// Add two costs.
50    pub fn add(&self, other: &Cost) -> Cost {
51        Cost::new(
52            self.cpu + other.cpu,
53            self.io + other.io,
54            self.memory + other.memory,
55            self.network + other.network,
56        )
57    }
58
59    /// Multiply cost by a factor.
60    pub fn multiply(&self, factor: f64) -> Cost {
61        Cost::new(
62            self.cpu * factor,
63            self.io * factor,
64            self.memory * factor,
65            self.network * factor,
66        )
67    }
68}
69
70/// Statistics for a table or relation.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Statistics {
73    /// Number of rows.
74    pub row_count: usize,
75    /// Average row size in bytes.
76    pub row_size: usize,
77    /// Column statistics.
78    pub columns: Vec<ColumnStatistics>,
79    /// Available indexes.
80    pub indexes: Vec<IndexStatistics>,
81}
82
83impl Statistics {
84    /// Create new statistics.
85    pub fn new(row_count: usize, row_size: usize) -> Self {
86        Self {
87            row_count,
88            row_size,
89            columns: Vec::new(),
90            indexes: Vec::new(),
91        }
92    }
93
94    /// Total size in bytes.
95    pub fn total_size(&self) -> usize {
96        self.row_count * self.row_size
97    }
98
99    /// Add column statistics.
100    pub fn with_column(mut self, col_stats: ColumnStatistics) -> Self {
101        self.columns.push(col_stats);
102        self
103    }
104
105    /// Add index statistics.
106    pub fn with_index(mut self, idx_stats: IndexStatistics) -> Self {
107        self.indexes.push(idx_stats);
108        self
109    }
110}
111
112/// Column statistics.
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ColumnStatistics {
115    /// Column name.
116    pub name: String,
117    /// Number of distinct values.
118    pub distinct_count: usize,
119    /// Number of null values.
120    pub null_count: usize,
121    /// Minimum value (if available).
122    pub min_value: Option<Literal>,
123    /// Maximum value (if available).
124    pub max_value: Option<Literal>,
125}
126
127impl ColumnStatistics {
128    /// Create new column statistics.
129    pub fn new(name: String, distinct_count: usize, null_count: usize) -> Self {
130        Self {
131            name,
132            distinct_count,
133            null_count,
134            min_value: None,
135            max_value: None,
136        }
137    }
138
139    /// Selectivity for equality predicate.
140    pub fn equality_selectivity(&self, _total_rows: usize) -> f64 {
141        if self.distinct_count == 0 {
142            return 0.0;
143        }
144        1.0 / self.distinct_count as f64
145    }
146
147    /// Selectivity for range predicate.
148    pub fn range_selectivity(&self, low: &Literal, high: &Literal) -> f64 {
149        // Simplified range selectivity estimation
150        match (&self.min_value, &self.max_value) {
151            (Some(min), Some(max)) => {
152                if let (Literal::Integer(min_val), Literal::Integer(max_val)) = (min, max) {
153                    if let (Literal::Integer(low_val), Literal::Integer(high_val)) = (low, high) {
154                        let range = (max_val - min_val) as f64;
155                        if range > 0.0 {
156                            let selected = (high_val - low_val) as f64;
157                            return (selected / range).clamp(0.0, 1.0);
158                        }
159                    }
160                }
161                // Default range selectivity
162                0.25
163            }
164            _ => 0.25, // Default range selectivity
165        }
166    }
167}
168
169/// Index statistics.
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct IndexStatistics {
172    /// Index name.
173    pub name: String,
174    /// Indexed columns.
175    pub columns: Vec<String>,
176    /// Index type (btree, rtree, hash).
177    pub index_type: IndexType,
178    /// Index size in bytes.
179    pub size: usize,
180    /// Height of index tree (for tree indexes).
181    pub height: Option<usize>,
182}
183
184/// Index type.
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
186pub enum IndexType {
187    /// B-tree index.
188    BTree,
189    /// R-tree index (spatial).
190    RTree,
191    /// Hash index.
192    Hash,
193}
194
195impl IndexStatistics {
196    /// Create new index statistics.
197    pub fn new(name: String, columns: Vec<String>, index_type: IndexType, size: usize) -> Self {
198        Self {
199            name,
200            columns,
201            index_type,
202            size,
203            height: None,
204        }
205    }
206
207    /// Cost of index lookup.
208    pub fn lookup_cost(&self) -> Cost {
209        match self.index_type {
210            IndexType::BTree => {
211                // B-tree lookup: log(n) * page_size
212                let height = self.height.unwrap_or(4) as f64;
213                Cost::new(height * 100.0, height * 8192.0, 0.0, 0.0)
214            }
215            IndexType::RTree => {
216                // R-tree lookup: similar to B-tree
217                let height = self.height.unwrap_or(4) as f64;
218                Cost::new(height * 150.0, height * 8192.0, 0.0, 0.0)
219            }
220            IndexType::Hash => {
221                // Hash lookup: O(1) average case
222                Cost::new(50.0, 8192.0, 0.0, 0.0)
223            }
224        }
225    }
226
227    /// Cost of index scan.
228    pub fn scan_cost(&self, selectivity: f64) -> Cost {
229        let io = (self.size as f64 * selectivity).max(8192.0);
230        Cost::new(io / 100.0, io, 0.0, 0.0)
231    }
232}
233
234/// Cost model for query operations.
235pub struct CostModel {
236    /// Statistics cache.
237    statistics: dashmap::DashMap<String, Statistics>,
238}
239
240impl CostModel {
241    /// Create a new cost model.
242    pub fn new() -> Self {
243        Self {
244            statistics: dashmap::DashMap::new(),
245        }
246    }
247
248    /// Register statistics for a table.
249    pub fn register_statistics(&self, table: String, stats: Statistics) {
250        self.statistics.insert(table, stats);
251    }
252
253    /// Get statistics for a table.
254    pub fn get_statistics(&self, table: &str) -> Option<Statistics> {
255        self.statistics.get(table).map(|s| s.clone())
256    }
257
258    /// Estimate cost of a table scan.
259    pub fn scan_cost(&self, table: &str) -> Cost {
260        if let Some(stats) = self.get_statistics(table) {
261            let total_size = stats.total_size() as f64;
262            Cost::new(
263                stats.row_count as f64 * 10.0,
264                total_size,
265                stats.row_size as f64,
266                0.0,
267            )
268        } else {
269            // Default cost for unknown table
270            Cost::new(1_000_000.0, 1_000_000_000.0, 1000.0, 0.0)
271        }
272    }
273
274    /// Estimate cost of a filter operation.
275    pub fn filter_cost(&self, input_rows: usize, selectivity: f64) -> Cost {
276        let output_rows = (input_rows as f64 * selectivity) as usize;
277        Cost::new(
278            input_rows as f64 * 2.0,
279            0.0,
280            output_rows as f64 * 100.0,
281            0.0,
282        )
283    }
284
285    /// Estimate cost of a join operation.
286    pub fn join_cost(&self, left_rows: usize, right_rows: usize, join_type: JoinType) -> Cost {
287        match join_type {
288            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
289                // Hash join cost
290                let build_cost = right_rows as f64 * 10.0;
291                let probe_cost = left_rows as f64 * 5.0;
292                let memory = right_rows as f64 * 100.0;
293                Cost::new(build_cost + probe_cost, 0.0, memory, 0.0)
294            }
295            JoinType::Cross => {
296                // Cross join cost (nested loop)
297                let total_ops = (left_rows * right_rows) as f64;
298                Cost::new(total_ops * 2.0, 0.0, total_ops * 100.0, 0.0)
299            }
300        }
301    }
302
303    /// Estimate cost of aggregation.
304    pub fn aggregate_cost(&self, input_rows: usize, group_count: usize) -> Cost {
305        Cost::new(
306            input_rows as f64 * 5.0,
307            0.0,
308            group_count as f64 * 200.0,
309            0.0,
310        )
311    }
312
313    /// Estimate cost of sorting.
314    pub fn sort_cost(&self, input_rows: usize) -> Cost {
315        // Sort cost: O(n log n)
316        let n = input_rows as f64;
317        let ops = n * n.log2();
318        Cost::new(ops * 10.0, 0.0, n * 100.0, 0.0)
319    }
320
321    /// Estimate selectivity of a predicate.
322    pub fn estimate_selectivity(&self, table: &str, expr: &Expr) -> f64 {
323        match expr {
324            Expr::BinaryOp { left, op, right } => match op {
325                BinaryOperator::Eq => {
326                    if let Expr::Column { name, .. } = &**left {
327                        if let Some(stats) = self.get_statistics(table) {
328                            if let Some(col_stats) = stats.columns.iter().find(|c| c.name == *name)
329                            {
330                                return col_stats.equality_selectivity(stats.row_count);
331                            }
332                        }
333                    }
334                    0.1 // Default equality selectivity
335                }
336                BinaryOperator::Lt
337                | BinaryOperator::LtEq
338                | BinaryOperator::Gt
339                | BinaryOperator::GtEq => 0.33, // Default range selectivity
340                BinaryOperator::And => {
341                    let left_sel = self.estimate_selectivity(table, left);
342                    let right_sel = self.estimate_selectivity(table, right);
343                    left_sel * right_sel
344                }
345                BinaryOperator::Or => {
346                    let left_sel = self.estimate_selectivity(table, left);
347                    let right_sel = self.estimate_selectivity(table, right);
348                    left_sel + right_sel - (left_sel * right_sel)
349                }
350                _ => 0.5, // Default selectivity
351            },
352            Expr::UnaryOp {
353                op: UnaryOperator::Not,
354                expr,
355            } => 1.0 - self.estimate_selectivity(table, expr),
356            Expr::Function { name, .. } => {
357                // Spatial predicates have lower selectivity
358                match name.to_uppercase().as_str() {
359                    "ST_INTERSECTS" | "ST_CONTAINS" | "ST_WITHIN" => 0.01,
360                    _ => 0.5,
361                }
362            }
363            _ => 0.5, // Default selectivity
364        }
365    }
366}
367
368impl Default for CostModel {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_cost_total() {
380        let cost = Cost::new(100.0, 1000.0, 100.0, 500.0);
381        assert!(cost.total() > 0.0);
382    }
383
384    #[test]
385    fn test_cost_add() {
386        let cost1 = Cost::new(100.0, 1000.0, 100.0, 0.0);
387        let cost2 = Cost::new(50.0, 500.0, 50.0, 0.0);
388        let total = cost1.add(&cost2);
389        assert_eq!(total.cpu, 150.0);
390        assert_eq!(total.io, 1500.0);
391    }
392
393    #[test]
394    fn test_statistics() {
395        let stats = Statistics::new(1000, 100)
396            .with_column(ColumnStatistics::new("id".to_string(), 1000, 0))
397            .with_index(IndexStatistics::new(
398                "idx_id".to_string(),
399                vec!["id".to_string()],
400                IndexType::BTree,
401                10000,
402            ));
403
404        assert_eq!(stats.row_count, 1000);
405        assert_eq!(stats.total_size(), 100_000);
406    }
407
408    #[test]
409    fn test_cost_model() {
410        let model = CostModel::new();
411        let stats = Statistics::new(10000, 100);
412        model.register_statistics("users".to_string(), stats);
413
414        let scan_cost = model.scan_cost("users");
415        assert!(scan_cost.total() > 0.0);
416    }
417}