Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per I/O page read.
104    #[allow(dead_code)]
105    io_page_cost: f64,
106    /// Cost per hash table lookup.
107    hash_lookup_cost: f64,
108    /// Cost per comparison in sorting.
109    sort_comparison_cost: f64,
110    /// Average tuple size in bytes.
111    avg_tuple_size: f64,
112    /// Page size in bytes.
113    page_size: f64,
114}
115
116impl CostModel {
117    /// Creates a new cost model with default parameters.
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            cpu_tuple_cost: 0.01,
122            io_page_cost: 1.0,
123            hash_lookup_cost: 0.02,
124            sort_comparison_cost: 0.02,
125            avg_tuple_size: 100.0,
126            page_size: 8192.0,
127        }
128    }
129
130    /// Estimates the cost of a logical operator.
131    #[must_use]
132    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133        match op {
134            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145            LogicalOperator::Empty => Cost::zero(),
146            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
147            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
148            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
149        }
150    }
151
152    /// Estimates the cost of a node scan.
153    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
154        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
155        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
156    }
157
158    /// Estimates the cost of a filter operation.
159    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
160        // Filter cost is just predicate evaluation per tuple
161        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
162    }
163
164    /// Estimates the cost of a projection.
165    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
166        // Cost depends on number of expressions evaluated
167        let expr_count = project.projections.len() as f64;
168        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
169    }
170
171    /// Estimates the cost of an expand operation.
172    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
173        // Expand involves adjacency list lookups
174        let lookup_cost = cardinality * self.hash_lookup_cost;
175        // Assume average fanout of 10 for edge traversal
176        let avg_fanout = 10.0;
177        let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
178        Cost::cpu(lookup_cost + output_cost)
179    }
180
181    /// Estimates the cost of a join operation.
182    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
183        // Cost depends on join type
184        match join.join_type {
185            JoinType::Cross => {
186                // Cross join is O(n * m)
187                Cost::cpu(cardinality * self.cpu_tuple_cost)
188            }
189            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
190                // Hash join: build phase + probe phase
191                // Assume left side is build, right side is probe
192                let build_cardinality = cardinality.sqrt(); // Rough estimate
193                let probe_cardinality = cardinality.sqrt();
194
195                // Build hash table
196                let build_cost = build_cardinality * self.hash_lookup_cost;
197                let memory_cost = build_cardinality * self.avg_tuple_size;
198
199                // Probe hash table
200                let probe_cost = probe_cardinality * self.hash_lookup_cost;
201
202                // Output cost
203                let output_cost = cardinality * self.cpu_tuple_cost;
204
205                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
206            }
207            JoinType::Semi | JoinType::Anti => {
208                // Semi/anti joins are typically cheaper
209                let build_cardinality = cardinality.sqrt();
210                let probe_cardinality = cardinality.sqrt();
211
212                let build_cost = build_cardinality * self.hash_lookup_cost;
213                let probe_cost = probe_cardinality * self.hash_lookup_cost;
214
215                Cost::cpu(build_cost + probe_cost)
216                    .with_memory(build_cardinality * self.avg_tuple_size)
217            }
218        }
219    }
220
221    /// Estimates the cost of an aggregation.
222    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
223        // Hash aggregation cost
224        let hash_cost = cardinality * self.hash_lookup_cost;
225
226        // Aggregate function evaluation
227        let agg_count = agg.aggregates.len() as f64;
228        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
229
230        // Memory for hash table (estimated distinct groups)
231        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
232        let memory_cost = distinct_groups * self.avg_tuple_size;
233
234        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
235    }
236
237    /// Estimates the cost of a sort operation.
238    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
239        if cardinality <= 1.0 {
240            return Cost::zero();
241        }
242
243        // Sort is O(n log n) comparisons
244        let comparisons = cardinality * cardinality.log2();
245        let key_count = sort.keys.len() as f64;
246
247        // Memory for sorting (full input materialization)
248        let memory_cost = cardinality * self.avg_tuple_size;
249
250        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
251    }
252
253    /// Estimates the cost of a distinct operation.
254    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
255        // Hash-based distinct
256        let hash_cost = cardinality * self.hash_lookup_cost;
257        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
258
259        Cost::cpu(hash_cost).with_memory(memory_cost)
260    }
261
262    /// Estimates the cost of a limit operation.
263    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
264        // Limit is very cheap - just counting
265        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
266    }
267
268    /// Estimates the cost of a skip operation.
269    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
270        // Skip requires scanning through skipped rows
271        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
272    }
273
274    /// Estimates the cost of a return operation.
275    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
276        // Return materializes results
277        let expr_count = ret.items.len() as f64;
278        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
279    }
280
281    /// Estimates the cost of a vector scan operation.
282    ///
283    /// HNSW index search is O(log N) per query, while brute-force is O(N).
284    /// This estimates the HNSW case with ef search parameter.
285    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
286        // k determines output cardinality
287        let k = scan.k as f64;
288
289        // HNSW search cost: O(ef * log(N)) distance computations
290        // Assume ef = 64 (default), N = cardinality
291        let ef = 64.0;
292        let n = cardinality.max(1.0);
293        let search_cost = if scan.index_name.is_some() {
294            // HNSW: O(ef * log N)
295            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
296        } else {
297            // Brute-force: O(N)
298            n * self.cpu_tuple_cost * 10.0
299        };
300
301        // Memory for candidate heap
302        let memory = k * self.avg_tuple_size * 2.0;
303
304        Cost::cpu(search_cost).with_memory(memory)
305    }
306
307    /// Estimates the cost of a vector join operation.
308    ///
309    /// Vector join performs k-NN search for each input row.
310    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
311        let k = join.k as f64;
312
313        // Each input row triggers a vector search
314        // Assume brute-force for hybrid queries (no index specified typically)
315        let per_row_search_cost = if join.index_name.is_some() {
316            // HNSW: O(ef * log N)
317            let ef = 64.0;
318            let n = cardinality.max(1.0);
319            ef * n.ln() * self.cpu_tuple_cost * 10.0
320        } else {
321            // Brute-force: O(N) per input row
322            cardinality * self.cpu_tuple_cost * 10.0
323        };
324
325        // Total cost: input_rows * search_cost
326        // For vector join, cardinality is typically input cardinality * k
327        let input_cardinality = (cardinality / k).max(1.0);
328        let total_search_cost = input_cardinality * per_row_search_cost;
329
330        // Memory for results
331        let memory = cardinality * self.avg_tuple_size;
332
333        Cost::cpu(total_search_cost).with_memory(memory)
334    }
335
336    /// Compares two costs and returns the cheaper one.
337    #[must_use]
338    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
339        if a.total() <= b.total() { a } else { b }
340    }
341
342    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
343    ///
344    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
345    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
346    /// all relations simultaneously using sorted iterators.
347    ///
348    /// # Arguments
349    /// * `num_relations` - Number of relations participating in the join
350    /// * `cardinalities` - Cardinality of each input relation
351    /// * `output_cardinality` - Expected output cardinality
352    ///
353    /// # Cost Model
354    /// - Materialization: O(sum of cardinalities) to build trie indexes
355    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
356    /// - Memory: Trie storage for all inputs
357    #[must_use]
358    pub fn leapfrog_join_cost(
359        &self,
360        num_relations: usize,
361        cardinalities: &[f64],
362        output_cardinality: f64,
363    ) -> Cost {
364        if cardinalities.is_empty() {
365            return Cost::zero();
366        }
367
368        let total_input: f64 = cardinalities.iter().sum();
369        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
370
371        // Materialization phase: build trie indexes for each input
372        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
373
374        // Intersection phase: leapfrog seeks are O(log n) per relation
375        let seek_cost = if min_card > 1.0 {
376            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
377        } else {
378            output_cardinality * self.cpu_tuple_cost
379        };
380
381        // Output materialization
382        let output_cost = output_cardinality * self.cpu_tuple_cost;
383
384        // Memory: trie storage (roughly 2x input size for sorted index)
385        let memory = total_input * self.avg_tuple_size * 2.0;
386
387        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
388    }
389
390    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
391    ///
392    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
393    #[must_use]
394    pub fn prefer_leapfrog_join(
395        &self,
396        num_relations: usize,
397        cardinalities: &[f64],
398        output_cardinality: f64,
399    ) -> bool {
400        if num_relations < 3 || cardinalities.len() < 3 {
401            // Leapfrog is only beneficial for multi-way joins (3+)
402            return false;
403        }
404
405        let leapfrog_cost =
406            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
407
408        // Estimate cascade of binary hash joins
409        // For N relations, we need N-1 joins
410        // Each join produces intermediate results that feed the next
411        let mut hash_cascade_cost = Cost::zero();
412        let mut intermediate_cardinality = cardinalities[0];
413
414        for card in &cardinalities[1..] {
415            // Hash join cost: build + probe + output
416            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
417            let join = JoinOp {
418                left: Box::new(LogicalOperator::Empty),
419                right: Box::new(LogicalOperator::Empty),
420                join_type: JoinType::Inner,
421                conditions: vec![],
422            };
423            hash_cascade_cost += self.join_cost(&join, join_output);
424            intermediate_cardinality = join_output;
425        }
426
427        leapfrog_cost.total() < hash_cascade_cost.total()
428    }
429
430    /// Estimates cost for factorized execution (compressed intermediate results).
431    ///
432    /// Factorized execution avoids materializing full cross products by keeping
433    /// results in a compressed "factorized" form. This is beneficial for multi-hop
434    /// traversals where intermediate results can explode.
435    ///
436    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
437    #[must_use]
438    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
439        if num_hops <= 1 || avg_fanout <= 1.0 {
440            return 1.0; // No benefit for single hop or low fanout
441        }
442
443        // Factorized representation compresses repeated prefixes
444        // Compression ratio improves with higher fanout and more hops
445        // Full materialization: fanout^hops
446        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
447
448        let full_size = avg_fanout.powi(num_hops as i32);
449        let factorized_size = if avg_fanout > 1.0 {
450            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
451        } else {
452            num_hops as f64
453        };
454
455        (factorized_size / full_size).min(1.0)
456    }
457}
458
459impl Default for CostModel {
460    fn default() -> Self {
461        Self::new()
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::query::plan::{
469        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
470        Projection, ReturnItem, SortOrder,
471    };
472
473    #[test]
474    fn test_cost_addition() {
475        let a = Cost::cpu(10.0).with_io(5.0);
476        let b = Cost::cpu(20.0).with_memory(100.0);
477        let c = a + b;
478
479        assert!((c.cpu - 30.0).abs() < 0.001);
480        assert!((c.io - 5.0).abs() < 0.001);
481        assert!((c.memory - 100.0).abs() < 0.001);
482    }
483
484    #[test]
485    fn test_cost_total() {
486        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
487        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
488        assert!((cost.total() - 30.0).abs() < 0.001);
489    }
490
491    #[test]
492    fn test_cost_model_node_scan() {
493        let model = CostModel::new();
494        let scan = NodeScanOp {
495            variable: "n".to_string(),
496            label: Some("Person".to_string()),
497            input: None,
498        };
499        let cost = model.node_scan_cost(&scan, 1000.0);
500
501        assert!(cost.cpu > 0.0);
502        assert!(cost.io > 0.0);
503    }
504
505    #[test]
506    fn test_cost_model_sort() {
507        let model = CostModel::new();
508        let sort = SortOp {
509            keys: vec![],
510            input: Box::new(LogicalOperator::Empty),
511        };
512
513        let cost_100 = model.sort_cost(&sort, 100.0);
514        let cost_1000 = model.sort_cost(&sort, 1000.0);
515
516        // Sorting 1000 rows should be more expensive than 100 rows
517        assert!(cost_1000.total() > cost_100.total());
518    }
519
520    #[test]
521    fn test_cost_zero() {
522        let cost = Cost::zero();
523        assert!((cost.cpu).abs() < 0.001);
524        assert!((cost.io).abs() < 0.001);
525        assert!((cost.memory).abs() < 0.001);
526        assert!((cost.network).abs() < 0.001);
527        assert!((cost.total()).abs() < 0.001);
528    }
529
530    #[test]
531    fn test_cost_add_assign() {
532        let mut cost = Cost::cpu(10.0);
533        cost += Cost::cpu(5.0).with_io(2.0);
534        assert!((cost.cpu - 15.0).abs() < 0.001);
535        assert!((cost.io - 2.0).abs() < 0.001);
536    }
537
538    #[test]
539    fn test_cost_total_weighted() {
540        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
541        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
542        let total = cost.total_weighted(2.0, 5.0, 0.5);
543        assert!((total - 80.0).abs() < 0.001);
544    }
545
546    #[test]
547    fn test_cost_model_filter() {
548        let model = CostModel::new();
549        let filter = FilterOp {
550            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
551            input: Box::new(LogicalOperator::Empty),
552        };
553        let cost = model.filter_cost(&filter, 1000.0);
554
555        // Filter cost is CPU only
556        assert!(cost.cpu > 0.0);
557        assert!((cost.io).abs() < 0.001);
558    }
559
560    #[test]
561    fn test_cost_model_project() {
562        let model = CostModel::new();
563        let project = ProjectOp {
564            projections: vec![
565                Projection {
566                    expression: LogicalExpression::Variable("a".to_string()),
567                    alias: None,
568                },
569                Projection {
570                    expression: LogicalExpression::Variable("b".to_string()),
571                    alias: None,
572                },
573            ],
574            input: Box::new(LogicalOperator::Empty),
575        };
576        let cost = model.project_cost(&project, 1000.0);
577
578        // Cost should scale with number of projections
579        assert!(cost.cpu > 0.0);
580    }
581
582    #[test]
583    fn test_cost_model_expand() {
584        let model = CostModel::new();
585        let expand = ExpandOp {
586            from_variable: "a".to_string(),
587            to_variable: "b".to_string(),
588            edge_variable: None,
589            direction: ExpandDirection::Outgoing,
590            edge_type: None,
591            min_hops: 1,
592            max_hops: Some(1),
593            input: Box::new(LogicalOperator::Empty),
594            path_alias: None,
595        };
596        let cost = model.expand_cost(&expand, 1000.0);
597
598        // Expand involves hash lookups and output generation
599        assert!(cost.cpu > 0.0);
600    }
601
602    #[test]
603    fn test_cost_model_hash_join() {
604        let model = CostModel::new();
605        let join = JoinOp {
606            left: Box::new(LogicalOperator::Empty),
607            right: Box::new(LogicalOperator::Empty),
608            join_type: JoinType::Inner,
609            conditions: vec![JoinCondition {
610                left: LogicalExpression::Variable("a".to_string()),
611                right: LogicalExpression::Variable("b".to_string()),
612            }],
613        };
614        let cost = model.join_cost(&join, 10000.0);
615
616        // Hash join has CPU cost and memory cost
617        assert!(cost.cpu > 0.0);
618        assert!(cost.memory > 0.0);
619    }
620
621    #[test]
622    fn test_cost_model_cross_join() {
623        let model = CostModel::new();
624        let join = JoinOp {
625            left: Box::new(LogicalOperator::Empty),
626            right: Box::new(LogicalOperator::Empty),
627            join_type: JoinType::Cross,
628            conditions: vec![],
629        };
630        let cost = model.join_cost(&join, 1000000.0);
631
632        // Cross join is expensive
633        assert!(cost.cpu > 0.0);
634    }
635
636    #[test]
637    fn test_cost_model_semi_join() {
638        let model = CostModel::new();
639        let join = JoinOp {
640            left: Box::new(LogicalOperator::Empty),
641            right: Box::new(LogicalOperator::Empty),
642            join_type: JoinType::Semi,
643            conditions: vec![],
644        };
645        let cost_semi = model.join_cost(&join, 1000.0);
646
647        let inner_join = JoinOp {
648            left: Box::new(LogicalOperator::Empty),
649            right: Box::new(LogicalOperator::Empty),
650            join_type: JoinType::Inner,
651            conditions: vec![],
652        };
653        let cost_inner = model.join_cost(&inner_join, 1000.0);
654
655        // Semi join can be cheaper than inner join
656        assert!(cost_semi.cpu > 0.0);
657        assert!(cost_inner.cpu > 0.0);
658    }
659
660    #[test]
661    fn test_cost_model_aggregate() {
662        let model = CostModel::new();
663        let agg = AggregateOp {
664            group_by: vec![],
665            aggregates: vec![
666                AggregateExpr {
667                    function: AggregateFunction::Count,
668                    expression: None,
669                    distinct: false,
670                    alias: Some("cnt".to_string()),
671                    percentile: None,
672                },
673                AggregateExpr {
674                    function: AggregateFunction::Sum,
675                    expression: Some(LogicalExpression::Variable("x".to_string())),
676                    distinct: false,
677                    alias: Some("total".to_string()),
678                    percentile: None,
679                },
680            ],
681            input: Box::new(LogicalOperator::Empty),
682            having: None,
683        };
684        let cost = model.aggregate_cost(&agg, 1000.0);
685
686        // Aggregation has hash cost and memory cost
687        assert!(cost.cpu > 0.0);
688        assert!(cost.memory > 0.0);
689    }
690
691    #[test]
692    fn test_cost_model_distinct() {
693        let model = CostModel::new();
694        let distinct = DistinctOp {
695            input: Box::new(LogicalOperator::Empty),
696            columns: None,
697        };
698        let cost = model.distinct_cost(&distinct, 1000.0);
699
700        // Distinct uses hash set
701        assert!(cost.cpu > 0.0);
702        assert!(cost.memory > 0.0);
703    }
704
705    #[test]
706    fn test_cost_model_limit() {
707        let model = CostModel::new();
708        let limit = LimitOp {
709            count: 10,
710            input: Box::new(LogicalOperator::Empty),
711        };
712        let cost = model.limit_cost(&limit, 1000.0);
713
714        // Limit is very cheap
715        assert!(cost.cpu > 0.0);
716        assert!(cost.cpu < 1.0); // Should be minimal
717    }
718
719    #[test]
720    fn test_cost_model_skip() {
721        let model = CostModel::new();
722        let skip = SkipOp {
723            count: 100,
724            input: Box::new(LogicalOperator::Empty),
725        };
726        let cost = model.skip_cost(&skip, 1000.0);
727
728        // Skip must scan through skipped rows
729        assert!(cost.cpu > 0.0);
730    }
731
732    #[test]
733    fn test_cost_model_return() {
734        let model = CostModel::new();
735        let ret = ReturnOp {
736            items: vec![
737                ReturnItem {
738                    expression: LogicalExpression::Variable("a".to_string()),
739                    alias: None,
740                },
741                ReturnItem {
742                    expression: LogicalExpression::Variable("b".to_string()),
743                    alias: None,
744                },
745            ],
746            distinct: false,
747            input: Box::new(LogicalOperator::Empty),
748        };
749        let cost = model.return_cost(&ret, 1000.0);
750
751        // Return materializes results
752        assert!(cost.cpu > 0.0);
753    }
754
755    #[test]
756    fn test_cost_cheaper() {
757        let model = CostModel::new();
758        let cheap = Cost::cpu(10.0);
759        let expensive = Cost::cpu(100.0);
760
761        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
762        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
763    }
764
765    #[test]
766    fn test_cost_comparison_prefers_lower_total() {
767        let model = CostModel::new();
768        // High CPU, low IO
769        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
770        // Low CPU, high IO
771        let io_heavy = Cost::cpu(10.0).with_io(20.0);
772
773        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
774        assert!(cpu_heavy.total() < io_heavy.total());
775        assert_eq!(
776            model.cheaper(&cpu_heavy, &io_heavy).total(),
777            cpu_heavy.total()
778        );
779    }
780
781    #[test]
782    fn test_cost_model_sort_with_keys() {
783        let model = CostModel::new();
784        let sort_single = SortOp {
785            keys: vec![crate::query::plan::SortKey {
786                expression: LogicalExpression::Variable("a".to_string()),
787                order: SortOrder::Ascending,
788            }],
789            input: Box::new(LogicalOperator::Empty),
790        };
791        let sort_multi = SortOp {
792            keys: vec![
793                crate::query::plan::SortKey {
794                    expression: LogicalExpression::Variable("a".to_string()),
795                    order: SortOrder::Ascending,
796                },
797                crate::query::plan::SortKey {
798                    expression: LogicalExpression::Variable("b".to_string()),
799                    order: SortOrder::Descending,
800                },
801            ],
802            input: Box::new(LogicalOperator::Empty),
803        };
804
805        let cost_single = model.sort_cost(&sort_single, 1000.0);
806        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
807
808        // More sort keys = more comparisons
809        assert!(cost_multi.cpu > cost_single.cpu);
810    }
811
812    #[test]
813    fn test_cost_model_empty_operator() {
814        let model = CostModel::new();
815        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
816        assert!((cost.total()).abs() < 0.001);
817    }
818
819    #[test]
820    fn test_cost_model_default() {
821        let model = CostModel::default();
822        let scan = NodeScanOp {
823            variable: "n".to_string(),
824            label: None,
825            input: None,
826        };
827        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
828        assert!(cost.total() > 0.0);
829    }
830
831    #[test]
832    fn test_leapfrog_join_cost() {
833        let model = CostModel::new();
834
835        // Three-way join (triangle pattern)
836        let cardinalities = vec![1000.0, 1000.0, 1000.0];
837        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
838
839        // Should have CPU cost for materialization and intersection
840        assert!(cost.cpu > 0.0);
841        // Should have memory cost for trie storage
842        assert!(cost.memory > 0.0);
843    }
844
845    #[test]
846    fn test_leapfrog_join_cost_empty() {
847        let model = CostModel::new();
848        let cost = model.leapfrog_join_cost(0, &[], 0.0);
849        assert!((cost.total()).abs() < 0.001);
850    }
851
852    #[test]
853    fn test_prefer_leapfrog_join_for_triangles() {
854        let model = CostModel::new();
855
856        // Compare costs for triangle pattern
857        let cardinalities = vec![10000.0, 10000.0, 10000.0];
858        let output = 1000.0;
859
860        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
861
862        // Leapfrog should have reasonable cost for triangle patterns
863        assert!(leapfrog_cost.cpu > 0.0);
864        assert!(leapfrog_cost.memory > 0.0);
865
866        // The prefer_leapfrog_join method compares against hash cascade
867        // Actual preference depends on specific cost parameters
868        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
869        // Test that it returns a boolean (doesn't panic)
870    }
871
872    #[test]
873    fn test_prefer_leapfrog_join_binary_case() {
874        let model = CostModel::new();
875
876        // Binary join should NOT prefer leapfrog (need 3+ relations)
877        let cardinalities = vec![1000.0, 1000.0];
878        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
879        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
880    }
881
882    #[test]
883    fn test_factorized_benefit_single_hop() {
884        let model = CostModel::new();
885
886        // Single hop: no factorization benefit
887        let benefit = model.factorized_benefit(10.0, 1);
888        assert!(
889            (benefit - 1.0).abs() < 0.001,
890            "Single hop should have no benefit"
891        );
892    }
893
894    #[test]
895    fn test_factorized_benefit_multi_hop() {
896        let model = CostModel::new();
897
898        // Multi-hop with high fanout
899        let benefit = model.factorized_benefit(10.0, 3);
900
901        // The factorized_benefit returns a ratio capped at 1.0
902        // For high fanout, factorized size / full size approaches 1/fanout
903        // which is beneficial but the formula gives a value <= 1.0
904        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
905        assert!(benefit > 0.0, "Benefit should be positive");
906    }
907
908    #[test]
909    fn test_factorized_benefit_low_fanout() {
910        let model = CostModel::new();
911
912        // Low fanout: minimal benefit
913        let benefit = model.factorized_benefit(1.5, 2);
914        assert!(
915            benefit <= 1.0,
916            "Low fanout still benefits from factorization"
917        );
918    }
919}