1use crate::parser::ast::*;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize)]
8pub struct Cost {
9 pub cpu: f64,
11 pub io: f64,
13 pub memory: f64,
15 pub network: f64,
17}
18
19impl Cost {
20 pub fn new(cpu: f64, io: f64, memory: f64, network: f64) -> Self {
22 Self {
23 cpu,
24 io,
25 memory,
26 network,
27 }
28 }
29
30 pub fn zero() -> Self {
32 Self::new(0.0, 0.0, 0.0, 0.0)
33 }
34
35 pub fn total(&self) -> f64 {
37 const CPU_WEIGHT: f64 = 1.0;
39 const IO_WEIGHT: f64 = 10.0;
40 const MEMORY_WEIGHT: f64 = 0.1;
41 const NETWORK_WEIGHT: f64 = 20.0;
42
43 self.cpu * CPU_WEIGHT
44 + self.io * IO_WEIGHT
45 + self.memory * MEMORY_WEIGHT
46 + self.network * NETWORK_WEIGHT
47 }
48
49 pub fn add(&self, other: &Cost) -> Cost {
51 Cost::new(
52 self.cpu + other.cpu,
53 self.io + other.io,
54 self.memory + other.memory,
55 self.network + other.network,
56 )
57 }
58
59 pub fn multiply(&self, factor: f64) -> Cost {
61 Cost::new(
62 self.cpu * factor,
63 self.io * factor,
64 self.memory * factor,
65 self.network * factor,
66 )
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Statistics {
73 pub row_count: usize,
75 pub row_size: usize,
77 pub columns: Vec<ColumnStatistics>,
79 pub indexes: Vec<IndexStatistics>,
81}
82
83impl Statistics {
84 pub fn new(row_count: usize, row_size: usize) -> Self {
86 Self {
87 row_count,
88 row_size,
89 columns: Vec::new(),
90 indexes: Vec::new(),
91 }
92 }
93
94 pub fn total_size(&self) -> usize {
96 self.row_count * self.row_size
97 }
98
99 pub fn with_column(mut self, col_stats: ColumnStatistics) -> Self {
101 self.columns.push(col_stats);
102 self
103 }
104
105 pub fn with_index(mut self, idx_stats: IndexStatistics) -> Self {
107 self.indexes.push(idx_stats);
108 self
109 }
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ColumnStatistics {
115 pub name: String,
117 pub distinct_count: usize,
119 pub null_count: usize,
121 pub min_value: Option<Literal>,
123 pub max_value: Option<Literal>,
125}
126
127impl ColumnStatistics {
128 pub fn new(name: String, distinct_count: usize, null_count: usize) -> Self {
130 Self {
131 name,
132 distinct_count,
133 null_count,
134 min_value: None,
135 max_value: None,
136 }
137 }
138
139 pub fn equality_selectivity(&self, _total_rows: usize) -> f64 {
141 if self.distinct_count == 0 {
142 return 0.0;
143 }
144 1.0 / self.distinct_count as f64
145 }
146
147 pub fn range_selectivity(&self, low: &Literal, high: &Literal) -> f64 {
149 match (&self.min_value, &self.max_value) {
151 (Some(min), Some(max)) => {
152 if let (Literal::Integer(min_val), Literal::Integer(max_val)) = (min, max) {
153 if let (Literal::Integer(low_val), Literal::Integer(high_val)) = (low, high) {
154 let range = (max_val - min_val) as f64;
155 if range > 0.0 {
156 let selected = (high_val - low_val) as f64;
157 return (selected / range).clamp(0.0, 1.0);
158 }
159 }
160 }
161 0.25
163 }
164 _ => 0.25, }
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct IndexStatistics {
172 pub name: String,
174 pub columns: Vec<String>,
176 pub index_type: IndexType,
178 pub size: usize,
180 pub height: Option<usize>,
182}
183
184#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
186pub enum IndexType {
187 BTree,
189 RTree,
191 Hash,
193}
194
195impl IndexStatistics {
196 pub fn new(name: String, columns: Vec<String>, index_type: IndexType, size: usize) -> Self {
198 Self {
199 name,
200 columns,
201 index_type,
202 size,
203 height: None,
204 }
205 }
206
207 pub fn lookup_cost(&self) -> Cost {
209 match self.index_type {
210 IndexType::BTree => {
211 let height = self.height.unwrap_or(4) as f64;
213 Cost::new(height * 100.0, height * 8192.0, 0.0, 0.0)
214 }
215 IndexType::RTree => {
216 let height = self.height.unwrap_or(4) as f64;
218 Cost::new(height * 150.0, height * 8192.0, 0.0, 0.0)
219 }
220 IndexType::Hash => {
221 Cost::new(50.0, 8192.0, 0.0, 0.0)
223 }
224 }
225 }
226
227 pub fn scan_cost(&self, selectivity: f64) -> Cost {
229 let io = (self.size as f64 * selectivity).max(8192.0);
230 Cost::new(io / 100.0, io, 0.0, 0.0)
231 }
232}
233
234pub struct CostModel {
236 statistics: dashmap::DashMap<String, Statistics>,
238}
239
240impl CostModel {
241 pub fn new() -> Self {
243 Self {
244 statistics: dashmap::DashMap::new(),
245 }
246 }
247
248 pub fn register_statistics(&self, table: String, stats: Statistics) {
250 self.statistics.insert(table, stats);
251 }
252
253 pub fn get_statistics(&self, table: &str) -> Option<Statistics> {
255 self.statistics.get(table).map(|s| s.clone())
256 }
257
258 pub fn scan_cost(&self, table: &str) -> Cost {
260 if let Some(stats) = self.get_statistics(table) {
261 let total_size = stats.total_size() as f64;
262 Cost::new(
263 stats.row_count as f64 * 10.0,
264 total_size,
265 stats.row_size as f64,
266 0.0,
267 )
268 } else {
269 Cost::new(1_000_000.0, 1_000_000_000.0, 1000.0, 0.0)
271 }
272 }
273
274 pub fn filter_cost(&self, input_rows: usize, selectivity: f64) -> Cost {
276 let output_rows = (input_rows as f64 * selectivity) as usize;
277 Cost::new(
278 input_rows as f64 * 2.0,
279 0.0,
280 output_rows as f64 * 100.0,
281 0.0,
282 )
283 }
284
285 pub fn join_cost(&self, left_rows: usize, right_rows: usize, join_type: JoinType) -> Cost {
287 match join_type {
288 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
289 let build_cost = right_rows as f64 * 10.0;
291 let probe_cost = left_rows as f64 * 5.0;
292 let memory = right_rows as f64 * 100.0;
293 Cost::new(build_cost + probe_cost, 0.0, memory, 0.0)
294 }
295 JoinType::Cross => {
296 let total_ops = (left_rows * right_rows) as f64;
298 Cost::new(total_ops * 2.0, 0.0, total_ops * 100.0, 0.0)
299 }
300 }
301 }
302
303 pub fn aggregate_cost(&self, input_rows: usize, group_count: usize) -> Cost {
305 Cost::new(
306 input_rows as f64 * 5.0,
307 0.0,
308 group_count as f64 * 200.0,
309 0.0,
310 )
311 }
312
313 pub fn sort_cost(&self, input_rows: usize) -> Cost {
315 let n = input_rows as f64;
317 let ops = n * n.log2();
318 Cost::new(ops * 10.0, 0.0, n * 100.0, 0.0)
319 }
320
321 pub fn estimate_selectivity(&self, table: &str, expr: &Expr) -> f64 {
323 match expr {
324 Expr::BinaryOp { left, op, right } => match op {
325 BinaryOperator::Eq => {
326 if let Expr::Column { name, .. } = &**left {
327 if let Some(stats) = self.get_statistics(table) {
328 if let Some(col_stats) = stats.columns.iter().find(|c| c.name == *name)
329 {
330 return col_stats.equality_selectivity(stats.row_count);
331 }
332 }
333 }
334 0.1 }
336 BinaryOperator::Lt
337 | BinaryOperator::LtEq
338 | BinaryOperator::Gt
339 | BinaryOperator::GtEq => 0.33, BinaryOperator::And => {
341 let left_sel = self.estimate_selectivity(table, left);
342 let right_sel = self.estimate_selectivity(table, right);
343 left_sel * right_sel
344 }
345 BinaryOperator::Or => {
346 let left_sel = self.estimate_selectivity(table, left);
347 let right_sel = self.estimate_selectivity(table, right);
348 left_sel + right_sel - (left_sel * right_sel)
349 }
350 _ => 0.5, },
352 Expr::UnaryOp {
353 op: UnaryOperator::Not,
354 expr,
355 } => 1.0 - self.estimate_selectivity(table, expr),
356 Expr::Function { name, .. } => {
357 match name.to_uppercase().as_str() {
359 "ST_INTERSECTS" | "ST_CONTAINS" | "ST_WITHIN" => 0.01,
360 _ => 0.5,
361 }
362 }
363 _ => 0.5, }
365 }
366}
367
368impl Default for CostModel {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_cost_total() {
380 let cost = Cost::new(100.0, 1000.0, 100.0, 500.0);
381 assert!(cost.total() > 0.0);
382 }
383
384 #[test]
385 fn test_cost_add() {
386 let cost1 = Cost::new(100.0, 1000.0, 100.0, 0.0);
387 let cost2 = Cost::new(50.0, 500.0, 50.0, 0.0);
388 let total = cost1.add(&cost2);
389 assert_eq!(total.cpu, 150.0);
390 assert_eq!(total.io, 1500.0);
391 }
392
393 #[test]
394 fn test_statistics() {
395 let stats = Statistics::new(1000, 100)
396 .with_column(ColumnStatistics::new("id".to_string(), 1000, 0))
397 .with_index(IndexStatistics::new(
398 "idx_id".to_string(),
399 vec!["id".to_string()],
400 IndexType::BTree,
401 10000,
402 ));
403
404 assert_eq!(stats.row_count, 1000);
405 assert_eq!(stats.total_size(), 100_000);
406 }
407
408 #[test]
409 fn test_cost_model() {
410 let model = CostModel::new();
411 let stats = Statistics::new(10000, 100);
412 model.register_statistics("users".to_string(), stats);
413
414 let scan_cost = model.scan_cost("users");
415 assert!(scan_cost.total() > 0.0);
416 }
417}