Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per hash table lookup.
104    hash_lookup_cost: f64,
105    /// Cost per comparison in sorting.
106    sort_comparison_cost: f64,
107    /// Average tuple size in bytes.
108    avg_tuple_size: f64,
109    /// Page size in bytes.
110    page_size: f64,
111    /// Average edge fanout (edges per node), from statistics.
112    avg_fanout: f64,
113}
114
115impl CostModel {
116    /// Creates a new cost model with default parameters.
117    #[must_use]
118    pub fn new() -> Self {
119        Self {
120            cpu_tuple_cost: 0.01,
121            hash_lookup_cost: 0.02,
122            sort_comparison_cost: 0.02,
123            avg_tuple_size: 100.0,
124            page_size: 8192.0,
125            avg_fanout: 10.0,
126        }
127    }
128
129    /// Creates a cost model with fanout derived from actual statistics.
130    #[must_use]
131    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
132        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
133        self
134    }
135
136    /// Estimates the cost of a logical operator.
137    #[must_use]
138    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
139        match op {
140            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
141            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
142            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
143            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
144            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
145            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
146            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
147            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
148            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
149            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
150            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
151            LogicalOperator::Empty => Cost::zero(),
152            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
153            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
154            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
155        }
156    }
157
158    /// Estimates the cost of a node scan.
159    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
160        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
161        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
162    }
163
164    /// Estimates the cost of a filter operation.
165    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
166        // Filter cost is just predicate evaluation per tuple
167        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
168    }
169
170    /// Estimates the cost of a projection.
171    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
172        // Cost depends on number of expressions evaluated
173        let expr_count = project.projections.len() as f64;
174        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
175    }
176
177    /// Estimates the cost of an expand operation.
178    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
179        // Expand involves adjacency list lookups
180        let lookup_cost = cardinality * self.hash_lookup_cost;
181        let output_cost = cardinality * self.avg_fanout * self.cpu_tuple_cost;
182        Cost::cpu(lookup_cost + output_cost)
183    }
184
185    /// Estimates the cost of a join operation.
186    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
187        // Cost depends on join type
188        match join.join_type {
189            JoinType::Cross => {
190                // Cross join is O(n * m)
191                Cost::cpu(cardinality * self.cpu_tuple_cost)
192            }
193            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
194                // Hash join: build phase + probe phase
195                // Assume left side is build, right side is probe
196                let build_cardinality = cardinality.sqrt(); // Rough estimate
197                let probe_cardinality = cardinality.sqrt();
198
199                // Build hash table
200                let build_cost = build_cardinality * self.hash_lookup_cost;
201                let memory_cost = build_cardinality * self.avg_tuple_size;
202
203                // Probe hash table
204                let probe_cost = probe_cardinality * self.hash_lookup_cost;
205
206                // Output cost
207                let output_cost = cardinality * self.cpu_tuple_cost;
208
209                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
210            }
211            JoinType::Semi | JoinType::Anti => {
212                // Semi/anti joins are typically cheaper
213                let build_cardinality = cardinality.sqrt();
214                let probe_cardinality = cardinality.sqrt();
215
216                let build_cost = build_cardinality * self.hash_lookup_cost;
217                let probe_cost = probe_cardinality * self.hash_lookup_cost;
218
219                Cost::cpu(build_cost + probe_cost)
220                    .with_memory(build_cardinality * self.avg_tuple_size)
221            }
222        }
223    }
224
225    /// Estimates the cost of an aggregation.
226    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
227        // Hash aggregation cost
228        let hash_cost = cardinality * self.hash_lookup_cost;
229
230        // Aggregate function evaluation
231        let agg_count = agg.aggregates.len() as f64;
232        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
233
234        // Memory for hash table (estimated distinct groups)
235        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
236        let memory_cost = distinct_groups * self.avg_tuple_size;
237
238        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
239    }
240
241    /// Estimates the cost of a sort operation.
242    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
243        if cardinality <= 1.0 {
244            return Cost::zero();
245        }
246
247        // Sort is O(n log n) comparisons
248        let comparisons = cardinality * cardinality.log2();
249        let key_count = sort.keys.len() as f64;
250
251        // Memory for sorting (full input materialization)
252        let memory_cost = cardinality * self.avg_tuple_size;
253
254        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
255    }
256
257    /// Estimates the cost of a distinct operation.
258    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
259        // Hash-based distinct
260        let hash_cost = cardinality * self.hash_lookup_cost;
261        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
262
263        Cost::cpu(hash_cost).with_memory(memory_cost)
264    }
265
266    /// Estimates the cost of a limit operation.
267    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
268        // Limit is very cheap - just counting
269        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
270    }
271
272    /// Estimates the cost of a skip operation.
273    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
274        // Skip requires scanning through skipped rows
275        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
276    }
277
278    /// Estimates the cost of a return operation.
279    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
280        // Return materializes results
281        let expr_count = ret.items.len() as f64;
282        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
283    }
284
285    /// Estimates the cost of a vector scan operation.
286    ///
287    /// HNSW index search is O(log N) per query, while brute-force is O(N).
288    /// This estimates the HNSW case with ef search parameter.
289    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
290        // k determines output cardinality
291        let k = scan.k as f64;
292
293        // HNSW search cost: O(ef * log(N)) distance computations
294        // Assume ef = 64 (default), N = cardinality
295        let ef = 64.0;
296        let n = cardinality.max(1.0);
297        let search_cost = if scan.index_name.is_some() {
298            // HNSW: O(ef * log N)
299            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
300        } else {
301            // Brute-force: O(N)
302            n * self.cpu_tuple_cost * 10.0
303        };
304
305        // Memory for candidate heap
306        let memory = k * self.avg_tuple_size * 2.0;
307
308        Cost::cpu(search_cost).with_memory(memory)
309    }
310
311    /// Estimates the cost of a vector join operation.
312    ///
313    /// Vector join performs k-NN search for each input row.
314    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
315        let k = join.k as f64;
316
317        // Each input row triggers a vector search
318        // Assume brute-force for hybrid queries (no index specified typically)
319        let per_row_search_cost = if join.index_name.is_some() {
320            // HNSW: O(ef * log N)
321            let ef = 64.0;
322            let n = cardinality.max(1.0);
323            ef * n.ln() * self.cpu_tuple_cost * 10.0
324        } else {
325            // Brute-force: O(N) per input row
326            cardinality * self.cpu_tuple_cost * 10.0
327        };
328
329        // Total cost: input_rows * search_cost
330        // For vector join, cardinality is typically input cardinality * k
331        let input_cardinality = (cardinality / k).max(1.0);
332        let total_search_cost = input_cardinality * per_row_search_cost;
333
334        // Memory for results
335        let memory = cardinality * self.avg_tuple_size;
336
337        Cost::cpu(total_search_cost).with_memory(memory)
338    }
339
340    /// Compares two costs and returns the cheaper one.
341    #[must_use]
342    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
343        if a.total() <= b.total() { a } else { b }
344    }
345
346    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
347    ///
348    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
349    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
350    /// all relations simultaneously using sorted iterators.
351    ///
352    /// # Arguments
353    /// * `num_relations` - Number of relations participating in the join
354    /// * `cardinalities` - Cardinality of each input relation
355    /// * `output_cardinality` - Expected output cardinality
356    ///
357    /// # Cost Model
358    /// - Materialization: O(sum of cardinalities) to build trie indexes
359    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
360    /// - Memory: Trie storage for all inputs
361    #[must_use]
362    pub fn leapfrog_join_cost(
363        &self,
364        num_relations: usize,
365        cardinalities: &[f64],
366        output_cardinality: f64,
367    ) -> Cost {
368        if cardinalities.is_empty() {
369            return Cost::zero();
370        }
371
372        let total_input: f64 = cardinalities.iter().sum();
373        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
374
375        // Materialization phase: build trie indexes for each input
376        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
377
378        // Intersection phase: leapfrog seeks are O(log n) per relation
379        let seek_cost = if min_card > 1.0 {
380            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
381        } else {
382            output_cardinality * self.cpu_tuple_cost
383        };
384
385        // Output materialization
386        let output_cost = output_cardinality * self.cpu_tuple_cost;
387
388        // Memory: trie storage (roughly 2x input size for sorted index)
389        let memory = total_input * self.avg_tuple_size * 2.0;
390
391        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
392    }
393
394    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
395    ///
396    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
397    #[must_use]
398    pub fn prefer_leapfrog_join(
399        &self,
400        num_relations: usize,
401        cardinalities: &[f64],
402        output_cardinality: f64,
403    ) -> bool {
404        if num_relations < 3 || cardinalities.len() < 3 {
405            // Leapfrog is only beneficial for multi-way joins (3+)
406            return false;
407        }
408
409        let leapfrog_cost =
410            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
411
412        // Estimate cascade of binary hash joins
413        // For N relations, we need N-1 joins
414        // Each join produces intermediate results that feed the next
415        let mut hash_cascade_cost = Cost::zero();
416        let mut intermediate_cardinality = cardinalities[0];
417
418        for card in &cardinalities[1..] {
419            // Hash join cost: build + probe + output
420            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
421            let join = JoinOp {
422                left: Box::new(LogicalOperator::Empty),
423                right: Box::new(LogicalOperator::Empty),
424                join_type: JoinType::Inner,
425                conditions: vec![],
426            };
427            hash_cascade_cost += self.join_cost(&join, join_output);
428            intermediate_cardinality = join_output;
429        }
430
431        leapfrog_cost.total() < hash_cascade_cost.total()
432    }
433
434    /// Estimates cost for factorized execution (compressed intermediate results).
435    ///
436    /// Factorized execution avoids materializing full cross products by keeping
437    /// results in a compressed "factorized" form. This is beneficial for multi-hop
438    /// traversals where intermediate results can explode.
439    ///
440    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
441    #[must_use]
442    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
443        if num_hops <= 1 || avg_fanout <= 1.0 {
444            return 1.0; // No benefit for single hop or low fanout
445        }
446
447        // Factorized representation compresses repeated prefixes
448        // Compression ratio improves with higher fanout and more hops
449        // Full materialization: fanout^hops
450        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
451
452        let full_size = avg_fanout.powi(num_hops as i32);
453        let factorized_size = if avg_fanout > 1.0 {
454            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
455        } else {
456            num_hops as f64
457        };
458
459        (factorized_size / full_size).min(1.0)
460    }
461}
462
463impl Default for CostModel {
464    fn default() -> Self {
465        Self::new()
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::query::plan::{
473        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
474        Projection, ReturnItem, SortOrder,
475    };
476
477    #[test]
478    fn test_cost_addition() {
479        let a = Cost::cpu(10.0).with_io(5.0);
480        let b = Cost::cpu(20.0).with_memory(100.0);
481        let c = a + b;
482
483        assert!((c.cpu - 30.0).abs() < 0.001);
484        assert!((c.io - 5.0).abs() < 0.001);
485        assert!((c.memory - 100.0).abs() < 0.001);
486    }
487
488    #[test]
489    fn test_cost_total() {
490        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
491        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
492        assert!((cost.total() - 30.0).abs() < 0.001);
493    }
494
495    #[test]
496    fn test_cost_model_node_scan() {
497        let model = CostModel::new();
498        let scan = NodeScanOp {
499            variable: "n".to_string(),
500            label: Some("Person".to_string()),
501            input: None,
502        };
503        let cost = model.node_scan_cost(&scan, 1000.0);
504
505        assert!(cost.cpu > 0.0);
506        assert!(cost.io > 0.0);
507    }
508
509    #[test]
510    fn test_cost_model_sort() {
511        let model = CostModel::new();
512        let sort = SortOp {
513            keys: vec![],
514            input: Box::new(LogicalOperator::Empty),
515        };
516
517        let cost_100 = model.sort_cost(&sort, 100.0);
518        let cost_1000 = model.sort_cost(&sort, 1000.0);
519
520        // Sorting 1000 rows should be more expensive than 100 rows
521        assert!(cost_1000.total() > cost_100.total());
522    }
523
524    #[test]
525    fn test_cost_zero() {
526        let cost = Cost::zero();
527        assert!((cost.cpu).abs() < 0.001);
528        assert!((cost.io).abs() < 0.001);
529        assert!((cost.memory).abs() < 0.001);
530        assert!((cost.network).abs() < 0.001);
531        assert!((cost.total()).abs() < 0.001);
532    }
533
534    #[test]
535    fn test_cost_add_assign() {
536        let mut cost = Cost::cpu(10.0);
537        cost += Cost::cpu(5.0).with_io(2.0);
538        assert!((cost.cpu - 15.0).abs() < 0.001);
539        assert!((cost.io - 2.0).abs() < 0.001);
540    }
541
542    #[test]
543    fn test_cost_total_weighted() {
544        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
545        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
546        let total = cost.total_weighted(2.0, 5.0, 0.5);
547        assert!((total - 80.0).abs() < 0.001);
548    }
549
550    #[test]
551    fn test_cost_model_filter() {
552        let model = CostModel::new();
553        let filter = FilterOp {
554            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
555            input: Box::new(LogicalOperator::Empty),
556        };
557        let cost = model.filter_cost(&filter, 1000.0);
558
559        // Filter cost is CPU only
560        assert!(cost.cpu > 0.0);
561        assert!((cost.io).abs() < 0.001);
562    }
563
564    #[test]
565    fn test_cost_model_project() {
566        let model = CostModel::new();
567        let project = ProjectOp {
568            projections: vec![
569                Projection {
570                    expression: LogicalExpression::Variable("a".to_string()),
571                    alias: None,
572                },
573                Projection {
574                    expression: LogicalExpression::Variable("b".to_string()),
575                    alias: None,
576                },
577            ],
578            input: Box::new(LogicalOperator::Empty),
579        };
580        let cost = model.project_cost(&project, 1000.0);
581
582        // Cost should scale with number of projections
583        assert!(cost.cpu > 0.0);
584    }
585
586    #[test]
587    fn test_cost_model_expand() {
588        let model = CostModel::new();
589        let expand = ExpandOp {
590            from_variable: "a".to_string(),
591            to_variable: "b".to_string(),
592            edge_variable: None,
593            direction: ExpandDirection::Outgoing,
594            edge_type: None,
595            min_hops: 1,
596            max_hops: Some(1),
597            input: Box::new(LogicalOperator::Empty),
598            path_alias: None,
599        };
600        let cost = model.expand_cost(&expand, 1000.0);
601
602        // Expand involves hash lookups and output generation
603        assert!(cost.cpu > 0.0);
604    }
605
606    #[test]
607    fn test_cost_model_hash_join() {
608        let model = CostModel::new();
609        let join = JoinOp {
610            left: Box::new(LogicalOperator::Empty),
611            right: Box::new(LogicalOperator::Empty),
612            join_type: JoinType::Inner,
613            conditions: vec![JoinCondition {
614                left: LogicalExpression::Variable("a".to_string()),
615                right: LogicalExpression::Variable("b".to_string()),
616            }],
617        };
618        let cost = model.join_cost(&join, 10000.0);
619
620        // Hash join has CPU cost and memory cost
621        assert!(cost.cpu > 0.0);
622        assert!(cost.memory > 0.0);
623    }
624
625    #[test]
626    fn test_cost_model_cross_join() {
627        let model = CostModel::new();
628        let join = JoinOp {
629            left: Box::new(LogicalOperator::Empty),
630            right: Box::new(LogicalOperator::Empty),
631            join_type: JoinType::Cross,
632            conditions: vec![],
633        };
634        let cost = model.join_cost(&join, 1000000.0);
635
636        // Cross join is expensive
637        assert!(cost.cpu > 0.0);
638    }
639
640    #[test]
641    fn test_cost_model_semi_join() {
642        let model = CostModel::new();
643        let join = JoinOp {
644            left: Box::new(LogicalOperator::Empty),
645            right: Box::new(LogicalOperator::Empty),
646            join_type: JoinType::Semi,
647            conditions: vec![],
648        };
649        let cost_semi = model.join_cost(&join, 1000.0);
650
651        let inner_join = JoinOp {
652            left: Box::new(LogicalOperator::Empty),
653            right: Box::new(LogicalOperator::Empty),
654            join_type: JoinType::Inner,
655            conditions: vec![],
656        };
657        let cost_inner = model.join_cost(&inner_join, 1000.0);
658
659        // Semi join can be cheaper than inner join
660        assert!(cost_semi.cpu > 0.0);
661        assert!(cost_inner.cpu > 0.0);
662    }
663
664    #[test]
665    fn test_cost_model_aggregate() {
666        let model = CostModel::new();
667        let agg = AggregateOp {
668            group_by: vec![],
669            aggregates: vec![
670                AggregateExpr {
671                    function: AggregateFunction::Count,
672                    expression: None,
673                    distinct: false,
674                    alias: Some("cnt".to_string()),
675                    percentile: None,
676                },
677                AggregateExpr {
678                    function: AggregateFunction::Sum,
679                    expression: Some(LogicalExpression::Variable("x".to_string())),
680                    distinct: false,
681                    alias: Some("total".to_string()),
682                    percentile: None,
683                },
684            ],
685            input: Box::new(LogicalOperator::Empty),
686            having: None,
687        };
688        let cost = model.aggregate_cost(&agg, 1000.0);
689
690        // Aggregation has hash cost and memory cost
691        assert!(cost.cpu > 0.0);
692        assert!(cost.memory > 0.0);
693    }
694
695    #[test]
696    fn test_cost_model_distinct() {
697        let model = CostModel::new();
698        let distinct = DistinctOp {
699            input: Box::new(LogicalOperator::Empty),
700            columns: None,
701        };
702        let cost = model.distinct_cost(&distinct, 1000.0);
703
704        // Distinct uses hash set
705        assert!(cost.cpu > 0.0);
706        assert!(cost.memory > 0.0);
707    }
708
709    #[test]
710    fn test_cost_model_limit() {
711        let model = CostModel::new();
712        let limit = LimitOp {
713            count: 10,
714            input: Box::new(LogicalOperator::Empty),
715        };
716        let cost = model.limit_cost(&limit, 1000.0);
717
718        // Limit is very cheap
719        assert!(cost.cpu > 0.0);
720        assert!(cost.cpu < 1.0); // Should be minimal
721    }
722
723    #[test]
724    fn test_cost_model_skip() {
725        let model = CostModel::new();
726        let skip = SkipOp {
727            count: 100,
728            input: Box::new(LogicalOperator::Empty),
729        };
730        let cost = model.skip_cost(&skip, 1000.0);
731
732        // Skip must scan through skipped rows
733        assert!(cost.cpu > 0.0);
734    }
735
736    #[test]
737    fn test_cost_model_return() {
738        let model = CostModel::new();
739        let ret = ReturnOp {
740            items: vec![
741                ReturnItem {
742                    expression: LogicalExpression::Variable("a".to_string()),
743                    alias: None,
744                },
745                ReturnItem {
746                    expression: LogicalExpression::Variable("b".to_string()),
747                    alias: None,
748                },
749            ],
750            distinct: false,
751            input: Box::new(LogicalOperator::Empty),
752        };
753        let cost = model.return_cost(&ret, 1000.0);
754
755        // Return materializes results
756        assert!(cost.cpu > 0.0);
757    }
758
759    #[test]
760    fn test_cost_cheaper() {
761        let model = CostModel::new();
762        let cheap = Cost::cpu(10.0);
763        let expensive = Cost::cpu(100.0);
764
765        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
766        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
767    }
768
769    #[test]
770    fn test_cost_comparison_prefers_lower_total() {
771        let model = CostModel::new();
772        // High CPU, low IO
773        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
774        // Low CPU, high IO
775        let io_heavy = Cost::cpu(10.0).with_io(20.0);
776
777        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
778        assert!(cpu_heavy.total() < io_heavy.total());
779        assert_eq!(
780            model.cheaper(&cpu_heavy, &io_heavy).total(),
781            cpu_heavy.total()
782        );
783    }
784
785    #[test]
786    fn test_cost_model_sort_with_keys() {
787        let model = CostModel::new();
788        let sort_single = SortOp {
789            keys: vec![crate::query::plan::SortKey {
790                expression: LogicalExpression::Variable("a".to_string()),
791                order: SortOrder::Ascending,
792            }],
793            input: Box::new(LogicalOperator::Empty),
794        };
795        let sort_multi = SortOp {
796            keys: vec![
797                crate::query::plan::SortKey {
798                    expression: LogicalExpression::Variable("a".to_string()),
799                    order: SortOrder::Ascending,
800                },
801                crate::query::plan::SortKey {
802                    expression: LogicalExpression::Variable("b".to_string()),
803                    order: SortOrder::Descending,
804                },
805            ],
806            input: Box::new(LogicalOperator::Empty),
807        };
808
809        let cost_single = model.sort_cost(&sort_single, 1000.0);
810        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
811
812        // More sort keys = more comparisons
813        assert!(cost_multi.cpu > cost_single.cpu);
814    }
815
816    #[test]
817    fn test_cost_model_empty_operator() {
818        let model = CostModel::new();
819        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
820        assert!((cost.total()).abs() < 0.001);
821    }
822
823    #[test]
824    fn test_cost_model_default() {
825        let model = CostModel::default();
826        let scan = NodeScanOp {
827            variable: "n".to_string(),
828            label: None,
829            input: None,
830        };
831        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
832        assert!(cost.total() > 0.0);
833    }
834
835    #[test]
836    fn test_leapfrog_join_cost() {
837        let model = CostModel::new();
838
839        // Three-way join (triangle pattern)
840        let cardinalities = vec![1000.0, 1000.0, 1000.0];
841        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
842
843        // Should have CPU cost for materialization and intersection
844        assert!(cost.cpu > 0.0);
845        // Should have memory cost for trie storage
846        assert!(cost.memory > 0.0);
847    }
848
849    #[test]
850    fn test_leapfrog_join_cost_empty() {
851        let model = CostModel::new();
852        let cost = model.leapfrog_join_cost(0, &[], 0.0);
853        assert!((cost.total()).abs() < 0.001);
854    }
855
856    #[test]
857    fn test_prefer_leapfrog_join_for_triangles() {
858        let model = CostModel::new();
859
860        // Compare costs for triangle pattern
861        let cardinalities = vec![10000.0, 10000.0, 10000.0];
862        let output = 1000.0;
863
864        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
865
866        // Leapfrog should have reasonable cost for triangle patterns
867        assert!(leapfrog_cost.cpu > 0.0);
868        assert!(leapfrog_cost.memory > 0.0);
869
870        // The prefer_leapfrog_join method compares against hash cascade
871        // Actual preference depends on specific cost parameters
872        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
873        // Test that it returns a boolean (doesn't panic)
874    }
875
876    #[test]
877    fn test_prefer_leapfrog_join_binary_case() {
878        let model = CostModel::new();
879
880        // Binary join should NOT prefer leapfrog (need 3+ relations)
881        let cardinalities = vec![1000.0, 1000.0];
882        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
883        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
884    }
885
886    #[test]
887    fn test_factorized_benefit_single_hop() {
888        let model = CostModel::new();
889
890        // Single hop: no factorization benefit
891        let benefit = model.factorized_benefit(10.0, 1);
892        assert!(
893            (benefit - 1.0).abs() < 0.001,
894            "Single hop should have no benefit"
895        );
896    }
897
898    #[test]
899    fn test_factorized_benefit_multi_hop() {
900        let model = CostModel::new();
901
902        // Multi-hop with high fanout
903        let benefit = model.factorized_benefit(10.0, 3);
904
905        // The factorized_benefit returns a ratio capped at 1.0
906        // For high fanout, factorized size / full size approaches 1/fanout
907        // which is beneficial but the formula gives a value <= 1.0
908        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
909        assert!(benefit > 0.0, "Benefit should be positive");
910    }
911
912    #[test]
913    fn test_factorized_benefit_low_fanout() {
914        let model = CostModel::new();
915
916        // Low fanout: minimal benefit
917        let benefit = model.factorized_benefit(1.5, 2);
918        assert!(
919            benefit <= 1.0,
920            "Low fanout still benefits from factorization"
921        );
922    }
923}