Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per I/O page read.
104    #[allow(dead_code)]
105    io_page_cost: f64,
106    /// Cost per hash table lookup.
107    hash_lookup_cost: f64,
108    /// Cost per comparison in sorting.
109    sort_comparison_cost: f64,
110    /// Average tuple size in bytes.
111    avg_tuple_size: f64,
112    /// Page size in bytes.
113    page_size: f64,
114}
115
116impl CostModel {
117    /// Creates a new cost model with default parameters.
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            cpu_tuple_cost: 0.01,
122            io_page_cost: 1.0,
123            hash_lookup_cost: 0.02,
124            sort_comparison_cost: 0.02,
125            avg_tuple_size: 100.0,
126            page_size: 8192.0,
127        }
128    }
129
130    /// Estimates the cost of a logical operator.
131    #[must_use]
132    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133        match op {
134            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145            LogicalOperator::Empty => Cost::zero(),
146            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147        }
148    }
149
150    /// Estimates the cost of a node scan.
151    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154    }
155
156    /// Estimates the cost of a filter operation.
157    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158        // Filter cost is just predicate evaluation per tuple
159        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160    }
161
162    /// Estimates the cost of a projection.
163    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164        // Cost depends on number of expressions evaluated
165        let expr_count = project.projections.len() as f64;
166        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167    }
168
169    /// Estimates the cost of an expand operation.
170    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171        // Expand involves adjacency list lookups
172        let lookup_cost = cardinality * self.hash_lookup_cost;
173        // Assume average fanout of 10 for edge traversal
174        let avg_fanout = 10.0;
175        let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176        Cost::cpu(lookup_cost + output_cost)
177    }
178
179    /// Estimates the cost of a join operation.
180    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181        // Cost depends on join type
182        match join.join_type {
183            JoinType::Cross => {
184                // Cross join is O(n * m)
185                Cost::cpu(cardinality * self.cpu_tuple_cost)
186            }
187            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188                // Hash join: build phase + probe phase
189                // Assume left side is build, right side is probe
190                let build_cardinality = cardinality.sqrt(); // Rough estimate
191                let probe_cardinality = cardinality.sqrt();
192
193                // Build hash table
194                let build_cost = build_cardinality * self.hash_lookup_cost;
195                let memory_cost = build_cardinality * self.avg_tuple_size;
196
197                // Probe hash table
198                let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200                // Output cost
201                let output_cost = cardinality * self.cpu_tuple_cost;
202
203                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204            }
205            JoinType::Semi | JoinType::Anti => {
206                // Semi/anti joins are typically cheaper
207                let build_cardinality = cardinality.sqrt();
208                let probe_cardinality = cardinality.sqrt();
209
210                let build_cost = build_cardinality * self.hash_lookup_cost;
211                let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213                Cost::cpu(build_cost + probe_cost)
214                    .with_memory(build_cardinality * self.avg_tuple_size)
215            }
216        }
217    }
218
219    /// Estimates the cost of an aggregation.
220    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221        // Hash aggregation cost
222        let hash_cost = cardinality * self.hash_lookup_cost;
223
224        // Aggregate function evaluation
225        let agg_count = agg.aggregates.len() as f64;
226        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228        // Memory for hash table (estimated distinct groups)
229        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
230        let memory_cost = distinct_groups * self.avg_tuple_size;
231
232        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233    }
234
235    /// Estimates the cost of a sort operation.
236    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237        if cardinality <= 1.0 {
238            return Cost::zero();
239        }
240
241        // Sort is O(n log n) comparisons
242        let comparisons = cardinality * cardinality.log2();
243        let key_count = sort.keys.len() as f64;
244
245        // Memory for sorting (full input materialization)
246        let memory_cost = cardinality * self.avg_tuple_size;
247
248        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249    }
250
251    /// Estimates the cost of a distinct operation.
252    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253        // Hash-based distinct
254        let hash_cost = cardinality * self.hash_lookup_cost;
255        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
256
257        Cost::cpu(hash_cost).with_memory(memory_cost)
258    }
259
260    /// Estimates the cost of a limit operation.
261    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262        // Limit is very cheap - just counting
263        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264    }
265
266    /// Estimates the cost of a skip operation.
267    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268        // Skip requires scanning through skipped rows
269        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270    }
271
272    /// Estimates the cost of a return operation.
273    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274        // Return materializes results
275        let expr_count = ret.items.len() as f64;
276        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277    }
278
279    /// Compares two costs and returns the cheaper one.
280    #[must_use]
281    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282        if a.total() <= b.total() { a } else { b }
283    }
284
285    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
286    ///
287    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
288    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
289    /// all relations simultaneously using sorted iterators.
290    ///
291    /// # Arguments
292    /// * `num_relations` - Number of relations participating in the join
293    /// * `cardinalities` - Cardinality of each input relation
294    /// * `output_cardinality` - Expected output cardinality
295    ///
296    /// # Cost Model
297    /// - Materialization: O(sum of cardinalities) to build trie indexes
298    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
299    /// - Memory: Trie storage for all inputs
300    #[must_use]
301    pub fn leapfrog_join_cost(
302        &self,
303        num_relations: usize,
304        cardinalities: &[f64],
305        output_cardinality: f64,
306    ) -> Cost {
307        if cardinalities.is_empty() {
308            return Cost::zero();
309        }
310
311        let total_input: f64 = cardinalities.iter().sum();
312        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
313
314        // Materialization phase: build trie indexes for each input
315        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
316
317        // Intersection phase: leapfrog seeks are O(log n) per relation
318        let seek_cost = if min_card > 1.0 {
319            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
320        } else {
321            output_cardinality * self.cpu_tuple_cost
322        };
323
324        // Output materialization
325        let output_cost = output_cardinality * self.cpu_tuple_cost;
326
327        // Memory: trie storage (roughly 2x input size for sorted index)
328        let memory = total_input * self.avg_tuple_size * 2.0;
329
330        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
331    }
332
333    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
334    ///
335    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
336    #[must_use]
337    pub fn prefer_leapfrog_join(
338        &self,
339        num_relations: usize,
340        cardinalities: &[f64],
341        output_cardinality: f64,
342    ) -> bool {
343        if num_relations < 3 || cardinalities.len() < 3 {
344            // Leapfrog is only beneficial for multi-way joins (3+)
345            return false;
346        }
347
348        let leapfrog_cost =
349            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
350
351        // Estimate cascade of binary hash joins
352        // For N relations, we need N-1 joins
353        // Each join produces intermediate results that feed the next
354        let mut hash_cascade_cost = Cost::zero();
355        let mut intermediate_cardinality = cardinalities[0];
356
357        for card in &cardinalities[1..] {
358            // Hash join cost: build + probe + output
359            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
360            let join = JoinOp {
361                left: Box::new(LogicalOperator::Empty),
362                right: Box::new(LogicalOperator::Empty),
363                join_type: JoinType::Inner,
364                conditions: vec![],
365            };
366            hash_cascade_cost += self.join_cost(&join, join_output);
367            intermediate_cardinality = join_output;
368        }
369
370        leapfrog_cost.total() < hash_cascade_cost.total()
371    }
372
373    /// Estimates cost for factorized execution (compressed intermediate results).
374    ///
375    /// Factorized execution avoids materializing full cross products by keeping
376    /// results in a compressed "factorized" form. This is beneficial for multi-hop
377    /// traversals where intermediate results can explode.
378    ///
379    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
380    #[must_use]
381    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
382        if num_hops <= 1 || avg_fanout <= 1.0 {
383            return 1.0; // No benefit for single hop or low fanout
384        }
385
386        // Factorized representation compresses repeated prefixes
387        // Compression ratio improves with higher fanout and more hops
388        // Full materialization: fanout^hops
389        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
390
391        let full_size = avg_fanout.powi(num_hops as i32);
392        let factorized_size = if avg_fanout > 1.0 {
393            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
394        } else {
395            num_hops as f64
396        };
397
398        (factorized_size / full_size).min(1.0)
399    }
400}
401
402impl Default for CostModel {
403    fn default() -> Self {
404        Self::new()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::query::plan::{
412        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
413        Projection, ReturnItem, SortOrder,
414    };
415
416    #[test]
417    fn test_cost_addition() {
418        let a = Cost::cpu(10.0).with_io(5.0);
419        let b = Cost::cpu(20.0).with_memory(100.0);
420        let c = a + b;
421
422        assert!((c.cpu - 30.0).abs() < 0.001);
423        assert!((c.io - 5.0).abs() < 0.001);
424        assert!((c.memory - 100.0).abs() < 0.001);
425    }
426
427    #[test]
428    fn test_cost_total() {
429        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
430        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
431        assert!((cost.total() - 30.0).abs() < 0.001);
432    }
433
434    #[test]
435    fn test_cost_model_node_scan() {
436        let model = CostModel::new();
437        let scan = NodeScanOp {
438            variable: "n".to_string(),
439            label: Some("Person".to_string()),
440            input: None,
441        };
442        let cost = model.node_scan_cost(&scan, 1000.0);
443
444        assert!(cost.cpu > 0.0);
445        assert!(cost.io > 0.0);
446    }
447
448    #[test]
449    fn test_cost_model_sort() {
450        let model = CostModel::new();
451        let sort = SortOp {
452            keys: vec![],
453            input: Box::new(LogicalOperator::Empty),
454        };
455
456        let cost_100 = model.sort_cost(&sort, 100.0);
457        let cost_1000 = model.sort_cost(&sort, 1000.0);
458
459        // Sorting 1000 rows should be more expensive than 100 rows
460        assert!(cost_1000.total() > cost_100.total());
461    }
462
463    #[test]
464    fn test_cost_zero() {
465        let cost = Cost::zero();
466        assert!((cost.cpu).abs() < 0.001);
467        assert!((cost.io).abs() < 0.001);
468        assert!((cost.memory).abs() < 0.001);
469        assert!((cost.network).abs() < 0.001);
470        assert!((cost.total()).abs() < 0.001);
471    }
472
473    #[test]
474    fn test_cost_add_assign() {
475        let mut cost = Cost::cpu(10.0);
476        cost += Cost::cpu(5.0).with_io(2.0);
477        assert!((cost.cpu - 15.0).abs() < 0.001);
478        assert!((cost.io - 2.0).abs() < 0.001);
479    }
480
481    #[test]
482    fn test_cost_total_weighted() {
483        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
484        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
485        let total = cost.total_weighted(2.0, 5.0, 0.5);
486        assert!((total - 80.0).abs() < 0.001);
487    }
488
489    #[test]
490    fn test_cost_model_filter() {
491        let model = CostModel::new();
492        let filter = FilterOp {
493            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
494            input: Box::new(LogicalOperator::Empty),
495        };
496        let cost = model.filter_cost(&filter, 1000.0);
497
498        // Filter cost is CPU only
499        assert!(cost.cpu > 0.0);
500        assert!((cost.io).abs() < 0.001);
501    }
502
503    #[test]
504    fn test_cost_model_project() {
505        let model = CostModel::new();
506        let project = ProjectOp {
507            projections: vec![
508                Projection {
509                    expression: LogicalExpression::Variable("a".to_string()),
510                    alias: None,
511                },
512                Projection {
513                    expression: LogicalExpression::Variable("b".to_string()),
514                    alias: None,
515                },
516            ],
517            input: Box::new(LogicalOperator::Empty),
518        };
519        let cost = model.project_cost(&project, 1000.0);
520
521        // Cost should scale with number of projections
522        assert!(cost.cpu > 0.0);
523    }
524
525    #[test]
526    fn test_cost_model_expand() {
527        let model = CostModel::new();
528        let expand = ExpandOp {
529            from_variable: "a".to_string(),
530            to_variable: "b".to_string(),
531            edge_variable: None,
532            direction: ExpandDirection::Outgoing,
533            edge_type: None,
534            min_hops: 1,
535            max_hops: Some(1),
536            input: Box::new(LogicalOperator::Empty),
537            path_alias: None,
538        };
539        let cost = model.expand_cost(&expand, 1000.0);
540
541        // Expand involves hash lookups and output generation
542        assert!(cost.cpu > 0.0);
543    }
544
545    #[test]
546    fn test_cost_model_hash_join() {
547        let model = CostModel::new();
548        let join = JoinOp {
549            left: Box::new(LogicalOperator::Empty),
550            right: Box::new(LogicalOperator::Empty),
551            join_type: JoinType::Inner,
552            conditions: vec![JoinCondition {
553                left: LogicalExpression::Variable("a".to_string()),
554                right: LogicalExpression::Variable("b".to_string()),
555            }],
556        };
557        let cost = model.join_cost(&join, 10000.0);
558
559        // Hash join has CPU cost and memory cost
560        assert!(cost.cpu > 0.0);
561        assert!(cost.memory > 0.0);
562    }
563
564    #[test]
565    fn test_cost_model_cross_join() {
566        let model = CostModel::new();
567        let join = JoinOp {
568            left: Box::new(LogicalOperator::Empty),
569            right: Box::new(LogicalOperator::Empty),
570            join_type: JoinType::Cross,
571            conditions: vec![],
572        };
573        let cost = model.join_cost(&join, 1000000.0);
574
575        // Cross join is expensive
576        assert!(cost.cpu > 0.0);
577    }
578
579    #[test]
580    fn test_cost_model_semi_join() {
581        let model = CostModel::new();
582        let join = JoinOp {
583            left: Box::new(LogicalOperator::Empty),
584            right: Box::new(LogicalOperator::Empty),
585            join_type: JoinType::Semi,
586            conditions: vec![],
587        };
588        let cost_semi = model.join_cost(&join, 1000.0);
589
590        let inner_join = JoinOp {
591            left: Box::new(LogicalOperator::Empty),
592            right: Box::new(LogicalOperator::Empty),
593            join_type: JoinType::Inner,
594            conditions: vec![],
595        };
596        let cost_inner = model.join_cost(&inner_join, 1000.0);
597
598        // Semi join can be cheaper than inner join
599        assert!(cost_semi.cpu > 0.0);
600        assert!(cost_inner.cpu > 0.0);
601    }
602
603    #[test]
604    fn test_cost_model_aggregate() {
605        let model = CostModel::new();
606        let agg = AggregateOp {
607            group_by: vec![],
608            aggregates: vec![
609                AggregateExpr {
610                    function: AggregateFunction::Count,
611                    expression: None,
612                    distinct: false,
613                    alias: Some("cnt".to_string()),
614                    percentile: None,
615                },
616                AggregateExpr {
617                    function: AggregateFunction::Sum,
618                    expression: Some(LogicalExpression::Variable("x".to_string())),
619                    distinct: false,
620                    alias: Some("total".to_string()),
621                    percentile: None,
622                },
623            ],
624            input: Box::new(LogicalOperator::Empty),
625            having: None,
626        };
627        let cost = model.aggregate_cost(&agg, 1000.0);
628
629        // Aggregation has hash cost and memory cost
630        assert!(cost.cpu > 0.0);
631        assert!(cost.memory > 0.0);
632    }
633
634    #[test]
635    fn test_cost_model_distinct() {
636        let model = CostModel::new();
637        let distinct = DistinctOp {
638            input: Box::new(LogicalOperator::Empty),
639            columns: None,
640        };
641        let cost = model.distinct_cost(&distinct, 1000.0);
642
643        // Distinct uses hash set
644        assert!(cost.cpu > 0.0);
645        assert!(cost.memory > 0.0);
646    }
647
648    #[test]
649    fn test_cost_model_limit() {
650        let model = CostModel::new();
651        let limit = LimitOp {
652            count: 10,
653            input: Box::new(LogicalOperator::Empty),
654        };
655        let cost = model.limit_cost(&limit, 1000.0);
656
657        // Limit is very cheap
658        assert!(cost.cpu > 0.0);
659        assert!(cost.cpu < 1.0); // Should be minimal
660    }
661
662    #[test]
663    fn test_cost_model_skip() {
664        let model = CostModel::new();
665        let skip = SkipOp {
666            count: 100,
667            input: Box::new(LogicalOperator::Empty),
668        };
669        let cost = model.skip_cost(&skip, 1000.0);
670
671        // Skip must scan through skipped rows
672        assert!(cost.cpu > 0.0);
673    }
674
675    #[test]
676    fn test_cost_model_return() {
677        let model = CostModel::new();
678        let ret = ReturnOp {
679            items: vec![
680                ReturnItem {
681                    expression: LogicalExpression::Variable("a".to_string()),
682                    alias: None,
683                },
684                ReturnItem {
685                    expression: LogicalExpression::Variable("b".to_string()),
686                    alias: None,
687                },
688            ],
689            distinct: false,
690            input: Box::new(LogicalOperator::Empty),
691        };
692        let cost = model.return_cost(&ret, 1000.0);
693
694        // Return materializes results
695        assert!(cost.cpu > 0.0);
696    }
697
698    #[test]
699    fn test_cost_cheaper() {
700        let model = CostModel::new();
701        let cheap = Cost::cpu(10.0);
702        let expensive = Cost::cpu(100.0);
703
704        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
705        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
706    }
707
708    #[test]
709    fn test_cost_comparison_prefers_lower_total() {
710        let model = CostModel::new();
711        // High CPU, low IO
712        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
713        // Low CPU, high IO
714        let io_heavy = Cost::cpu(10.0).with_io(20.0);
715
716        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
717        assert!(cpu_heavy.total() < io_heavy.total());
718        assert_eq!(
719            model.cheaper(&cpu_heavy, &io_heavy).total(),
720            cpu_heavy.total()
721        );
722    }
723
724    #[test]
725    fn test_cost_model_sort_with_keys() {
726        let model = CostModel::new();
727        let sort_single = SortOp {
728            keys: vec![crate::query::plan::SortKey {
729                expression: LogicalExpression::Variable("a".to_string()),
730                order: SortOrder::Ascending,
731            }],
732            input: Box::new(LogicalOperator::Empty),
733        };
734        let sort_multi = SortOp {
735            keys: vec![
736                crate::query::plan::SortKey {
737                    expression: LogicalExpression::Variable("a".to_string()),
738                    order: SortOrder::Ascending,
739                },
740                crate::query::plan::SortKey {
741                    expression: LogicalExpression::Variable("b".to_string()),
742                    order: SortOrder::Descending,
743                },
744            ],
745            input: Box::new(LogicalOperator::Empty),
746        };
747
748        let cost_single = model.sort_cost(&sort_single, 1000.0);
749        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
750
751        // More sort keys = more comparisons
752        assert!(cost_multi.cpu > cost_single.cpu);
753    }
754
755    #[test]
756    fn test_cost_model_empty_operator() {
757        let model = CostModel::new();
758        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
759        assert!((cost.total()).abs() < 0.001);
760    }
761
762    #[test]
763    fn test_cost_model_default() {
764        let model = CostModel::default();
765        let scan = NodeScanOp {
766            variable: "n".to_string(),
767            label: None,
768            input: None,
769        };
770        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
771        assert!(cost.total() > 0.0);
772    }
773
774    #[test]
775    fn test_leapfrog_join_cost() {
776        let model = CostModel::new();
777
778        // Three-way join (triangle pattern)
779        let cardinalities = vec![1000.0, 1000.0, 1000.0];
780        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
781
782        // Should have CPU cost for materialization and intersection
783        assert!(cost.cpu > 0.0);
784        // Should have memory cost for trie storage
785        assert!(cost.memory > 0.0);
786    }
787
788    #[test]
789    fn test_leapfrog_join_cost_empty() {
790        let model = CostModel::new();
791        let cost = model.leapfrog_join_cost(0, &[], 0.0);
792        assert!((cost.total()).abs() < 0.001);
793    }
794
795    #[test]
796    fn test_prefer_leapfrog_join_for_triangles() {
797        let model = CostModel::new();
798
799        // Compare costs for triangle pattern
800        let cardinalities = vec![10000.0, 10000.0, 10000.0];
801        let output = 1000.0;
802
803        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
804
805        // Leapfrog should have reasonable cost for triangle patterns
806        assert!(leapfrog_cost.cpu > 0.0);
807        assert!(leapfrog_cost.memory > 0.0);
808
809        // The prefer_leapfrog_join method compares against hash cascade
810        // Actual preference depends on specific cost parameters
811        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
812        // Test that it returns a boolean (doesn't panic)
813    }
814
815    #[test]
816    fn test_prefer_leapfrog_join_binary_case() {
817        let model = CostModel::new();
818
819        // Binary join should NOT prefer leapfrog (need 3+ relations)
820        let cardinalities = vec![1000.0, 1000.0];
821        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
822        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
823    }
824
825    #[test]
826    fn test_factorized_benefit_single_hop() {
827        let model = CostModel::new();
828
829        // Single hop: no factorization benefit
830        let benefit = model.factorized_benefit(10.0, 1);
831        assert!(
832            (benefit - 1.0).abs() < 0.001,
833            "Single hop should have no benefit"
834        );
835    }
836
837    #[test]
838    fn test_factorized_benefit_multi_hop() {
839        let model = CostModel::new();
840
841        // Multi-hop with high fanout
842        let benefit = model.factorized_benefit(10.0, 3);
843
844        // The factorized_benefit returns a ratio capped at 1.0
845        // For high fanout, factorized size / full size approaches 1/fanout
846        // which is beneficial but the formula gives a value <= 1.0
847        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
848        assert!(benefit > 0.0, "Benefit should be positive");
849    }
850
851    #[test]
852    fn test_factorized_benefit_low_fanout() {
853        let model = CostModel::new();
854
855        // Low fanout: minimal benefit
856        let benefit = model.factorized_benefit(1.5, 2);
857        assert!(
858            benefit <= 1.0,
859            "Low fanout still benefits from factorization"
860        );
861    }
862}