Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per I/O page read.
104    #[allow(dead_code)]
105    io_page_cost: f64,
106    /// Cost per hash table lookup.
107    hash_lookup_cost: f64,
108    /// Cost per comparison in sorting.
109    sort_comparison_cost: f64,
110    /// Average tuple size in bytes.
111    avg_tuple_size: f64,
112    /// Page size in bytes.
113    page_size: f64,
114}
115
116impl CostModel {
117    /// Creates a new cost model with default parameters.
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            cpu_tuple_cost: 0.01,
122            io_page_cost: 1.0,
123            hash_lookup_cost: 0.02,
124            sort_comparison_cost: 0.02,
125            avg_tuple_size: 100.0,
126            page_size: 8192.0,
127        }
128    }
129
130    /// Estimates the cost of a logical operator.
131    #[must_use]
132    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133        match op {
134            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145            LogicalOperator::Empty => Cost::zero(),
146            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147        }
148    }
149
150    /// Estimates the cost of a node scan.
151    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154    }
155
156    /// Estimates the cost of a filter operation.
157    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158        // Filter cost is just predicate evaluation per tuple
159        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160    }
161
162    /// Estimates the cost of a projection.
163    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164        // Cost depends on number of expressions evaluated
165        let expr_count = project.projections.len() as f64;
166        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167    }
168
169    /// Estimates the cost of an expand operation.
170    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171        // Expand involves adjacency list lookups
172        let lookup_cost = cardinality * self.hash_lookup_cost;
173        // Assume average fanout of 10 for edge traversal
174        let avg_fanout = 10.0;
175        let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176        Cost::cpu(lookup_cost + output_cost)
177    }
178
179    /// Estimates the cost of a join operation.
180    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181        // Cost depends on join type
182        match join.join_type {
183            JoinType::Cross => {
184                // Cross join is O(n * m)
185                Cost::cpu(cardinality * self.cpu_tuple_cost)
186            }
187            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188                // Hash join: build phase + probe phase
189                // Assume left side is build, right side is probe
190                let build_cardinality = cardinality.sqrt(); // Rough estimate
191                let probe_cardinality = cardinality.sqrt();
192
193                // Build hash table
194                let build_cost = build_cardinality * self.hash_lookup_cost;
195                let memory_cost = build_cardinality * self.avg_tuple_size;
196
197                // Probe hash table
198                let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200                // Output cost
201                let output_cost = cardinality * self.cpu_tuple_cost;
202
203                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204            }
205            JoinType::Semi | JoinType::Anti => {
206                // Semi/anti joins are typically cheaper
207                let build_cardinality = cardinality.sqrt();
208                let probe_cardinality = cardinality.sqrt();
209
210                let build_cost = build_cardinality * self.hash_lookup_cost;
211                let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213                Cost::cpu(build_cost + probe_cost)
214                    .with_memory(build_cardinality * self.avg_tuple_size)
215            }
216        }
217    }
218
219    /// Estimates the cost of an aggregation.
220    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221        // Hash aggregation cost
222        let hash_cost = cardinality * self.hash_lookup_cost;
223
224        // Aggregate function evaluation
225        let agg_count = agg.aggregates.len() as f64;
226        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228        // Memory for hash table (estimated distinct groups)
229        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
230        let memory_cost = distinct_groups * self.avg_tuple_size;
231
232        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233    }
234
235    /// Estimates the cost of a sort operation.
236    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237        if cardinality <= 1.0 {
238            return Cost::zero();
239        }
240
241        // Sort is O(n log n) comparisons
242        let comparisons = cardinality * cardinality.log2();
243        let key_count = sort.keys.len() as f64;
244
245        // Memory for sorting (full input materialization)
246        let memory_cost = cardinality * self.avg_tuple_size;
247
248        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249    }
250
251    /// Estimates the cost of a distinct operation.
252    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253        // Hash-based distinct
254        let hash_cost = cardinality * self.hash_lookup_cost;
255        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
256
257        Cost::cpu(hash_cost).with_memory(memory_cost)
258    }
259
260    /// Estimates the cost of a limit operation.
261    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262        // Limit is very cheap - just counting
263        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264    }
265
266    /// Estimates the cost of a skip operation.
267    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268        // Skip requires scanning through skipped rows
269        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270    }
271
272    /// Estimates the cost of a return operation.
273    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274        // Return materializes results
275        let expr_count = ret.items.len() as f64;
276        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277    }
278
279    /// Compares two costs and returns the cheaper one.
280    #[must_use]
281    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282        if a.total() <= b.total() { a } else { b }
283    }
284}
285
286impl Default for CostModel {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::query::plan::{
296        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
297        Projection, ReturnItem, SortOrder,
298    };
299
300    #[test]
301    fn test_cost_addition() {
302        let a = Cost::cpu(10.0).with_io(5.0);
303        let b = Cost::cpu(20.0).with_memory(100.0);
304        let c = a + b;
305
306        assert!((c.cpu - 30.0).abs() < 0.001);
307        assert!((c.io - 5.0).abs() < 0.001);
308        assert!((c.memory - 100.0).abs() < 0.001);
309    }
310
311    #[test]
312    fn test_cost_total() {
313        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
314        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
315        assert!((cost.total() - 30.0).abs() < 0.001);
316    }
317
318    #[test]
319    fn test_cost_model_node_scan() {
320        let model = CostModel::new();
321        let scan = NodeScanOp {
322            variable: "n".to_string(),
323            label: Some("Person".to_string()),
324            input: None,
325        };
326        let cost = model.node_scan_cost(&scan, 1000.0);
327
328        assert!(cost.cpu > 0.0);
329        assert!(cost.io > 0.0);
330    }
331
332    #[test]
333    fn test_cost_model_sort() {
334        let model = CostModel::new();
335        let sort = SortOp {
336            keys: vec![],
337            input: Box::new(LogicalOperator::Empty),
338        };
339
340        let cost_100 = model.sort_cost(&sort, 100.0);
341        let cost_1000 = model.sort_cost(&sort, 1000.0);
342
343        // Sorting 1000 rows should be more expensive than 100 rows
344        assert!(cost_1000.total() > cost_100.total());
345    }
346
347    #[test]
348    fn test_cost_zero() {
349        let cost = Cost::zero();
350        assert!((cost.cpu).abs() < 0.001);
351        assert!((cost.io).abs() < 0.001);
352        assert!((cost.memory).abs() < 0.001);
353        assert!((cost.network).abs() < 0.001);
354        assert!((cost.total()).abs() < 0.001);
355    }
356
357    #[test]
358    fn test_cost_add_assign() {
359        let mut cost = Cost::cpu(10.0);
360        cost += Cost::cpu(5.0).with_io(2.0);
361        assert!((cost.cpu - 15.0).abs() < 0.001);
362        assert!((cost.io - 2.0).abs() < 0.001);
363    }
364
365    #[test]
366    fn test_cost_total_weighted() {
367        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
368        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
369        let total = cost.total_weighted(2.0, 5.0, 0.5);
370        assert!((total - 80.0).abs() < 0.001);
371    }
372
373    #[test]
374    fn test_cost_model_filter() {
375        let model = CostModel::new();
376        let filter = FilterOp {
377            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
378            input: Box::new(LogicalOperator::Empty),
379        };
380        let cost = model.filter_cost(&filter, 1000.0);
381
382        // Filter cost is CPU only
383        assert!(cost.cpu > 0.0);
384        assert!((cost.io).abs() < 0.001);
385    }
386
387    #[test]
388    fn test_cost_model_project() {
389        let model = CostModel::new();
390        let project = ProjectOp {
391            projections: vec![
392                Projection {
393                    expression: LogicalExpression::Variable("a".to_string()),
394                    alias: None,
395                },
396                Projection {
397                    expression: LogicalExpression::Variable("b".to_string()),
398                    alias: None,
399                },
400            ],
401            input: Box::new(LogicalOperator::Empty),
402        };
403        let cost = model.project_cost(&project, 1000.0);
404
405        // Cost should scale with number of projections
406        assert!(cost.cpu > 0.0);
407    }
408
409    #[test]
410    fn test_cost_model_expand() {
411        let model = CostModel::new();
412        let expand = ExpandOp {
413            from_variable: "a".to_string(),
414            to_variable: "b".to_string(),
415            edge_variable: None,
416            direction: ExpandDirection::Outgoing,
417            edge_type: None,
418            min_hops: 1,
419            max_hops: Some(1),
420            input: Box::new(LogicalOperator::Empty),
421            path_alias: None,
422        };
423        let cost = model.expand_cost(&expand, 1000.0);
424
425        // Expand involves hash lookups and output generation
426        assert!(cost.cpu > 0.0);
427    }
428
429    #[test]
430    fn test_cost_model_hash_join() {
431        let model = CostModel::new();
432        let join = JoinOp {
433            left: Box::new(LogicalOperator::Empty),
434            right: Box::new(LogicalOperator::Empty),
435            join_type: JoinType::Inner,
436            conditions: vec![JoinCondition {
437                left: LogicalExpression::Variable("a".to_string()),
438                right: LogicalExpression::Variable("b".to_string()),
439            }],
440        };
441        let cost = model.join_cost(&join, 10000.0);
442
443        // Hash join has CPU cost and memory cost
444        assert!(cost.cpu > 0.0);
445        assert!(cost.memory > 0.0);
446    }
447
448    #[test]
449    fn test_cost_model_cross_join() {
450        let model = CostModel::new();
451        let join = JoinOp {
452            left: Box::new(LogicalOperator::Empty),
453            right: Box::new(LogicalOperator::Empty),
454            join_type: JoinType::Cross,
455            conditions: vec![],
456        };
457        let cost = model.join_cost(&join, 1000000.0);
458
459        // Cross join is expensive
460        assert!(cost.cpu > 0.0);
461    }
462
463    #[test]
464    fn test_cost_model_semi_join() {
465        let model = CostModel::new();
466        let join = JoinOp {
467            left: Box::new(LogicalOperator::Empty),
468            right: Box::new(LogicalOperator::Empty),
469            join_type: JoinType::Semi,
470            conditions: vec![],
471        };
472        let cost_semi = model.join_cost(&join, 1000.0);
473
474        let inner_join = JoinOp {
475            left: Box::new(LogicalOperator::Empty),
476            right: Box::new(LogicalOperator::Empty),
477            join_type: JoinType::Inner,
478            conditions: vec![],
479        };
480        let cost_inner = model.join_cost(&inner_join, 1000.0);
481
482        // Semi join can be cheaper than inner join
483        assert!(cost_semi.cpu > 0.0);
484        assert!(cost_inner.cpu > 0.0);
485    }
486
487    #[test]
488    fn test_cost_model_aggregate() {
489        let model = CostModel::new();
490        let agg = AggregateOp {
491            group_by: vec![],
492            aggregates: vec![
493                AggregateExpr {
494                    function: AggregateFunction::Count,
495                    expression: None,
496                    distinct: false,
497                    alias: Some("cnt".to_string()),
498                    percentile: None,
499                },
500                AggregateExpr {
501                    function: AggregateFunction::Sum,
502                    expression: Some(LogicalExpression::Variable("x".to_string())),
503                    distinct: false,
504                    alias: Some("total".to_string()),
505                    percentile: None,
506                },
507            ],
508            input: Box::new(LogicalOperator::Empty),
509            having: None,
510        };
511        let cost = model.aggregate_cost(&agg, 1000.0);
512
513        // Aggregation has hash cost and memory cost
514        assert!(cost.cpu > 0.0);
515        assert!(cost.memory > 0.0);
516    }
517
518    #[test]
519    fn test_cost_model_distinct() {
520        let model = CostModel::new();
521        let distinct = DistinctOp {
522            input: Box::new(LogicalOperator::Empty),
523            columns: None,
524        };
525        let cost = model.distinct_cost(&distinct, 1000.0);
526
527        // Distinct uses hash set
528        assert!(cost.cpu > 0.0);
529        assert!(cost.memory > 0.0);
530    }
531
532    #[test]
533    fn test_cost_model_limit() {
534        let model = CostModel::new();
535        let limit = LimitOp {
536            count: 10,
537            input: Box::new(LogicalOperator::Empty),
538        };
539        let cost = model.limit_cost(&limit, 1000.0);
540
541        // Limit is very cheap
542        assert!(cost.cpu > 0.0);
543        assert!(cost.cpu < 1.0); // Should be minimal
544    }
545
546    #[test]
547    fn test_cost_model_skip() {
548        let model = CostModel::new();
549        let skip = SkipOp {
550            count: 100,
551            input: Box::new(LogicalOperator::Empty),
552        };
553        let cost = model.skip_cost(&skip, 1000.0);
554
555        // Skip must scan through skipped rows
556        assert!(cost.cpu > 0.0);
557    }
558
559    #[test]
560    fn test_cost_model_return() {
561        let model = CostModel::new();
562        let ret = ReturnOp {
563            items: vec![
564                ReturnItem {
565                    expression: LogicalExpression::Variable("a".to_string()),
566                    alias: None,
567                },
568                ReturnItem {
569                    expression: LogicalExpression::Variable("b".to_string()),
570                    alias: None,
571                },
572            ],
573            distinct: false,
574            input: Box::new(LogicalOperator::Empty),
575        };
576        let cost = model.return_cost(&ret, 1000.0);
577
578        // Return materializes results
579        assert!(cost.cpu > 0.0);
580    }
581
582    #[test]
583    fn test_cost_cheaper() {
584        let model = CostModel::new();
585        let cheap = Cost::cpu(10.0);
586        let expensive = Cost::cpu(100.0);
587
588        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
589        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
590    }
591
592    #[test]
593    fn test_cost_comparison_prefers_lower_total() {
594        let model = CostModel::new();
595        // High CPU, low IO
596        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
597        // Low CPU, high IO
598        let io_heavy = Cost::cpu(10.0).with_io(20.0);
599
600        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
601        assert!(cpu_heavy.total() < io_heavy.total());
602        assert_eq!(
603            model.cheaper(&cpu_heavy, &io_heavy).total(),
604            cpu_heavy.total()
605        );
606    }
607
608    #[test]
609    fn test_cost_model_sort_with_keys() {
610        let model = CostModel::new();
611        let sort_single = SortOp {
612            keys: vec![crate::query::plan::SortKey {
613                expression: LogicalExpression::Variable("a".to_string()),
614                order: SortOrder::Ascending,
615            }],
616            input: Box::new(LogicalOperator::Empty),
617        };
618        let sort_multi = SortOp {
619            keys: vec![
620                crate::query::plan::SortKey {
621                    expression: LogicalExpression::Variable("a".to_string()),
622                    order: SortOrder::Ascending,
623                },
624                crate::query::plan::SortKey {
625                    expression: LogicalExpression::Variable("b".to_string()),
626                    order: SortOrder::Descending,
627                },
628            ],
629            input: Box::new(LogicalOperator::Empty),
630        };
631
632        let cost_single = model.sort_cost(&sort_single, 1000.0);
633        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
634
635        // More sort keys = more comparisons
636        assert!(cost_multi.cpu > cost_single.cpu);
637    }
638
639    #[test]
640    fn test_cost_model_empty_operator() {
641        let model = CostModel::new();
642        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
643        assert!((cost.total()).abs() < 0.001);
644    }
645
646    #[test]
647    fn test_cost_model_default() {
648        let model = CostModel::default();
649        let scan = NodeScanOp {
650            variable: "n".to_string(),
651            label: None,
652            input: None,
653        };
654        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
655        assert!(cost.total() > 0.0);
656    }
657}