Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7    NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100pub struct CostModel {
101    /// Cost per tuple processed by CPU.
102    cpu_tuple_cost: f64,
103    /// Cost per hash table lookup.
104    hash_lookup_cost: f64,
105    /// Cost per comparison in sorting.
106    sort_comparison_cost: f64,
107    /// Average tuple size in bytes.
108    avg_tuple_size: f64,
109    /// Page size in bytes.
110    page_size: f64,
111}
112
113impl CostModel {
114    /// Creates a new cost model with default parameters.
115    #[must_use]
116    pub fn new() -> Self {
117        Self {
118            cpu_tuple_cost: 0.01,
119            hash_lookup_cost: 0.02,
120            sort_comparison_cost: 0.02,
121            avg_tuple_size: 100.0,
122            page_size: 8192.0,
123        }
124    }
125
126    /// Estimates the cost of a logical operator.
127    #[must_use]
128    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
129        match op {
130            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
131            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
132            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
133            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
134            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
135            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
136            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
137            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
138            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
139            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
140            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
141            LogicalOperator::Empty => Cost::zero(),
142            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
143            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
144            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
145        }
146    }
147
148    /// Estimates the cost of a node scan.
149    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
150        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
151        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
152    }
153
154    /// Estimates the cost of a filter operation.
155    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
156        // Filter cost is just predicate evaluation per tuple
157        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
158    }
159
160    /// Estimates the cost of a projection.
161    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
162        // Cost depends on number of expressions evaluated
163        let expr_count = project.projections.len() as f64;
164        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
165    }
166
167    /// Estimates the cost of an expand operation.
168    fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
169        // Expand involves adjacency list lookups
170        let lookup_cost = cardinality * self.hash_lookup_cost;
171        // Assume average fanout of 10 for edge traversal
172        let avg_fanout = 10.0;
173        let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
174        Cost::cpu(lookup_cost + output_cost)
175    }
176
177    /// Estimates the cost of a join operation.
178    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
179        // Cost depends on join type
180        match join.join_type {
181            JoinType::Cross => {
182                // Cross join is O(n * m)
183                Cost::cpu(cardinality * self.cpu_tuple_cost)
184            }
185            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
186                // Hash join: build phase + probe phase
187                // Assume left side is build, right side is probe
188                let build_cardinality = cardinality.sqrt(); // Rough estimate
189                let probe_cardinality = cardinality.sqrt();
190
191                // Build hash table
192                let build_cost = build_cardinality * self.hash_lookup_cost;
193                let memory_cost = build_cardinality * self.avg_tuple_size;
194
195                // Probe hash table
196                let probe_cost = probe_cardinality * self.hash_lookup_cost;
197
198                // Output cost
199                let output_cost = cardinality * self.cpu_tuple_cost;
200
201                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
202            }
203            JoinType::Semi | JoinType::Anti => {
204                // Semi/anti joins are typically cheaper
205                let build_cardinality = cardinality.sqrt();
206                let probe_cardinality = cardinality.sqrt();
207
208                let build_cost = build_cardinality * self.hash_lookup_cost;
209                let probe_cost = probe_cardinality * self.hash_lookup_cost;
210
211                Cost::cpu(build_cost + probe_cost)
212                    .with_memory(build_cardinality * self.avg_tuple_size)
213            }
214        }
215    }
216
217    /// Estimates the cost of an aggregation.
218    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
219        // Hash aggregation cost
220        let hash_cost = cardinality * self.hash_lookup_cost;
221
222        // Aggregate function evaluation
223        let agg_count = agg.aggregates.len() as f64;
224        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
225
226        // Memory for hash table (estimated distinct groups)
227        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
228        let memory_cost = distinct_groups * self.avg_tuple_size;
229
230        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
231    }
232
233    /// Estimates the cost of a sort operation.
234    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
235        if cardinality <= 1.0 {
236            return Cost::zero();
237        }
238
239        // Sort is O(n log n) comparisons
240        let comparisons = cardinality * cardinality.log2();
241        let key_count = sort.keys.len() as f64;
242
243        // Memory for sorting (full input materialization)
244        let memory_cost = cardinality * self.avg_tuple_size;
245
246        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
247    }
248
249    /// Estimates the cost of a distinct operation.
250    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
251        // Hash-based distinct
252        let hash_cost = cardinality * self.hash_lookup_cost;
253        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
254
255        Cost::cpu(hash_cost).with_memory(memory_cost)
256    }
257
258    /// Estimates the cost of a limit operation.
259    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
260        // Limit is very cheap - just counting
261        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
262    }
263
264    /// Estimates the cost of a skip operation.
265    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
266        // Skip requires scanning through skipped rows
267        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
268    }
269
270    /// Estimates the cost of a return operation.
271    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
272        // Return materializes results
273        let expr_count = ret.items.len() as f64;
274        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
275    }
276
277    /// Estimates the cost of a vector scan operation.
278    ///
279    /// HNSW index search is O(log N) per query, while brute-force is O(N).
280    /// This estimates the HNSW case with ef search parameter.
281    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
282        // k determines output cardinality
283        let k = scan.k as f64;
284
285        // HNSW search cost: O(ef * log(N)) distance computations
286        // Assume ef = 64 (default), N = cardinality
287        let ef = 64.0;
288        let n = cardinality.max(1.0);
289        let search_cost = if scan.index_name.is_some() {
290            // HNSW: O(ef * log N)
291            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
292        } else {
293            // Brute-force: O(N)
294            n * self.cpu_tuple_cost * 10.0
295        };
296
297        // Memory for candidate heap
298        let memory = k * self.avg_tuple_size * 2.0;
299
300        Cost::cpu(search_cost).with_memory(memory)
301    }
302
303    /// Estimates the cost of a vector join operation.
304    ///
305    /// Vector join performs k-NN search for each input row.
306    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
307        let k = join.k as f64;
308
309        // Each input row triggers a vector search
310        // Assume brute-force for hybrid queries (no index specified typically)
311        let per_row_search_cost = if join.index_name.is_some() {
312            // HNSW: O(ef * log N)
313            let ef = 64.0;
314            let n = cardinality.max(1.0);
315            ef * n.ln() * self.cpu_tuple_cost * 10.0
316        } else {
317            // Brute-force: O(N) per input row
318            cardinality * self.cpu_tuple_cost * 10.0
319        };
320
321        // Total cost: input_rows * search_cost
322        // For vector join, cardinality is typically input cardinality * k
323        let input_cardinality = (cardinality / k).max(1.0);
324        let total_search_cost = input_cardinality * per_row_search_cost;
325
326        // Memory for results
327        let memory = cardinality * self.avg_tuple_size;
328
329        Cost::cpu(total_search_cost).with_memory(memory)
330    }
331
332    /// Compares two costs and returns the cheaper one.
333    #[must_use]
334    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
335        if a.total() <= b.total() { a } else { b }
336    }
337
338    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
339    ///
340    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
341    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
342    /// all relations simultaneously using sorted iterators.
343    ///
344    /// # Arguments
345    /// * `num_relations` - Number of relations participating in the join
346    /// * `cardinalities` - Cardinality of each input relation
347    /// * `output_cardinality` - Expected output cardinality
348    ///
349    /// # Cost Model
350    /// - Materialization: O(sum of cardinalities) to build trie indexes
351    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
352    /// - Memory: Trie storage for all inputs
353    #[must_use]
354    pub fn leapfrog_join_cost(
355        &self,
356        num_relations: usize,
357        cardinalities: &[f64],
358        output_cardinality: f64,
359    ) -> Cost {
360        if cardinalities.is_empty() {
361            return Cost::zero();
362        }
363
364        let total_input: f64 = cardinalities.iter().sum();
365        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
366
367        // Materialization phase: build trie indexes for each input
368        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
369
370        // Intersection phase: leapfrog seeks are O(log n) per relation
371        let seek_cost = if min_card > 1.0 {
372            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
373        } else {
374            output_cardinality * self.cpu_tuple_cost
375        };
376
377        // Output materialization
378        let output_cost = output_cardinality * self.cpu_tuple_cost;
379
380        // Memory: trie storage (roughly 2x input size for sorted index)
381        let memory = total_input * self.avg_tuple_size * 2.0;
382
383        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
384    }
385
386    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
387    ///
388    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
389    #[must_use]
390    pub fn prefer_leapfrog_join(
391        &self,
392        num_relations: usize,
393        cardinalities: &[f64],
394        output_cardinality: f64,
395    ) -> bool {
396        if num_relations < 3 || cardinalities.len() < 3 {
397            // Leapfrog is only beneficial for multi-way joins (3+)
398            return false;
399        }
400
401        let leapfrog_cost =
402            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
403
404        // Estimate cascade of binary hash joins
405        // For N relations, we need N-1 joins
406        // Each join produces intermediate results that feed the next
407        let mut hash_cascade_cost = Cost::zero();
408        let mut intermediate_cardinality = cardinalities[0];
409
410        for card in &cardinalities[1..] {
411            // Hash join cost: build + probe + output
412            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
413            let join = JoinOp {
414                left: Box::new(LogicalOperator::Empty),
415                right: Box::new(LogicalOperator::Empty),
416                join_type: JoinType::Inner,
417                conditions: vec![],
418            };
419            hash_cascade_cost += self.join_cost(&join, join_output);
420            intermediate_cardinality = join_output;
421        }
422
423        leapfrog_cost.total() < hash_cascade_cost.total()
424    }
425
426    /// Estimates cost for factorized execution (compressed intermediate results).
427    ///
428    /// Factorized execution avoids materializing full cross products by keeping
429    /// results in a compressed "factorized" form. This is beneficial for multi-hop
430    /// traversals where intermediate results can explode.
431    ///
432    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
433    #[must_use]
434    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
435        if num_hops <= 1 || avg_fanout <= 1.0 {
436            return 1.0; // No benefit for single hop or low fanout
437        }
438
439        // Factorized representation compresses repeated prefixes
440        // Compression ratio improves with higher fanout and more hops
441        // Full materialization: fanout^hops
442        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
443
444        let full_size = avg_fanout.powi(num_hops as i32);
445        let factorized_size = if avg_fanout > 1.0 {
446            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
447        } else {
448            num_hops as f64
449        };
450
451        (factorized_size / full_size).min(1.0)
452    }
453}
454
455impl Default for CostModel {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use crate::query::plan::{
465        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
466        Projection, ReturnItem, SortOrder,
467    };
468
469    #[test]
470    fn test_cost_addition() {
471        let a = Cost::cpu(10.0).with_io(5.0);
472        let b = Cost::cpu(20.0).with_memory(100.0);
473        let c = a + b;
474
475        assert!((c.cpu - 30.0).abs() < 0.001);
476        assert!((c.io - 5.0).abs() < 0.001);
477        assert!((c.memory - 100.0).abs() < 0.001);
478    }
479
480    #[test]
481    fn test_cost_total() {
482        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
483        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
484        assert!((cost.total() - 30.0).abs() < 0.001);
485    }
486
487    #[test]
488    fn test_cost_model_node_scan() {
489        let model = CostModel::new();
490        let scan = NodeScanOp {
491            variable: "n".to_string(),
492            label: Some("Person".to_string()),
493            input: None,
494        };
495        let cost = model.node_scan_cost(&scan, 1000.0);
496
497        assert!(cost.cpu > 0.0);
498        assert!(cost.io > 0.0);
499    }
500
501    #[test]
502    fn test_cost_model_sort() {
503        let model = CostModel::new();
504        let sort = SortOp {
505            keys: vec![],
506            input: Box::new(LogicalOperator::Empty),
507        };
508
509        let cost_100 = model.sort_cost(&sort, 100.0);
510        let cost_1000 = model.sort_cost(&sort, 1000.0);
511
512        // Sorting 1000 rows should be more expensive than 100 rows
513        assert!(cost_1000.total() > cost_100.total());
514    }
515
516    #[test]
517    fn test_cost_zero() {
518        let cost = Cost::zero();
519        assert!((cost.cpu).abs() < 0.001);
520        assert!((cost.io).abs() < 0.001);
521        assert!((cost.memory).abs() < 0.001);
522        assert!((cost.network).abs() < 0.001);
523        assert!((cost.total()).abs() < 0.001);
524    }
525
526    #[test]
527    fn test_cost_add_assign() {
528        let mut cost = Cost::cpu(10.0);
529        cost += Cost::cpu(5.0).with_io(2.0);
530        assert!((cost.cpu - 15.0).abs() < 0.001);
531        assert!((cost.io - 2.0).abs() < 0.001);
532    }
533
534    #[test]
535    fn test_cost_total_weighted() {
536        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
537        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
538        let total = cost.total_weighted(2.0, 5.0, 0.5);
539        assert!((total - 80.0).abs() < 0.001);
540    }
541
542    #[test]
543    fn test_cost_model_filter() {
544        let model = CostModel::new();
545        let filter = FilterOp {
546            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
547            input: Box::new(LogicalOperator::Empty),
548        };
549        let cost = model.filter_cost(&filter, 1000.0);
550
551        // Filter cost is CPU only
552        assert!(cost.cpu > 0.0);
553        assert!((cost.io).abs() < 0.001);
554    }
555
556    #[test]
557    fn test_cost_model_project() {
558        let model = CostModel::new();
559        let project = ProjectOp {
560            projections: vec![
561                Projection {
562                    expression: LogicalExpression::Variable("a".to_string()),
563                    alias: None,
564                },
565                Projection {
566                    expression: LogicalExpression::Variable("b".to_string()),
567                    alias: None,
568                },
569            ],
570            input: Box::new(LogicalOperator::Empty),
571        };
572        let cost = model.project_cost(&project, 1000.0);
573
574        // Cost should scale with number of projections
575        assert!(cost.cpu > 0.0);
576    }
577
578    #[test]
579    fn test_cost_model_expand() {
580        let model = CostModel::new();
581        let expand = ExpandOp {
582            from_variable: "a".to_string(),
583            to_variable: "b".to_string(),
584            edge_variable: None,
585            direction: ExpandDirection::Outgoing,
586            edge_type: None,
587            min_hops: 1,
588            max_hops: Some(1),
589            input: Box::new(LogicalOperator::Empty),
590            path_alias: None,
591        };
592        let cost = model.expand_cost(&expand, 1000.0);
593
594        // Expand involves hash lookups and output generation
595        assert!(cost.cpu > 0.0);
596    }
597
598    #[test]
599    fn test_cost_model_hash_join() {
600        let model = CostModel::new();
601        let join = JoinOp {
602            left: Box::new(LogicalOperator::Empty),
603            right: Box::new(LogicalOperator::Empty),
604            join_type: JoinType::Inner,
605            conditions: vec![JoinCondition {
606                left: LogicalExpression::Variable("a".to_string()),
607                right: LogicalExpression::Variable("b".to_string()),
608            }],
609        };
610        let cost = model.join_cost(&join, 10000.0);
611
612        // Hash join has CPU cost and memory cost
613        assert!(cost.cpu > 0.0);
614        assert!(cost.memory > 0.0);
615    }
616
617    #[test]
618    fn test_cost_model_cross_join() {
619        let model = CostModel::new();
620        let join = JoinOp {
621            left: Box::new(LogicalOperator::Empty),
622            right: Box::new(LogicalOperator::Empty),
623            join_type: JoinType::Cross,
624            conditions: vec![],
625        };
626        let cost = model.join_cost(&join, 1000000.0);
627
628        // Cross join is expensive
629        assert!(cost.cpu > 0.0);
630    }
631
632    #[test]
633    fn test_cost_model_semi_join() {
634        let model = CostModel::new();
635        let join = JoinOp {
636            left: Box::new(LogicalOperator::Empty),
637            right: Box::new(LogicalOperator::Empty),
638            join_type: JoinType::Semi,
639            conditions: vec![],
640        };
641        let cost_semi = model.join_cost(&join, 1000.0);
642
643        let inner_join = JoinOp {
644            left: Box::new(LogicalOperator::Empty),
645            right: Box::new(LogicalOperator::Empty),
646            join_type: JoinType::Inner,
647            conditions: vec![],
648        };
649        let cost_inner = model.join_cost(&inner_join, 1000.0);
650
651        // Semi join can be cheaper than inner join
652        assert!(cost_semi.cpu > 0.0);
653        assert!(cost_inner.cpu > 0.0);
654    }
655
656    #[test]
657    fn test_cost_model_aggregate() {
658        let model = CostModel::new();
659        let agg = AggregateOp {
660            group_by: vec![],
661            aggregates: vec![
662                AggregateExpr {
663                    function: AggregateFunction::Count,
664                    expression: None,
665                    distinct: false,
666                    alias: Some("cnt".to_string()),
667                    percentile: None,
668                },
669                AggregateExpr {
670                    function: AggregateFunction::Sum,
671                    expression: Some(LogicalExpression::Variable("x".to_string())),
672                    distinct: false,
673                    alias: Some("total".to_string()),
674                    percentile: None,
675                },
676            ],
677            input: Box::new(LogicalOperator::Empty),
678            having: None,
679        };
680        let cost = model.aggregate_cost(&agg, 1000.0);
681
682        // Aggregation has hash cost and memory cost
683        assert!(cost.cpu > 0.0);
684        assert!(cost.memory > 0.0);
685    }
686
687    #[test]
688    fn test_cost_model_distinct() {
689        let model = CostModel::new();
690        let distinct = DistinctOp {
691            input: Box::new(LogicalOperator::Empty),
692            columns: None,
693        };
694        let cost = model.distinct_cost(&distinct, 1000.0);
695
696        // Distinct uses hash set
697        assert!(cost.cpu > 0.0);
698        assert!(cost.memory > 0.0);
699    }
700
701    #[test]
702    fn test_cost_model_limit() {
703        let model = CostModel::new();
704        let limit = LimitOp {
705            count: 10,
706            input: Box::new(LogicalOperator::Empty),
707        };
708        let cost = model.limit_cost(&limit, 1000.0);
709
710        // Limit is very cheap
711        assert!(cost.cpu > 0.0);
712        assert!(cost.cpu < 1.0); // Should be minimal
713    }
714
715    #[test]
716    fn test_cost_model_skip() {
717        let model = CostModel::new();
718        let skip = SkipOp {
719            count: 100,
720            input: Box::new(LogicalOperator::Empty),
721        };
722        let cost = model.skip_cost(&skip, 1000.0);
723
724        // Skip must scan through skipped rows
725        assert!(cost.cpu > 0.0);
726    }
727
728    #[test]
729    fn test_cost_model_return() {
730        let model = CostModel::new();
731        let ret = ReturnOp {
732            items: vec![
733                ReturnItem {
734                    expression: LogicalExpression::Variable("a".to_string()),
735                    alias: None,
736                },
737                ReturnItem {
738                    expression: LogicalExpression::Variable("b".to_string()),
739                    alias: None,
740                },
741            ],
742            distinct: false,
743            input: Box::new(LogicalOperator::Empty),
744        };
745        let cost = model.return_cost(&ret, 1000.0);
746
747        // Return materializes results
748        assert!(cost.cpu > 0.0);
749    }
750
751    #[test]
752    fn test_cost_cheaper() {
753        let model = CostModel::new();
754        let cheap = Cost::cpu(10.0);
755        let expensive = Cost::cpu(100.0);
756
757        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
758        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
759    }
760
761    #[test]
762    fn test_cost_comparison_prefers_lower_total() {
763        let model = CostModel::new();
764        // High CPU, low IO
765        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
766        // Low CPU, high IO
767        let io_heavy = Cost::cpu(10.0).with_io(20.0);
768
769        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
770        assert!(cpu_heavy.total() < io_heavy.total());
771        assert_eq!(
772            model.cheaper(&cpu_heavy, &io_heavy).total(),
773            cpu_heavy.total()
774        );
775    }
776
777    #[test]
778    fn test_cost_model_sort_with_keys() {
779        let model = CostModel::new();
780        let sort_single = SortOp {
781            keys: vec![crate::query::plan::SortKey {
782                expression: LogicalExpression::Variable("a".to_string()),
783                order: SortOrder::Ascending,
784            }],
785            input: Box::new(LogicalOperator::Empty),
786        };
787        let sort_multi = SortOp {
788            keys: vec![
789                crate::query::plan::SortKey {
790                    expression: LogicalExpression::Variable("a".to_string()),
791                    order: SortOrder::Ascending,
792                },
793                crate::query::plan::SortKey {
794                    expression: LogicalExpression::Variable("b".to_string()),
795                    order: SortOrder::Descending,
796                },
797            ],
798            input: Box::new(LogicalOperator::Empty),
799        };
800
801        let cost_single = model.sort_cost(&sort_single, 1000.0);
802        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
803
804        // More sort keys = more comparisons
805        assert!(cost_multi.cpu > cost_single.cpu);
806    }
807
808    #[test]
809    fn test_cost_model_empty_operator() {
810        let model = CostModel::new();
811        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
812        assert!((cost.total()).abs() < 0.001);
813    }
814
815    #[test]
816    fn test_cost_model_default() {
817        let model = CostModel::default();
818        let scan = NodeScanOp {
819            variable: "n".to_string(),
820            label: None,
821            input: None,
822        };
823        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
824        assert!(cost.total() > 0.0);
825    }
826
827    #[test]
828    fn test_leapfrog_join_cost() {
829        let model = CostModel::new();
830
831        // Three-way join (triangle pattern)
832        let cardinalities = vec![1000.0, 1000.0, 1000.0];
833        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
834
835        // Should have CPU cost for materialization and intersection
836        assert!(cost.cpu > 0.0);
837        // Should have memory cost for trie storage
838        assert!(cost.memory > 0.0);
839    }
840
841    #[test]
842    fn test_leapfrog_join_cost_empty() {
843        let model = CostModel::new();
844        let cost = model.leapfrog_join_cost(0, &[], 0.0);
845        assert!((cost.total()).abs() < 0.001);
846    }
847
848    #[test]
849    fn test_prefer_leapfrog_join_for_triangles() {
850        let model = CostModel::new();
851
852        // Compare costs for triangle pattern
853        let cardinalities = vec![10000.0, 10000.0, 10000.0];
854        let output = 1000.0;
855
856        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
857
858        // Leapfrog should have reasonable cost for triangle patterns
859        assert!(leapfrog_cost.cpu > 0.0);
860        assert!(leapfrog_cost.memory > 0.0);
861
862        // The prefer_leapfrog_join method compares against hash cascade
863        // Actual preference depends on specific cost parameters
864        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
865        // Test that it returns a boolean (doesn't panic)
866    }
867
868    #[test]
869    fn test_prefer_leapfrog_join_binary_case() {
870        let model = CostModel::new();
871
872        // Binary join should NOT prefer leapfrog (need 3+ relations)
873        let cardinalities = vec![1000.0, 1000.0];
874        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
875        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
876    }
877
878    #[test]
879    fn test_factorized_benefit_single_hop() {
880        let model = CostModel::new();
881
882        // Single hop: no factorization benefit
883        let benefit = model.factorized_benefit(10.0, 1);
884        assert!(
885            (benefit - 1.0).abs() < 0.001,
886            "Single hop should have no benefit"
887        );
888    }
889
890    #[test]
891    fn test_factorized_benefit_multi_hop() {
892        let model = CostModel::new();
893
894        // Multi-hop with high fanout
895        let benefit = model.factorized_benefit(10.0, 3);
896
897        // The factorized_benefit returns a ratio capped at 1.0
898        // For high fanout, factorized size / full size approaches 1/fanout
899        // which is beneficial but the formula gives a value <= 1.0
900        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
901        assert!(benefit > 0.0, "Benefit should be positive");
902    }
903
904    #[test]
905    fn test_factorized_benefit_low_fanout() {
906        let model = CostModel::new();
907
908        // Low fanout: minimal benefit
909        let benefit = model.factorized_benefit(1.5, 2);
910        assert!(
911            benefit <= 1.0,
912            "Low fanout still benefits from factorization"
913        );
914    }
915}