Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7    LogicalOperator, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100///
101/// Default constants are calibrated relative to each other:
102/// - Tuple scan is the baseline (1x)
103/// - Hash lookup is ~3x (hash computation + potential cache miss)
104/// - Sort comparison is ~2x (key extraction + comparison)
105/// - Distance computation is ~10x (vector math)
106pub struct CostModel {
107    /// Cost per tuple processed by CPU (baseline unit).
108    cpu_tuple_cost: f64,
109    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
110    hash_lookup_cost: f64,
111    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
112    sort_comparison_cost: f64,
113    /// Average tuple size in bytes (for IO estimation).
114    avg_tuple_size: f64,
115    /// Page size in bytes.
116    page_size: f64,
117    /// Global average edge fanout (fallback when per-type stats unavailable).
118    avg_fanout: f64,
119    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
120    edge_type_degrees: std::collections::HashMap<String, (f64, f64)>,
121}
122
123impl CostModel {
124    /// Creates a new cost model with calibrated default parameters.
125    #[must_use]
126    pub fn new() -> Self {
127        Self {
128            cpu_tuple_cost: 0.01,
129            hash_lookup_cost: 0.03,
130            sort_comparison_cost: 0.02,
131            avg_tuple_size: 100.0,
132            page_size: 8192.0,
133            avg_fanout: 10.0,
134            edge_type_degrees: std::collections::HashMap::new(),
135        }
136    }
137
138    /// Sets the global average fanout from graph statistics.
139    #[must_use]
140    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
141        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
142        self
143    }
144
145    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
146    ///
147    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
148    #[must_use]
149    pub fn with_edge_type_degrees(
150        mut self,
151        degrees: std::collections::HashMap<String, (f64, f64)>,
152    ) -> Self {
153        self.edge_type_degrees = degrees;
154        self
155    }
156
157    /// Returns the fanout for a specific expand operation.
158    ///
159    /// Uses per-edge-type degree stats when available, falling back to the
160    /// global average fanout.
161    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
162        if let Some(edge_type) = &expand.edge_type
163            && let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type)
164        {
165            return match expand.direction {
166                ExpandDirection::Outgoing => out_deg,
167                ExpandDirection::Incoming => in_deg,
168                ExpandDirection::Both => out_deg + in_deg,
169            };
170        }
171        self.avg_fanout
172    }
173
174    /// Estimates the cost of a logical operator.
175    #[must_use]
176    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
177        match op {
178            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
179            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
180            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
181            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
182            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
183            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
184            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
185            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
186            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
187            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
188            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
189            LogicalOperator::Empty => Cost::zero(),
190            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
191            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
192            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
193        }
194    }
195
196    /// Estimates the cost of a node scan.
197    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
198        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
199        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
200    }
201
202    /// Estimates the cost of a filter operation.
203    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
204        // Filter cost is just predicate evaluation per tuple
205        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
206    }
207
208    /// Estimates the cost of a projection.
209    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
210        // Cost depends on number of expressions evaluated
211        let expr_count = project.projections.len() as f64;
212        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
213    }
214
215    /// Estimates the cost of an expand operation.
216    ///
217    /// Uses per-edge-type degree stats when available, otherwise falls back
218    /// to the global average fanout.
219    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
220        let fanout = self.fanout_for_expand(expand);
221        // Adjacency list lookup per input row
222        let lookup_cost = cardinality * self.hash_lookup_cost;
223        // Process each expanded output tuple
224        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
225        Cost::cpu(lookup_cost + output_cost)
226    }
227
228    /// Estimates the cost of a join operation.
229    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
230        // Cost depends on join type
231        match join.join_type {
232            JoinType::Cross => {
233                // Cross join is O(n * m)
234                Cost::cpu(cardinality * self.cpu_tuple_cost)
235            }
236            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
237                // Hash join: build phase + probe phase
238                // Assume left side is build, right side is probe
239                let build_cardinality = cardinality.sqrt(); // Rough estimate
240                let probe_cardinality = cardinality.sqrt();
241
242                // Build hash table
243                let build_cost = build_cardinality * self.hash_lookup_cost;
244                let memory_cost = build_cardinality * self.avg_tuple_size;
245
246                // Probe hash table
247                let probe_cost = probe_cardinality * self.hash_lookup_cost;
248
249                // Output cost
250                let output_cost = cardinality * self.cpu_tuple_cost;
251
252                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
253            }
254            JoinType::Semi | JoinType::Anti => {
255                // Semi/anti joins are typically cheaper
256                let build_cardinality = cardinality.sqrt();
257                let probe_cardinality = cardinality.sqrt();
258
259                let build_cost = build_cardinality * self.hash_lookup_cost;
260                let probe_cost = probe_cardinality * self.hash_lookup_cost;
261
262                Cost::cpu(build_cost + probe_cost)
263                    .with_memory(build_cardinality * self.avg_tuple_size)
264            }
265        }
266    }
267
268    /// Estimates the cost of an aggregation.
269    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
270        // Hash aggregation cost
271        let hash_cost = cardinality * self.hash_lookup_cost;
272
273        // Aggregate function evaluation
274        let agg_count = agg.aggregates.len() as f64;
275        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
276
277        // Memory for hash table (estimated distinct groups)
278        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
279        let memory_cost = distinct_groups * self.avg_tuple_size;
280
281        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
282    }
283
284    /// Estimates the cost of a sort operation.
285    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
286        if cardinality <= 1.0 {
287            return Cost::zero();
288        }
289
290        // Sort is O(n log n) comparisons
291        let comparisons = cardinality * cardinality.log2();
292        let key_count = sort.keys.len() as f64;
293
294        // Memory for sorting (full input materialization)
295        let memory_cost = cardinality * self.avg_tuple_size;
296
297        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
298    }
299
300    /// Estimates the cost of a distinct operation.
301    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
302        // Hash-based distinct
303        let hash_cost = cardinality * self.hash_lookup_cost;
304        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
305
306        Cost::cpu(hash_cost).with_memory(memory_cost)
307    }
308
309    /// Estimates the cost of a limit operation.
310    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
311        // Limit is very cheap - just counting
312        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
313    }
314
315    /// Estimates the cost of a skip operation.
316    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
317        // Skip requires scanning through skipped rows
318        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
319    }
320
321    /// Estimates the cost of a return operation.
322    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
323        // Return materializes results
324        let expr_count = ret.items.len() as f64;
325        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
326    }
327
328    /// Estimates the cost of a vector scan operation.
329    ///
330    /// HNSW index search is O(log N) per query, while brute-force is O(N).
331    /// This estimates the HNSW case with ef search parameter.
332    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
333        // k determines output cardinality
334        let k = scan.k as f64;
335
336        // HNSW search cost: O(ef * log(N)) distance computations
337        // Assume ef = 64 (default), N = cardinality
338        let ef = 64.0;
339        let n = cardinality.max(1.0);
340        let search_cost = if scan.index_name.is_some() {
341            // HNSW: O(ef * log N)
342            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
343        } else {
344            // Brute-force: O(N)
345            n * self.cpu_tuple_cost * 10.0
346        };
347
348        // Memory for candidate heap
349        let memory = k * self.avg_tuple_size * 2.0;
350
351        Cost::cpu(search_cost).with_memory(memory)
352    }
353
354    /// Estimates the cost of a vector join operation.
355    ///
356    /// Vector join performs k-NN search for each input row.
357    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
358        let k = join.k as f64;
359
360        // Each input row triggers a vector search
361        // Assume brute-force for hybrid queries (no index specified typically)
362        let per_row_search_cost = if join.index_name.is_some() {
363            // HNSW: O(ef * log N)
364            let ef = 64.0;
365            let n = cardinality.max(1.0);
366            ef * n.ln() * self.cpu_tuple_cost * 10.0
367        } else {
368            // Brute-force: O(N) per input row
369            cardinality * self.cpu_tuple_cost * 10.0
370        };
371
372        // Total cost: input_rows * search_cost
373        // For vector join, cardinality is typically input cardinality * k
374        let input_cardinality = (cardinality / k).max(1.0);
375        let total_search_cost = input_cardinality * per_row_search_cost;
376
377        // Memory for results
378        let memory = cardinality * self.avg_tuple_size;
379
380        Cost::cpu(total_search_cost).with_memory(memory)
381    }
382
383    /// Compares two costs and returns the cheaper one.
384    #[must_use]
385    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
386        if a.total() <= b.total() { a } else { b }
387    }
388
389    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
390    ///
391    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
392    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
393    /// all relations simultaneously using sorted iterators.
394    ///
395    /// # Arguments
396    /// * `num_relations` - Number of relations participating in the join
397    /// * `cardinalities` - Cardinality of each input relation
398    /// * `output_cardinality` - Expected output cardinality
399    ///
400    /// # Cost Model
401    /// - Materialization: O(sum of cardinalities) to build trie indexes
402    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
403    /// - Memory: Trie storage for all inputs
404    #[must_use]
405    pub fn leapfrog_join_cost(
406        &self,
407        num_relations: usize,
408        cardinalities: &[f64],
409        output_cardinality: f64,
410    ) -> Cost {
411        if cardinalities.is_empty() {
412            return Cost::zero();
413        }
414
415        let total_input: f64 = cardinalities.iter().sum();
416        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
417
418        // Materialization phase: build trie indexes for each input
419        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
420
421        // Intersection phase: leapfrog seeks are O(log n) per relation
422        let seek_cost = if min_card > 1.0 {
423            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
424        } else {
425            output_cardinality * self.cpu_tuple_cost
426        };
427
428        // Output materialization
429        let output_cost = output_cardinality * self.cpu_tuple_cost;
430
431        // Memory: trie storage (roughly 2x input size for sorted index)
432        let memory = total_input * self.avg_tuple_size * 2.0;
433
434        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
435    }
436
437    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
438    ///
439    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
440    #[must_use]
441    pub fn prefer_leapfrog_join(
442        &self,
443        num_relations: usize,
444        cardinalities: &[f64],
445        output_cardinality: f64,
446    ) -> bool {
447        if num_relations < 3 || cardinalities.len() < 3 {
448            // Leapfrog is only beneficial for multi-way joins (3+)
449            return false;
450        }
451
452        let leapfrog_cost =
453            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
454
455        // Estimate cascade of binary hash joins
456        // For N relations, we need N-1 joins
457        // Each join produces intermediate results that feed the next
458        let mut hash_cascade_cost = Cost::zero();
459        let mut intermediate_cardinality = cardinalities[0];
460
461        for card in &cardinalities[1..] {
462            // Hash join cost: build + probe + output
463            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
464            let join = JoinOp {
465                left: Box::new(LogicalOperator::Empty),
466                right: Box::new(LogicalOperator::Empty),
467                join_type: JoinType::Inner,
468                conditions: vec![],
469            };
470            hash_cascade_cost += self.join_cost(&join, join_output);
471            intermediate_cardinality = join_output;
472        }
473
474        leapfrog_cost.total() < hash_cascade_cost.total()
475    }
476
477    /// Estimates cost for factorized execution (compressed intermediate results).
478    ///
479    /// Factorized execution avoids materializing full cross products by keeping
480    /// results in a compressed "factorized" form. This is beneficial for multi-hop
481    /// traversals where intermediate results can explode.
482    ///
483    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
484    #[must_use]
485    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
486        if num_hops <= 1 || avg_fanout <= 1.0 {
487            return 1.0; // No benefit for single hop or low fanout
488        }
489
490        // Factorized representation compresses repeated prefixes
491        // Compression ratio improves with higher fanout and more hops
492        // Full materialization: fanout^hops
493        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
494
495        let full_size = avg_fanout.powi(num_hops as i32);
496        let factorized_size = if avg_fanout > 1.0 {
497            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
498        } else {
499            num_hops as f64
500        };
501
502        (factorized_size / full_size).min(1.0)
503    }
504}
505
506impl Default for CostModel {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::query::plan::{
516        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
517        Projection, ReturnItem, SortOrder,
518    };
519
520    #[test]
521    fn test_cost_addition() {
522        let a = Cost::cpu(10.0).with_io(5.0);
523        let b = Cost::cpu(20.0).with_memory(100.0);
524        let c = a + b;
525
526        assert!((c.cpu - 30.0).abs() < 0.001);
527        assert!((c.io - 5.0).abs() < 0.001);
528        assert!((c.memory - 100.0).abs() < 0.001);
529    }
530
531    #[test]
532    fn test_cost_total() {
533        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
534        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
535        assert!((cost.total() - 30.0).abs() < 0.001);
536    }
537
538    #[test]
539    fn test_cost_model_node_scan() {
540        let model = CostModel::new();
541        let scan = NodeScanOp {
542            variable: "n".to_string(),
543            label: Some("Person".to_string()),
544            input: None,
545        };
546        let cost = model.node_scan_cost(&scan, 1000.0);
547
548        assert!(cost.cpu > 0.0);
549        assert!(cost.io > 0.0);
550    }
551
552    #[test]
553    fn test_cost_model_sort() {
554        let model = CostModel::new();
555        let sort = SortOp {
556            keys: vec![],
557            input: Box::new(LogicalOperator::Empty),
558        };
559
560        let cost_100 = model.sort_cost(&sort, 100.0);
561        let cost_1000 = model.sort_cost(&sort, 1000.0);
562
563        // Sorting 1000 rows should be more expensive than 100 rows
564        assert!(cost_1000.total() > cost_100.total());
565    }
566
567    #[test]
568    fn test_cost_zero() {
569        let cost = Cost::zero();
570        assert!((cost.cpu).abs() < 0.001);
571        assert!((cost.io).abs() < 0.001);
572        assert!((cost.memory).abs() < 0.001);
573        assert!((cost.network).abs() < 0.001);
574        assert!((cost.total()).abs() < 0.001);
575    }
576
577    #[test]
578    fn test_cost_add_assign() {
579        let mut cost = Cost::cpu(10.0);
580        cost += Cost::cpu(5.0).with_io(2.0);
581        assert!((cost.cpu - 15.0).abs() < 0.001);
582        assert!((cost.io - 2.0).abs() < 0.001);
583    }
584
585    #[test]
586    fn test_cost_total_weighted() {
587        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
588        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
589        let total = cost.total_weighted(2.0, 5.0, 0.5);
590        assert!((total - 80.0).abs() < 0.001);
591    }
592
593    #[test]
594    fn test_cost_model_filter() {
595        let model = CostModel::new();
596        let filter = FilterOp {
597            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
598            input: Box::new(LogicalOperator::Empty),
599        };
600        let cost = model.filter_cost(&filter, 1000.0);
601
602        // Filter cost is CPU only
603        assert!(cost.cpu > 0.0);
604        assert!((cost.io).abs() < 0.001);
605    }
606
607    #[test]
608    fn test_cost_model_project() {
609        let model = CostModel::new();
610        let project = ProjectOp {
611            projections: vec![
612                Projection {
613                    expression: LogicalExpression::Variable("a".to_string()),
614                    alias: None,
615                },
616                Projection {
617                    expression: LogicalExpression::Variable("b".to_string()),
618                    alias: None,
619                },
620            ],
621            input: Box::new(LogicalOperator::Empty),
622        };
623        let cost = model.project_cost(&project, 1000.0);
624
625        // Cost should scale with number of projections
626        assert!(cost.cpu > 0.0);
627    }
628
629    #[test]
630    fn test_cost_model_expand() {
631        let model = CostModel::new();
632        let expand = ExpandOp {
633            from_variable: "a".to_string(),
634            to_variable: "b".to_string(),
635            edge_variable: None,
636            direction: ExpandDirection::Outgoing,
637            edge_type: None,
638            min_hops: 1,
639            max_hops: Some(1),
640            input: Box::new(LogicalOperator::Empty),
641            path_alias: None,
642        };
643        let cost = model.expand_cost(&expand, 1000.0);
644
645        // Expand involves hash lookups and output generation
646        assert!(cost.cpu > 0.0);
647    }
648
649    #[test]
650    fn test_cost_model_expand_with_edge_type_stats() {
651        let mut degrees = std::collections::HashMap::new();
652        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
653        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
654
655        let model = CostModel::new().with_edge_type_degrees(degrees);
656
657        // Outgoing KNOWS: fanout = 5
658        let knows_out = ExpandOp {
659            from_variable: "a".to_string(),
660            to_variable: "b".to_string(),
661            edge_variable: None,
662            direction: ExpandDirection::Outgoing,
663            edge_type: Some("KNOWS".to_string()),
664            min_hops: 1,
665            max_hops: Some(1),
666            input: Box::new(LogicalOperator::Empty),
667            path_alias: None,
668        };
669        let cost_knows = model.expand_cost(&knows_out, 1000.0);
670
671        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
672        let works_out = ExpandOp {
673            from_variable: "a".to_string(),
674            to_variable: "b".to_string(),
675            edge_variable: None,
676            direction: ExpandDirection::Outgoing,
677            edge_type: Some("WORKS_AT".to_string()),
678            min_hops: 1,
679            max_hops: Some(1),
680            input: Box::new(LogicalOperator::Empty),
681            path_alias: None,
682        };
683        let cost_works = model.expand_cost(&works_out, 1000.0);
684
685        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
686        assert!(
687            cost_knows.cpu > cost_works.cpu,
688            "KNOWS(5) should cost more than WORKS_AT(1)"
689        );
690
691        // Incoming WORKS_AT: fanout = 50 (company has many employees)
692        let works_in = ExpandOp {
693            from_variable: "c".to_string(),
694            to_variable: "p".to_string(),
695            edge_variable: None,
696            direction: ExpandDirection::Incoming,
697            edge_type: Some("WORKS_AT".to_string()),
698            min_hops: 1,
699            max_hops: Some(1),
700            input: Box::new(LogicalOperator::Empty),
701            path_alias: None,
702        };
703        let cost_works_in = model.expand_cost(&works_in, 1000.0);
704
705        // Incoming WORKS_AT (fanout=50) should be most expensive
706        assert!(
707            cost_works_in.cpu > cost_knows.cpu,
708            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
709        );
710    }
711
712    #[test]
713    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
714        let model = CostModel::new().with_avg_fanout(7.0);
715        let expand = ExpandOp {
716            from_variable: "a".to_string(),
717            to_variable: "b".to_string(),
718            edge_variable: None,
719            direction: ExpandDirection::Outgoing,
720            edge_type: Some("UNKNOWN_TYPE".to_string()),
721            min_hops: 1,
722            max_hops: Some(1),
723            input: Box::new(LogicalOperator::Empty),
724            path_alias: None,
725        };
726        let cost_unknown = model.expand_cost(&expand, 1000.0);
727
728        // Without edge type (uses global fanout too)
729        let expand_no_type = ExpandOp {
730            from_variable: "a".to_string(),
731            to_variable: "b".to_string(),
732            edge_variable: None,
733            direction: ExpandDirection::Outgoing,
734            edge_type: None,
735            min_hops: 1,
736            max_hops: Some(1),
737            input: Box::new(LogicalOperator::Empty),
738            path_alias: None,
739        };
740        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
741
742        // Both should use global fanout = 7, so costs should be equal
743        assert!(
744            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
745            "Unknown edge type should use global fanout"
746        );
747    }
748
749    #[test]
750    fn test_cost_model_hash_join() {
751        let model = CostModel::new();
752        let join = JoinOp {
753            left: Box::new(LogicalOperator::Empty),
754            right: Box::new(LogicalOperator::Empty),
755            join_type: JoinType::Inner,
756            conditions: vec![JoinCondition {
757                left: LogicalExpression::Variable("a".to_string()),
758                right: LogicalExpression::Variable("b".to_string()),
759            }],
760        };
761        let cost = model.join_cost(&join, 10000.0);
762
763        // Hash join has CPU cost and memory cost
764        assert!(cost.cpu > 0.0);
765        assert!(cost.memory > 0.0);
766    }
767
768    #[test]
769    fn test_cost_model_cross_join() {
770        let model = CostModel::new();
771        let join = JoinOp {
772            left: Box::new(LogicalOperator::Empty),
773            right: Box::new(LogicalOperator::Empty),
774            join_type: JoinType::Cross,
775            conditions: vec![],
776        };
777        let cost = model.join_cost(&join, 1000000.0);
778
779        // Cross join is expensive
780        assert!(cost.cpu > 0.0);
781    }
782
783    #[test]
784    fn test_cost_model_semi_join() {
785        let model = CostModel::new();
786        let join = JoinOp {
787            left: Box::new(LogicalOperator::Empty),
788            right: Box::new(LogicalOperator::Empty),
789            join_type: JoinType::Semi,
790            conditions: vec![],
791        };
792        let cost_semi = model.join_cost(&join, 1000.0);
793
794        let inner_join = JoinOp {
795            left: Box::new(LogicalOperator::Empty),
796            right: Box::new(LogicalOperator::Empty),
797            join_type: JoinType::Inner,
798            conditions: vec![],
799        };
800        let cost_inner = model.join_cost(&inner_join, 1000.0);
801
802        // Semi join can be cheaper than inner join
803        assert!(cost_semi.cpu > 0.0);
804        assert!(cost_inner.cpu > 0.0);
805    }
806
807    #[test]
808    fn test_cost_model_aggregate() {
809        let model = CostModel::new();
810        let agg = AggregateOp {
811            group_by: vec![],
812            aggregates: vec![
813                AggregateExpr {
814                    function: AggregateFunction::Count,
815                    expression: None,
816                    distinct: false,
817                    alias: Some("cnt".to_string()),
818                    percentile: None,
819                },
820                AggregateExpr {
821                    function: AggregateFunction::Sum,
822                    expression: Some(LogicalExpression::Variable("x".to_string())),
823                    distinct: false,
824                    alias: Some("total".to_string()),
825                    percentile: None,
826                },
827            ],
828            input: Box::new(LogicalOperator::Empty),
829            having: None,
830        };
831        let cost = model.aggregate_cost(&agg, 1000.0);
832
833        // Aggregation has hash cost and memory cost
834        assert!(cost.cpu > 0.0);
835        assert!(cost.memory > 0.0);
836    }
837
838    #[test]
839    fn test_cost_model_distinct() {
840        let model = CostModel::new();
841        let distinct = DistinctOp {
842            input: Box::new(LogicalOperator::Empty),
843            columns: None,
844        };
845        let cost = model.distinct_cost(&distinct, 1000.0);
846
847        // Distinct uses hash set
848        assert!(cost.cpu > 0.0);
849        assert!(cost.memory > 0.0);
850    }
851
852    #[test]
853    fn test_cost_model_limit() {
854        let model = CostModel::new();
855        let limit = LimitOp {
856            count: 10,
857            input: Box::new(LogicalOperator::Empty),
858        };
859        let cost = model.limit_cost(&limit, 1000.0);
860
861        // Limit is very cheap
862        assert!(cost.cpu > 0.0);
863        assert!(cost.cpu < 1.0); // Should be minimal
864    }
865
866    #[test]
867    fn test_cost_model_skip() {
868        let model = CostModel::new();
869        let skip = SkipOp {
870            count: 100,
871            input: Box::new(LogicalOperator::Empty),
872        };
873        let cost = model.skip_cost(&skip, 1000.0);
874
875        // Skip must scan through skipped rows
876        assert!(cost.cpu > 0.0);
877    }
878
879    #[test]
880    fn test_cost_model_return() {
881        let model = CostModel::new();
882        let ret = ReturnOp {
883            items: vec![
884                ReturnItem {
885                    expression: LogicalExpression::Variable("a".to_string()),
886                    alias: None,
887                },
888                ReturnItem {
889                    expression: LogicalExpression::Variable("b".to_string()),
890                    alias: None,
891                },
892            ],
893            distinct: false,
894            input: Box::new(LogicalOperator::Empty),
895        };
896        let cost = model.return_cost(&ret, 1000.0);
897
898        // Return materializes results
899        assert!(cost.cpu > 0.0);
900    }
901
902    #[test]
903    fn test_cost_cheaper() {
904        let model = CostModel::new();
905        let cheap = Cost::cpu(10.0);
906        let expensive = Cost::cpu(100.0);
907
908        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
909        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
910    }
911
912    #[test]
913    fn test_cost_comparison_prefers_lower_total() {
914        let model = CostModel::new();
915        // High CPU, low IO
916        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
917        // Low CPU, high IO
918        let io_heavy = Cost::cpu(10.0).with_io(20.0);
919
920        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
921        assert!(cpu_heavy.total() < io_heavy.total());
922        assert_eq!(
923            model.cheaper(&cpu_heavy, &io_heavy).total(),
924            cpu_heavy.total()
925        );
926    }
927
928    #[test]
929    fn test_cost_model_sort_with_keys() {
930        let model = CostModel::new();
931        let sort_single = SortOp {
932            keys: vec![crate::query::plan::SortKey {
933                expression: LogicalExpression::Variable("a".to_string()),
934                order: SortOrder::Ascending,
935            }],
936            input: Box::new(LogicalOperator::Empty),
937        };
938        let sort_multi = SortOp {
939            keys: vec![
940                crate::query::plan::SortKey {
941                    expression: LogicalExpression::Variable("a".to_string()),
942                    order: SortOrder::Ascending,
943                },
944                crate::query::plan::SortKey {
945                    expression: LogicalExpression::Variable("b".to_string()),
946                    order: SortOrder::Descending,
947                },
948            ],
949            input: Box::new(LogicalOperator::Empty),
950        };
951
952        let cost_single = model.sort_cost(&sort_single, 1000.0);
953        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
954
955        // More sort keys = more comparisons
956        assert!(cost_multi.cpu > cost_single.cpu);
957    }
958
959    #[test]
960    fn test_cost_model_empty_operator() {
961        let model = CostModel::new();
962        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
963        assert!((cost.total()).abs() < 0.001);
964    }
965
966    #[test]
967    fn test_cost_model_default() {
968        let model = CostModel::default();
969        let scan = NodeScanOp {
970            variable: "n".to_string(),
971            label: None,
972            input: None,
973        };
974        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
975        assert!(cost.total() > 0.0);
976    }
977
978    #[test]
979    fn test_leapfrog_join_cost() {
980        let model = CostModel::new();
981
982        // Three-way join (triangle pattern)
983        let cardinalities = vec![1000.0, 1000.0, 1000.0];
984        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
985
986        // Should have CPU cost for materialization and intersection
987        assert!(cost.cpu > 0.0);
988        // Should have memory cost for trie storage
989        assert!(cost.memory > 0.0);
990    }
991
992    #[test]
993    fn test_leapfrog_join_cost_empty() {
994        let model = CostModel::new();
995        let cost = model.leapfrog_join_cost(0, &[], 0.0);
996        assert!((cost.total()).abs() < 0.001);
997    }
998
999    #[test]
1000    fn test_prefer_leapfrog_join_for_triangles() {
1001        let model = CostModel::new();
1002
1003        // Compare costs for triangle pattern
1004        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1005        let output = 1000.0;
1006
1007        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1008
1009        // Leapfrog should have reasonable cost for triangle patterns
1010        assert!(leapfrog_cost.cpu > 0.0);
1011        assert!(leapfrog_cost.memory > 0.0);
1012
1013        // The prefer_leapfrog_join method compares against hash cascade
1014        // Actual preference depends on specific cost parameters
1015        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1016        // Test that it returns a boolean (doesn't panic)
1017    }
1018
1019    #[test]
1020    fn test_prefer_leapfrog_join_binary_case() {
1021        let model = CostModel::new();
1022
1023        // Binary join should NOT prefer leapfrog (need 3+ relations)
1024        let cardinalities = vec![1000.0, 1000.0];
1025        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1026        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1027    }
1028
1029    #[test]
1030    fn test_factorized_benefit_single_hop() {
1031        let model = CostModel::new();
1032
1033        // Single hop: no factorization benefit
1034        let benefit = model.factorized_benefit(10.0, 1);
1035        assert!(
1036            (benefit - 1.0).abs() < 0.001,
1037            "Single hop should have no benefit"
1038        );
1039    }
1040
1041    #[test]
1042    fn test_factorized_benefit_multi_hop() {
1043        let model = CostModel::new();
1044
1045        // Multi-hop with high fanout
1046        let benefit = model.factorized_benefit(10.0, 3);
1047
1048        // The factorized_benefit returns a ratio capped at 1.0
1049        // For high fanout, factorized size / full size approaches 1/fanout
1050        // which is beneficial but the formula gives a value <= 1.0
1051        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1052        assert!(benefit > 0.0, "Benefit should be positive");
1053    }
1054
1055    #[test]
1056    fn test_factorized_benefit_low_fanout() {
1057        let model = CostModel::new();
1058
1059        // Low fanout: minimal benefit
1060        let benefit = model.factorized_benefit(1.5, 2);
1061        assert!(
1062            benefit <= 1.0,
1063            "Low fanout still benefits from factorization"
1064        );
1065    }
1066}