Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per I/O page read.
104    #[allow(dead_code)]
105    io_page_cost: f64,
106    /// Cost per hash table lookup.
107    hash_lookup_cost: f64,
108    /// Cost per comparison in sorting.
109    sort_comparison_cost: f64,
110    /// Average tuple size in bytes.
111    avg_tuple_size: f64,
112    /// Page size in bytes.
113    page_size: f64,
114}
115
116impl CostModel {
117    /// Creates a new cost model with default parameters.
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            cpu_tuple_cost: 0.01,
122            io_page_cost: 1.0,
123            hash_lookup_cost: 0.02,
124            sort_comparison_cost: 0.02,
125            avg_tuple_size: 100.0,
126            page_size: 8192.0,
127        }
128    }
129
130    /// Estimates the cost of a logical operator.
131    #[must_use]
132    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133        match op {
134            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145            LogicalOperator::Empty => Cost::zero(),
146            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147        }
148    }
149
150    /// Estimates the cost of a node scan.
151    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154    }
155
156    /// Estimates the cost of a filter operation.
157    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158        // Filter cost is just predicate evaluation per tuple
159        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160    }
161
162    /// Estimates the cost of a projection.
163    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164        // Cost depends on number of expressions evaluated
165        let expr_count = project.projections.len() as f64;
166        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167    }
168
169    /// Estimates the cost of an expand operation.
170    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171        // Expand involves adjacency list lookups
172        let lookup_cost = cardinality * self.hash_lookup_cost;
173        // Assume average fanout of 10 for edge traversal
174        let avg_fanout = 10.0;
175        let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176        Cost::cpu(lookup_cost + output_cost)
177    }
178
179    /// Estimates the cost of a join operation.
180    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181        // Cost depends on join type
182        match join.join_type {
183            JoinType::Cross => {
184                // Cross join is O(n * m)
185                Cost::cpu(cardinality * self.cpu_tuple_cost)
186            }
187            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188                // Hash join: build phase + probe phase
189                // Assume left side is build, right side is probe
190                let build_cardinality = cardinality.sqrt(); // Rough estimate
191                let probe_cardinality = cardinality.sqrt();
192
193                // Build hash table
194                let build_cost = build_cardinality * self.hash_lookup_cost;
195                let memory_cost = build_cardinality * self.avg_tuple_size;
196
197                // Probe hash table
198                let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200                // Output cost
201                let output_cost = cardinality * self.cpu_tuple_cost;
202
203                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204            }
205            JoinType::Semi | JoinType::Anti => {
206                // Semi/anti joins are typically cheaper
207                let build_cardinality = cardinality.sqrt();
208                let probe_cardinality = cardinality.sqrt();
209
210                let build_cost = build_cardinality * self.hash_lookup_cost;
211                let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213                Cost::cpu(build_cost + probe_cost)
214                    .with_memory(build_cardinality * self.avg_tuple_size)
215            }
216        }
217    }
218
219    /// Estimates the cost of an aggregation.
220    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221        // Hash aggregation cost
222        let hash_cost = cardinality * self.hash_lookup_cost;
223
224        // Aggregate function evaluation
225        let agg_count = agg.aggregates.len() as f64;
226        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228        // Memory for hash table (estimated distinct groups)
229        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
230        let memory_cost = distinct_groups * self.avg_tuple_size;
231
232        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233    }
234
235    /// Estimates the cost of a sort operation.
236    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237        if cardinality <= 1.0 {
238            return Cost::zero();
239        }
240
241        // Sort is O(n log n) comparisons
242        let comparisons = cardinality * cardinality.log2();
243        let key_count = sort.keys.len() as f64;
244
245        // Memory for sorting (full input materialization)
246        let memory_cost = cardinality * self.avg_tuple_size;
247
248        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249    }
250
251    /// Estimates the cost of a distinct operation.
252    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253        // Hash-based distinct
254        let hash_cost = cardinality * self.hash_lookup_cost;
255        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
256
257        Cost::cpu(hash_cost).with_memory(memory_cost)
258    }
259
260    /// Estimates the cost of a limit operation.
261    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262        // Limit is very cheap - just counting
263        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264    }
265
266    /// Estimates the cost of a skip operation.
267    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268        // Skip requires scanning through skipped rows
269        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270    }
271
272    /// Estimates the cost of a return operation.
273    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274        // Return materializes results
275        let expr_count = ret.items.len() as f64;
276        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277    }
278
279    /// Compares two costs and returns the cheaper one.
280    #[must_use]
281    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282        if a.total() <= b.total() { a } else { b }
283    }
284}
285
286impl Default for CostModel {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::query::plan::{
296        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
297        Projection, ReturnItem, SortOrder,
298    };
299
300    #[test]
301    fn test_cost_addition() {
302        let a = Cost::cpu(10.0).with_io(5.0);
303        let b = Cost::cpu(20.0).with_memory(100.0);
304        let c = a + b;
305
306        assert!((c.cpu - 30.0).abs() < 0.001);
307        assert!((c.io - 5.0).abs() < 0.001);
308        assert!((c.memory - 100.0).abs() < 0.001);
309    }
310
311    #[test]
312    fn test_cost_total() {
313        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
314        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
315        assert!((cost.total() - 30.0).abs() < 0.001);
316    }
317
318    #[test]
319    fn test_cost_model_node_scan() {
320        let model = CostModel::new();
321        let scan = NodeScanOp {
322            variable: "n".to_string(),
323            label: Some("Person".to_string()),
324            input: None,
325        };
326        let cost = model.node_scan_cost(&scan, 1000.0);
327
328        assert!(cost.cpu > 0.0);
329        assert!(cost.io > 0.0);
330    }
331
332    #[test]
333    fn test_cost_model_sort() {
334        let model = CostModel::new();
335        let sort = SortOp {
336            keys: vec![],
337            input: Box::new(LogicalOperator::Empty),
338        };
339
340        let cost_100 = model.sort_cost(&sort, 100.0);
341        let cost_1000 = model.sort_cost(&sort, 1000.0);
342
343        // Sorting 1000 rows should be more expensive than 100 rows
344        assert!(cost_1000.total() > cost_100.total());
345    }
346
347    #[test]
348    fn test_cost_zero() {
349        let cost = Cost::zero();
350        assert!((cost.cpu).abs() < 0.001);
351        assert!((cost.io).abs() < 0.001);
352        assert!((cost.memory).abs() < 0.001);
353        assert!((cost.network).abs() < 0.001);
354        assert!((cost.total()).abs() < 0.001);
355    }
356
357    #[test]
358    fn test_cost_add_assign() {
359        let mut cost = Cost::cpu(10.0);
360        cost += Cost::cpu(5.0).with_io(2.0);
361        assert!((cost.cpu - 15.0).abs() < 0.001);
362        assert!((cost.io - 2.0).abs() < 0.001);
363    }
364
365    #[test]
366    fn test_cost_total_weighted() {
367        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
368        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
369        let total = cost.total_weighted(2.0, 5.0, 0.5);
370        assert!((total - 80.0).abs() < 0.001);
371    }
372
373    #[test]
374    fn test_cost_model_filter() {
375        let model = CostModel::new();
376        let filter = FilterOp {
377            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
378            input: Box::new(LogicalOperator::Empty),
379        };
380        let cost = model.filter_cost(&filter, 1000.0);
381
382        // Filter cost is CPU only
383        assert!(cost.cpu > 0.0);
384        assert!((cost.io).abs() < 0.001);
385    }
386
387    #[test]
388    fn test_cost_model_project() {
389        let model = CostModel::new();
390        let project = ProjectOp {
391            projections: vec![
392                Projection {
393                    expression: LogicalExpression::Variable("a".to_string()),
394                    alias: None,
395                },
396                Projection {
397                    expression: LogicalExpression::Variable("b".to_string()),
398                    alias: None,
399                },
400            ],
401            input: Box::new(LogicalOperator::Empty),
402        };
403        let cost = model.project_cost(&project, 1000.0);
404
405        // Cost should scale with number of projections
406        assert!(cost.cpu > 0.0);
407    }
408
409    #[test]
410    fn test_cost_model_expand() {
411        let model = CostModel::new();
412        let expand = ExpandOp {
413            from_variable: "a".to_string(),
414            to_variable: "b".to_string(),
415            edge_variable: None,
416            direction: ExpandDirection::Outgoing,
417            edge_type: None,
418            min_hops: 1,
419            max_hops: Some(1),
420            input: Box::new(LogicalOperator::Empty),
421        };
422        let cost = model.expand_cost(&expand, 1000.0);
423
424        // Expand involves hash lookups and output generation
425        assert!(cost.cpu > 0.0);
426    }
427
428    #[test]
429    fn test_cost_model_hash_join() {
430        let model = CostModel::new();
431        let join = JoinOp {
432            left: Box::new(LogicalOperator::Empty),
433            right: Box::new(LogicalOperator::Empty),
434            join_type: JoinType::Inner,
435            conditions: vec![JoinCondition {
436                left: LogicalExpression::Variable("a".to_string()),
437                right: LogicalExpression::Variable("b".to_string()),
438            }],
439        };
440        let cost = model.join_cost(&join, 10000.0);
441
442        // Hash join has CPU cost and memory cost
443        assert!(cost.cpu > 0.0);
444        assert!(cost.memory > 0.0);
445    }
446
447    #[test]
448    fn test_cost_model_cross_join() {
449        let model = CostModel::new();
450        let join = JoinOp {
451            left: Box::new(LogicalOperator::Empty),
452            right: Box::new(LogicalOperator::Empty),
453            join_type: JoinType::Cross,
454            conditions: vec![],
455        };
456        let cost = model.join_cost(&join, 1000000.0);
457
458        // Cross join is expensive
459        assert!(cost.cpu > 0.0);
460    }
461
462    #[test]
463    fn test_cost_model_semi_join() {
464        let model = CostModel::new();
465        let join = JoinOp {
466            left: Box::new(LogicalOperator::Empty),
467            right: Box::new(LogicalOperator::Empty),
468            join_type: JoinType::Semi,
469            conditions: vec![],
470        };
471        let cost_semi = model.join_cost(&join, 1000.0);
472
473        let inner_join = JoinOp {
474            left: Box::new(LogicalOperator::Empty),
475            right: Box::new(LogicalOperator::Empty),
476            join_type: JoinType::Inner,
477            conditions: vec![],
478        };
479        let cost_inner = model.join_cost(&inner_join, 1000.0);
480
481        // Semi join can be cheaper than inner join
482        assert!(cost_semi.cpu > 0.0);
483        assert!(cost_inner.cpu > 0.0);
484    }
485
486    #[test]
487    fn test_cost_model_aggregate() {
488        let model = CostModel::new();
489        let agg = AggregateOp {
490            group_by: vec![],
491            aggregates: vec![
492                AggregateExpr {
493                    function: AggregateFunction::Count,
494                    expression: None,
495                    distinct: false,
496                    alias: Some("cnt".to_string()),
497                },
498                AggregateExpr {
499                    function: AggregateFunction::Sum,
500                    expression: Some(LogicalExpression::Variable("x".to_string())),
501                    distinct: false,
502                    alias: Some("total".to_string()),
503                },
504            ],
505            input: Box::new(LogicalOperator::Empty),
506        };
507        let cost = model.aggregate_cost(&agg, 1000.0);
508
509        // Aggregation has hash cost and memory cost
510        assert!(cost.cpu > 0.0);
511        assert!(cost.memory > 0.0);
512    }
513
514    #[test]
515    fn test_cost_model_distinct() {
516        let model = CostModel::new();
517        let distinct = DistinctOp {
518            input: Box::new(LogicalOperator::Empty),
519        };
520        let cost = model.distinct_cost(&distinct, 1000.0);
521
522        // Distinct uses hash set
523        assert!(cost.cpu > 0.0);
524        assert!(cost.memory > 0.0);
525    }
526
527    #[test]
528    fn test_cost_model_limit() {
529        let model = CostModel::new();
530        let limit = LimitOp {
531            count: 10,
532            input: Box::new(LogicalOperator::Empty),
533        };
534        let cost = model.limit_cost(&limit, 1000.0);
535
536        // Limit is very cheap
537        assert!(cost.cpu > 0.0);
538        assert!(cost.cpu < 1.0); // Should be minimal
539    }
540
541    #[test]
542    fn test_cost_model_skip() {
543        let model = CostModel::new();
544        let skip = SkipOp {
545            count: 100,
546            input: Box::new(LogicalOperator::Empty),
547        };
548        let cost = model.skip_cost(&skip, 1000.0);
549
550        // Skip must scan through skipped rows
551        assert!(cost.cpu > 0.0);
552    }
553
554    #[test]
555    fn test_cost_model_return() {
556        let model = CostModel::new();
557        let ret = ReturnOp {
558            items: vec![
559                ReturnItem {
560                    expression: LogicalExpression::Variable("a".to_string()),
561                    alias: None,
562                },
563                ReturnItem {
564                    expression: LogicalExpression::Variable("b".to_string()),
565                    alias: None,
566                },
567            ],
568            distinct: false,
569            input: Box::new(LogicalOperator::Empty),
570        };
571        let cost = model.return_cost(&ret, 1000.0);
572
573        // Return materializes results
574        assert!(cost.cpu > 0.0);
575    }
576
577    #[test]
578    fn test_cost_cheaper() {
579        let model = CostModel::new();
580        let cheap = Cost::cpu(10.0);
581        let expensive = Cost::cpu(100.0);
582
583        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
584        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
585    }
586
587    #[test]
588    fn test_cost_comparison_prefers_lower_total() {
589        let model = CostModel::new();
590        // High CPU, low IO
591        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
592        // Low CPU, high IO
593        let io_heavy = Cost::cpu(10.0).with_io(20.0);
594
595        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
596        assert!(cpu_heavy.total() < io_heavy.total());
597        assert_eq!(
598            model.cheaper(&cpu_heavy, &io_heavy).total(),
599            cpu_heavy.total()
600        );
601    }
602
603    #[test]
604    fn test_cost_model_sort_with_keys() {
605        let model = CostModel::new();
606        let sort_single = SortOp {
607            keys: vec![crate::query::plan::SortKey {
608                expression: LogicalExpression::Variable("a".to_string()),
609                order: SortOrder::Ascending,
610            }],
611            input: Box::new(LogicalOperator::Empty),
612        };
613        let sort_multi = SortOp {
614            keys: vec![
615                crate::query::plan::SortKey {
616                    expression: LogicalExpression::Variable("a".to_string()),
617                    order: SortOrder::Ascending,
618                },
619                crate::query::plan::SortKey {
620                    expression: LogicalExpression::Variable("b".to_string()),
621                    order: SortOrder::Descending,
622                },
623            ],
624            input: Box::new(LogicalOperator::Empty),
625        };
626
627        let cost_single = model.sort_cost(&sort_single, 1000.0);
628        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
629
630        // More sort keys = more comparisons
631        assert!(cost_multi.cpu > cost_single.cpu);
632    }
633
634    #[test]
635    fn test_cost_model_empty_operator() {
636        let model = CostModel::new();
637        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
638        assert!((cost.total()).abs() < 0.001);
639    }
640
641    #[test]
642    fn test_cost_model_default() {
643        let model = CostModel::default();
644        let scan = NodeScanOp {
645            variable: "n".to_string(),
646            label: None,
647            input: None,
648        };
649        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
650        assert!(cost.total() > 0.0);
651    }
652}