Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7    LogicalOperator, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10/// Cost of an operation.
11///
12/// Represents the estimated resource consumption of executing an operator.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15    /// Estimated CPU cycles / work units.
16    pub cpu: f64,
17    /// Estimated I/O operations (page reads).
18    pub io: f64,
19    /// Estimated memory usage in bytes.
20    pub memory: f64,
21    /// Network cost (for distributed queries).
22    pub network: f64,
23}
24
25impl Cost {
26    /// Creates a zero cost.
27    #[must_use]
28    pub fn zero() -> Self {
29        Self {
30            cpu: 0.0,
31            io: 0.0,
32            memory: 0.0,
33            network: 0.0,
34        }
35    }
36
37    /// Creates a cost from CPU work units.
38    #[must_use]
39    pub fn cpu(cpu: f64) -> Self {
40        Self {
41            cpu,
42            io: 0.0,
43            memory: 0.0,
44            network: 0.0,
45        }
46    }
47
48    /// Adds I/O cost.
49    #[must_use]
50    pub fn with_io(mut self, io: f64) -> Self {
51        self.io = io;
52        self
53    }
54
55    /// Adds memory cost.
56    #[must_use]
57    pub fn with_memory(mut self, memory: f64) -> Self {
58        self.memory = memory;
59        self
60    }
61
62    /// Returns the total weighted cost.
63    ///
64    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
65    #[must_use]
66    pub fn total(&self) -> f64 {
67        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68    }
69
70    /// Returns the total cost with custom weights.
71    #[must_use]
72    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74    }
75}
76
77impl std::ops::Add for Cost {
78    type Output = Self;
79
80    fn add(self, other: Self) -> Self {
81        Self {
82            cpu: self.cpu + other.cpu,
83            io: self.io + other.io,
84            memory: self.memory + other.memory,
85            network: self.network + other.network,
86        }
87    }
88}
89
90impl std::ops::AddAssign for Cost {
91    fn add_assign(&mut self, other: Self) {
92        self.cpu += other.cpu;
93        self.io += other.io;
94        self.memory += other.memory;
95        self.network += other.network;
96    }
97}
98
99/// Cost model for estimating operator costs.
100///
101/// Default constants are calibrated relative to each other:
102/// - Tuple scan is the baseline (1x)
103/// - Hash lookup is ~3x (hash computation + potential cache miss)
104/// - Sort comparison is ~2x (key extraction + comparison)
105/// - Distance computation is ~10x (vector math)
106pub struct CostModel {
107    /// Cost per tuple processed by CPU (baseline unit).
108    cpu_tuple_cost: f64,
109    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
110    hash_lookup_cost: f64,
111    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
112    sort_comparison_cost: f64,
113    /// Average tuple size in bytes (for IO estimation).
114    avg_tuple_size: f64,
115    /// Page size in bytes.
116    page_size: f64,
117    /// Global average edge fanout (fallback when per-type stats unavailable).
118    avg_fanout: f64,
119    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
120    edge_type_degrees: std::collections::HashMap<String, (f64, f64)>,
121}
122
123impl CostModel {
124    /// Creates a new cost model with calibrated default parameters.
125    #[must_use]
126    pub fn new() -> Self {
127        Self {
128            cpu_tuple_cost: 0.01,
129            hash_lookup_cost: 0.03,
130            sort_comparison_cost: 0.02,
131            avg_tuple_size: 100.0,
132            page_size: 8192.0,
133            avg_fanout: 10.0,
134            edge_type_degrees: std::collections::HashMap::new(),
135        }
136    }
137
138    /// Sets the global average fanout from graph statistics.
139    #[must_use]
140    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
141        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
142        self
143    }
144
145    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
146    ///
147    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
148    #[must_use]
149    pub fn with_edge_type_degrees(
150        mut self,
151        degrees: std::collections::HashMap<String, (f64, f64)>,
152    ) -> Self {
153        self.edge_type_degrees = degrees;
154        self
155    }
156
157    /// Returns the fanout for a specific expand operation.
158    ///
159    /// Uses per-edge-type degree stats when available, falling back to the
160    /// global average fanout.
161    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
162        if expand.edge_types.len() == 1
163            && let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(&expand.edge_types[0])
164        {
165            return match expand.direction {
166                ExpandDirection::Outgoing => out_deg,
167                ExpandDirection::Incoming => in_deg,
168                ExpandDirection::Both => out_deg + in_deg,
169            };
170        }
171        self.avg_fanout
172    }
173
174    /// Estimates the cost of a logical operator.
175    #[must_use]
176    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
177        match op {
178            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
179            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
180            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
181            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
182            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
183            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
184            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
185            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
186            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
187            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
188            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
189            LogicalOperator::Empty => Cost::zero(),
190            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
191            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
192            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
193        }
194    }
195
196    /// Estimates the cost of a node scan.
197    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
198        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
199        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
200    }
201
202    /// Estimates the cost of a filter operation.
203    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
204        // Filter cost is just predicate evaluation per tuple
205        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
206    }
207
208    /// Estimates the cost of a projection.
209    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
210        // Cost depends on number of expressions evaluated
211        let expr_count = project.projections.len() as f64;
212        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
213    }
214
215    /// Estimates the cost of an expand operation.
216    ///
217    /// Uses per-edge-type degree stats when available, otherwise falls back
218    /// to the global average fanout.
219    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
220        let fanout = self.fanout_for_expand(expand);
221        // Adjacency list lookup per input row
222        let lookup_cost = cardinality * self.hash_lookup_cost;
223        // Process each expanded output tuple
224        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
225        Cost::cpu(lookup_cost + output_cost)
226    }
227
228    /// Estimates the cost of a join operation.
229    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
230        // Cost depends on join type
231        match join.join_type {
232            JoinType::Cross => {
233                // Cross join is O(n * m)
234                Cost::cpu(cardinality * self.cpu_tuple_cost)
235            }
236            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
237                // Hash join: build phase + probe phase
238                // Assume left side is build, right side is probe
239                let build_cardinality = cardinality.sqrt(); // Rough estimate
240                let probe_cardinality = cardinality.sqrt();
241
242                // Build hash table
243                let build_cost = build_cardinality * self.hash_lookup_cost;
244                let memory_cost = build_cardinality * self.avg_tuple_size;
245
246                // Probe hash table
247                let probe_cost = probe_cardinality * self.hash_lookup_cost;
248
249                // Output cost
250                let output_cost = cardinality * self.cpu_tuple_cost;
251
252                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
253            }
254            JoinType::Semi | JoinType::Anti => {
255                // Semi/anti joins are typically cheaper
256                let build_cardinality = cardinality.sqrt();
257                let probe_cardinality = cardinality.sqrt();
258
259                let build_cost = build_cardinality * self.hash_lookup_cost;
260                let probe_cost = probe_cardinality * self.hash_lookup_cost;
261
262                Cost::cpu(build_cost + probe_cost)
263                    .with_memory(build_cardinality * self.avg_tuple_size)
264            }
265        }
266    }
267
268    /// Estimates the cost of an aggregation.
269    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
270        // Hash aggregation cost
271        let hash_cost = cardinality * self.hash_lookup_cost;
272
273        // Aggregate function evaluation
274        let agg_count = agg.aggregates.len() as f64;
275        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
276
277        // Memory for hash table (estimated distinct groups)
278        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
279        let memory_cost = distinct_groups * self.avg_tuple_size;
280
281        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
282    }
283
284    /// Estimates the cost of a sort operation.
285    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
286        if cardinality <= 1.0 {
287            return Cost::zero();
288        }
289
290        // Sort is O(n log n) comparisons
291        let comparisons = cardinality * cardinality.log2();
292        let key_count = sort.keys.len() as f64;
293
294        // Memory for sorting (full input materialization)
295        let memory_cost = cardinality * self.avg_tuple_size;
296
297        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
298    }
299
300    /// Estimates the cost of a distinct operation.
301    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
302        // Hash-based distinct
303        let hash_cost = cardinality * self.hash_lookup_cost;
304        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
305
306        Cost::cpu(hash_cost).with_memory(memory_cost)
307    }
308
309    /// Estimates the cost of a limit operation.
310    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
311        // Limit is very cheap - just counting
312        Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
313    }
314
315    /// Estimates the cost of a skip operation.
316    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
317        // Skip requires scanning through skipped rows
318        Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
319    }
320
321    /// Estimates the cost of a return operation.
322    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
323        // Return materializes results
324        let expr_count = ret.items.len() as f64;
325        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
326    }
327
328    /// Estimates the cost of a vector scan operation.
329    ///
330    /// HNSW index search is O(log N) per query, while brute-force is O(N).
331    /// This estimates the HNSW case with ef search parameter.
332    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
333        // k determines output cardinality
334        let k = scan.k as f64;
335
336        // HNSW search cost: O(ef * log(N)) distance computations
337        // Assume ef = 64 (default), N = cardinality
338        let ef = 64.0;
339        let n = cardinality.max(1.0);
340        let search_cost = if scan.index_name.is_some() {
341            // HNSW: O(ef * log N)
342            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
343        } else {
344            // Brute-force: O(N)
345            n * self.cpu_tuple_cost * 10.0
346        };
347
348        // Memory for candidate heap
349        let memory = k * self.avg_tuple_size * 2.0;
350
351        Cost::cpu(search_cost).with_memory(memory)
352    }
353
354    /// Estimates the cost of a vector join operation.
355    ///
356    /// Vector join performs k-NN search for each input row.
357    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
358        let k = join.k as f64;
359
360        // Each input row triggers a vector search
361        // Assume brute-force for hybrid queries (no index specified typically)
362        let per_row_search_cost = if join.index_name.is_some() {
363            // HNSW: O(ef * log N)
364            let ef = 64.0;
365            let n = cardinality.max(1.0);
366            ef * n.ln() * self.cpu_tuple_cost * 10.0
367        } else {
368            // Brute-force: O(N) per input row
369            cardinality * self.cpu_tuple_cost * 10.0
370        };
371
372        // Total cost: input_rows * search_cost
373        // For vector join, cardinality is typically input cardinality * k
374        let input_cardinality = (cardinality / k).max(1.0);
375        let total_search_cost = input_cardinality * per_row_search_cost;
376
377        // Memory for results
378        let memory = cardinality * self.avg_tuple_size;
379
380        Cost::cpu(total_search_cost).with_memory(memory)
381    }
382
383    /// Compares two costs and returns the cheaper one.
384    #[must_use]
385    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
386        if a.total() <= b.total() { a } else { b }
387    }
388
389    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
390    ///
391    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
392    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
393    /// all relations simultaneously using sorted iterators.
394    ///
395    /// # Arguments
396    /// * `num_relations` - Number of relations participating in the join
397    /// * `cardinalities` - Cardinality of each input relation
398    /// * `output_cardinality` - Expected output cardinality
399    ///
400    /// # Cost Model
401    /// - Materialization: O(sum of cardinalities) to build trie indexes
402    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
403    /// - Memory: Trie storage for all inputs
404    #[must_use]
405    pub fn leapfrog_join_cost(
406        &self,
407        num_relations: usize,
408        cardinalities: &[f64],
409        output_cardinality: f64,
410    ) -> Cost {
411        if cardinalities.is_empty() {
412            return Cost::zero();
413        }
414
415        let total_input: f64 = cardinalities.iter().sum();
416        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
417
418        // Materialization phase: build trie indexes for each input
419        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
420
421        // Intersection phase: leapfrog seeks are O(log n) per relation
422        let seek_cost = if min_card > 1.0 {
423            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
424        } else {
425            output_cardinality * self.cpu_tuple_cost
426        };
427
428        // Output materialization
429        let output_cost = output_cardinality * self.cpu_tuple_cost;
430
431        // Memory: trie storage (roughly 2x input size for sorted index)
432        let memory = total_input * self.avg_tuple_size * 2.0;
433
434        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
435    }
436
437    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
438    ///
439    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
440    #[must_use]
441    pub fn prefer_leapfrog_join(
442        &self,
443        num_relations: usize,
444        cardinalities: &[f64],
445        output_cardinality: f64,
446    ) -> bool {
447        if num_relations < 3 || cardinalities.len() < 3 {
448            // Leapfrog is only beneficial for multi-way joins (3+)
449            return false;
450        }
451
452        let leapfrog_cost =
453            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
454
455        // Estimate cascade of binary hash joins
456        // For N relations, we need N-1 joins
457        // Each join produces intermediate results that feed the next
458        let mut hash_cascade_cost = Cost::zero();
459        let mut intermediate_cardinality = cardinalities[0];
460
461        for card in &cardinalities[1..] {
462            // Hash join cost: build + probe + output
463            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
464            let join = JoinOp {
465                left: Box::new(LogicalOperator::Empty),
466                right: Box::new(LogicalOperator::Empty),
467                join_type: JoinType::Inner,
468                conditions: vec![],
469            };
470            hash_cascade_cost += self.join_cost(&join, join_output);
471            intermediate_cardinality = join_output;
472        }
473
474        leapfrog_cost.total() < hash_cascade_cost.total()
475    }
476
477    /// Estimates cost for factorized execution (compressed intermediate results).
478    ///
479    /// Factorized execution avoids materializing full cross products by keeping
480    /// results in a compressed "factorized" form. This is beneficial for multi-hop
481    /// traversals where intermediate results can explode.
482    ///
483    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
484    #[must_use]
485    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
486        if num_hops <= 1 || avg_fanout <= 1.0 {
487            return 1.0; // No benefit for single hop or low fanout
488        }
489
490        // Factorized representation compresses repeated prefixes
491        // Compression ratio improves with higher fanout and more hops
492        // Full materialization: fanout^hops
493        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
494
495        let full_size = avg_fanout.powi(num_hops as i32);
496        let factorized_size = if avg_fanout > 1.0 {
497            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
498        } else {
499            num_hops as f64
500        };
501
502        (factorized_size / full_size).min(1.0)
503    }
504}
505
506impl Default for CostModel {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::query::plan::{
516        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
517        PathMode, Projection, ReturnItem, SortOrder,
518    };
519
520    #[test]
521    fn test_cost_addition() {
522        let a = Cost::cpu(10.0).with_io(5.0);
523        let b = Cost::cpu(20.0).with_memory(100.0);
524        let c = a + b;
525
526        assert!((c.cpu - 30.0).abs() < 0.001);
527        assert!((c.io - 5.0).abs() < 0.001);
528        assert!((c.memory - 100.0).abs() < 0.001);
529    }
530
531    #[test]
532    fn test_cost_total() {
533        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
534        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
535        assert!((cost.total() - 30.0).abs() < 0.001);
536    }
537
538    #[test]
539    fn test_cost_model_node_scan() {
540        let model = CostModel::new();
541        let scan = NodeScanOp {
542            variable: "n".to_string(),
543            label: Some("Person".to_string()),
544            input: None,
545        };
546        let cost = model.node_scan_cost(&scan, 1000.0);
547
548        assert!(cost.cpu > 0.0);
549        assert!(cost.io > 0.0);
550    }
551
552    #[test]
553    fn test_cost_model_sort() {
554        let model = CostModel::new();
555        let sort = SortOp {
556            keys: vec![],
557            input: Box::new(LogicalOperator::Empty),
558        };
559
560        let cost_100 = model.sort_cost(&sort, 100.0);
561        let cost_1000 = model.sort_cost(&sort, 1000.0);
562
563        // Sorting 1000 rows should be more expensive than 100 rows
564        assert!(cost_1000.total() > cost_100.total());
565    }
566
567    #[test]
568    fn test_cost_zero() {
569        let cost = Cost::zero();
570        assert!((cost.cpu).abs() < 0.001);
571        assert!((cost.io).abs() < 0.001);
572        assert!((cost.memory).abs() < 0.001);
573        assert!((cost.network).abs() < 0.001);
574        assert!((cost.total()).abs() < 0.001);
575    }
576
577    #[test]
578    fn test_cost_add_assign() {
579        let mut cost = Cost::cpu(10.0);
580        cost += Cost::cpu(5.0).with_io(2.0);
581        assert!((cost.cpu - 15.0).abs() < 0.001);
582        assert!((cost.io - 2.0).abs() < 0.001);
583    }
584
585    #[test]
586    fn test_cost_total_weighted() {
587        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
588        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
589        let total = cost.total_weighted(2.0, 5.0, 0.5);
590        assert!((total - 80.0).abs() < 0.001);
591    }
592
593    #[test]
594    fn test_cost_model_filter() {
595        let model = CostModel::new();
596        let filter = FilterOp {
597            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
598            input: Box::new(LogicalOperator::Empty),
599        };
600        let cost = model.filter_cost(&filter, 1000.0);
601
602        // Filter cost is CPU only
603        assert!(cost.cpu > 0.0);
604        assert!((cost.io).abs() < 0.001);
605    }
606
607    #[test]
608    fn test_cost_model_project() {
609        let model = CostModel::new();
610        let project = ProjectOp {
611            projections: vec![
612                Projection {
613                    expression: LogicalExpression::Variable("a".to_string()),
614                    alias: None,
615                },
616                Projection {
617                    expression: LogicalExpression::Variable("b".to_string()),
618                    alias: None,
619                },
620            ],
621            input: Box::new(LogicalOperator::Empty),
622        };
623        let cost = model.project_cost(&project, 1000.0);
624
625        // Cost should scale with number of projections
626        assert!(cost.cpu > 0.0);
627    }
628
629    #[test]
630    fn test_cost_model_expand() {
631        let model = CostModel::new();
632        let expand = ExpandOp {
633            from_variable: "a".to_string(),
634            to_variable: "b".to_string(),
635            edge_variable: None,
636            direction: ExpandDirection::Outgoing,
637            edge_types: vec![],
638            min_hops: 1,
639            max_hops: Some(1),
640            input: Box::new(LogicalOperator::Empty),
641            path_alias: None,
642            path_mode: PathMode::Walk,
643        };
644        let cost = model.expand_cost(&expand, 1000.0);
645
646        // Expand involves hash lookups and output generation
647        assert!(cost.cpu > 0.0);
648    }
649
650    #[test]
651    fn test_cost_model_expand_with_edge_type_stats() {
652        let mut degrees = std::collections::HashMap::new();
653        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
654        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
655
656        let model = CostModel::new().with_edge_type_degrees(degrees);
657
658        // Outgoing KNOWS: fanout = 5
659        let knows_out = ExpandOp {
660            from_variable: "a".to_string(),
661            to_variable: "b".to_string(),
662            edge_variable: None,
663            direction: ExpandDirection::Outgoing,
664            edge_types: vec!["KNOWS".to_string()],
665            min_hops: 1,
666            max_hops: Some(1),
667            input: Box::new(LogicalOperator::Empty),
668            path_alias: None,
669            path_mode: PathMode::Walk,
670        };
671        let cost_knows = model.expand_cost(&knows_out, 1000.0);
672
673        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
674        let works_out = ExpandOp {
675            from_variable: "a".to_string(),
676            to_variable: "b".to_string(),
677            edge_variable: None,
678            direction: ExpandDirection::Outgoing,
679            edge_types: vec!["WORKS_AT".to_string()],
680            min_hops: 1,
681            max_hops: Some(1),
682            input: Box::new(LogicalOperator::Empty),
683            path_alias: None,
684            path_mode: PathMode::Walk,
685        };
686        let cost_works = model.expand_cost(&works_out, 1000.0);
687
688        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
689        assert!(
690            cost_knows.cpu > cost_works.cpu,
691            "KNOWS(5) should cost more than WORKS_AT(1)"
692        );
693
694        // Incoming WORKS_AT: fanout = 50 (company has many employees)
695        let works_in = ExpandOp {
696            from_variable: "c".to_string(),
697            to_variable: "p".to_string(),
698            edge_variable: None,
699            direction: ExpandDirection::Incoming,
700            edge_types: vec!["WORKS_AT".to_string()],
701            min_hops: 1,
702            max_hops: Some(1),
703            input: Box::new(LogicalOperator::Empty),
704            path_alias: None,
705            path_mode: PathMode::Walk,
706        };
707        let cost_works_in = model.expand_cost(&works_in, 1000.0);
708
709        // Incoming WORKS_AT (fanout=50) should be most expensive
710        assert!(
711            cost_works_in.cpu > cost_knows.cpu,
712            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
713        );
714    }
715
716    #[test]
717    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
718        let model = CostModel::new().with_avg_fanout(7.0);
719        let expand = ExpandOp {
720            from_variable: "a".to_string(),
721            to_variable: "b".to_string(),
722            edge_variable: None,
723            direction: ExpandDirection::Outgoing,
724            edge_types: vec!["UNKNOWN_TYPE".to_string()],
725            min_hops: 1,
726            max_hops: Some(1),
727            input: Box::new(LogicalOperator::Empty),
728            path_alias: None,
729            path_mode: PathMode::Walk,
730        };
731        let cost_unknown = model.expand_cost(&expand, 1000.0);
732
733        // Without edge type (uses global fanout too)
734        let expand_no_type = ExpandOp {
735            from_variable: "a".to_string(),
736            to_variable: "b".to_string(),
737            edge_variable: None,
738            direction: ExpandDirection::Outgoing,
739            edge_types: vec![],
740            min_hops: 1,
741            max_hops: Some(1),
742            input: Box::new(LogicalOperator::Empty),
743            path_alias: None,
744            path_mode: PathMode::Walk,
745        };
746        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
747
748        // Both should use global fanout = 7, so costs should be equal
749        assert!(
750            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
751            "Unknown edge type should use global fanout"
752        );
753    }
754
755    #[test]
756    fn test_cost_model_hash_join() {
757        let model = CostModel::new();
758        let join = JoinOp {
759            left: Box::new(LogicalOperator::Empty),
760            right: Box::new(LogicalOperator::Empty),
761            join_type: JoinType::Inner,
762            conditions: vec![JoinCondition {
763                left: LogicalExpression::Variable("a".to_string()),
764                right: LogicalExpression::Variable("b".to_string()),
765            }],
766        };
767        let cost = model.join_cost(&join, 10000.0);
768
769        // Hash join has CPU cost and memory cost
770        assert!(cost.cpu > 0.0);
771        assert!(cost.memory > 0.0);
772    }
773
774    #[test]
775    fn test_cost_model_cross_join() {
776        let model = CostModel::new();
777        let join = JoinOp {
778            left: Box::new(LogicalOperator::Empty),
779            right: Box::new(LogicalOperator::Empty),
780            join_type: JoinType::Cross,
781            conditions: vec![],
782        };
783        let cost = model.join_cost(&join, 1000000.0);
784
785        // Cross join is expensive
786        assert!(cost.cpu > 0.0);
787    }
788
789    #[test]
790    fn test_cost_model_semi_join() {
791        let model = CostModel::new();
792        let join = JoinOp {
793            left: Box::new(LogicalOperator::Empty),
794            right: Box::new(LogicalOperator::Empty),
795            join_type: JoinType::Semi,
796            conditions: vec![],
797        };
798        let cost_semi = model.join_cost(&join, 1000.0);
799
800        let inner_join = JoinOp {
801            left: Box::new(LogicalOperator::Empty),
802            right: Box::new(LogicalOperator::Empty),
803            join_type: JoinType::Inner,
804            conditions: vec![],
805        };
806        let cost_inner = model.join_cost(&inner_join, 1000.0);
807
808        // Semi join can be cheaper than inner join
809        assert!(cost_semi.cpu > 0.0);
810        assert!(cost_inner.cpu > 0.0);
811    }
812
813    #[test]
814    fn test_cost_model_aggregate() {
815        let model = CostModel::new();
816        let agg = AggregateOp {
817            group_by: vec![],
818            aggregates: vec![
819                AggregateExpr {
820                    function: AggregateFunction::Count,
821                    expression: None,
822                    distinct: false,
823                    alias: Some("cnt".to_string()),
824                    percentile: None,
825                },
826                AggregateExpr {
827                    function: AggregateFunction::Sum,
828                    expression: Some(LogicalExpression::Variable("x".to_string())),
829                    distinct: false,
830                    alias: Some("total".to_string()),
831                    percentile: None,
832                },
833            ],
834            input: Box::new(LogicalOperator::Empty),
835            having: None,
836        };
837        let cost = model.aggregate_cost(&agg, 1000.0);
838
839        // Aggregation has hash cost and memory cost
840        assert!(cost.cpu > 0.0);
841        assert!(cost.memory > 0.0);
842    }
843
844    #[test]
845    fn test_cost_model_distinct() {
846        let model = CostModel::new();
847        let distinct = DistinctOp {
848            input: Box::new(LogicalOperator::Empty),
849            columns: None,
850        };
851        let cost = model.distinct_cost(&distinct, 1000.0);
852
853        // Distinct uses hash set
854        assert!(cost.cpu > 0.0);
855        assert!(cost.memory > 0.0);
856    }
857
858    #[test]
859    fn test_cost_model_limit() {
860        let model = CostModel::new();
861        let limit = LimitOp {
862            count: 10,
863            input: Box::new(LogicalOperator::Empty),
864        };
865        let cost = model.limit_cost(&limit, 1000.0);
866
867        // Limit is very cheap
868        assert!(cost.cpu > 0.0);
869        assert!(cost.cpu < 1.0); // Should be minimal
870    }
871
872    #[test]
873    fn test_cost_model_skip() {
874        let model = CostModel::new();
875        let skip = SkipOp {
876            count: 100,
877            input: Box::new(LogicalOperator::Empty),
878        };
879        let cost = model.skip_cost(&skip, 1000.0);
880
881        // Skip must scan through skipped rows
882        assert!(cost.cpu > 0.0);
883    }
884
885    #[test]
886    fn test_cost_model_return() {
887        let model = CostModel::new();
888        let ret = ReturnOp {
889            items: vec![
890                ReturnItem {
891                    expression: LogicalExpression::Variable("a".to_string()),
892                    alias: None,
893                },
894                ReturnItem {
895                    expression: LogicalExpression::Variable("b".to_string()),
896                    alias: None,
897                },
898            ],
899            distinct: false,
900            input: Box::new(LogicalOperator::Empty),
901        };
902        let cost = model.return_cost(&ret, 1000.0);
903
904        // Return materializes results
905        assert!(cost.cpu > 0.0);
906    }
907
908    #[test]
909    fn test_cost_cheaper() {
910        let model = CostModel::new();
911        let cheap = Cost::cpu(10.0);
912        let expensive = Cost::cpu(100.0);
913
914        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
915        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
916    }
917
918    #[test]
919    fn test_cost_comparison_prefers_lower_total() {
920        let model = CostModel::new();
921        // High CPU, low IO
922        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
923        // Low CPU, high IO
924        let io_heavy = Cost::cpu(10.0).with_io(20.0);
925
926        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
927        assert!(cpu_heavy.total() < io_heavy.total());
928        assert_eq!(
929            model.cheaper(&cpu_heavy, &io_heavy).total(),
930            cpu_heavy.total()
931        );
932    }
933
934    #[test]
935    fn test_cost_model_sort_with_keys() {
936        let model = CostModel::new();
937        let sort_single = SortOp {
938            keys: vec![crate::query::plan::SortKey {
939                expression: LogicalExpression::Variable("a".to_string()),
940                order: SortOrder::Ascending,
941            }],
942            input: Box::new(LogicalOperator::Empty),
943        };
944        let sort_multi = SortOp {
945            keys: vec![
946                crate::query::plan::SortKey {
947                    expression: LogicalExpression::Variable("a".to_string()),
948                    order: SortOrder::Ascending,
949                },
950                crate::query::plan::SortKey {
951                    expression: LogicalExpression::Variable("b".to_string()),
952                    order: SortOrder::Descending,
953                },
954            ],
955            input: Box::new(LogicalOperator::Empty),
956        };
957
958        let cost_single = model.sort_cost(&sort_single, 1000.0);
959        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
960
961        // More sort keys = more comparisons
962        assert!(cost_multi.cpu > cost_single.cpu);
963    }
964
965    #[test]
966    fn test_cost_model_empty_operator() {
967        let model = CostModel::new();
968        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
969        assert!((cost.total()).abs() < 0.001);
970    }
971
972    #[test]
973    fn test_cost_model_default() {
974        let model = CostModel::default();
975        let scan = NodeScanOp {
976            variable: "n".to_string(),
977            label: None,
978            input: None,
979        };
980        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
981        assert!(cost.total() > 0.0);
982    }
983
984    #[test]
985    fn test_leapfrog_join_cost() {
986        let model = CostModel::new();
987
988        // Three-way join (triangle pattern)
989        let cardinalities = vec![1000.0, 1000.0, 1000.0];
990        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
991
992        // Should have CPU cost for materialization and intersection
993        assert!(cost.cpu > 0.0);
994        // Should have memory cost for trie storage
995        assert!(cost.memory > 0.0);
996    }
997
998    #[test]
999    fn test_leapfrog_join_cost_empty() {
1000        let model = CostModel::new();
1001        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1002        assert!((cost.total()).abs() < 0.001);
1003    }
1004
1005    #[test]
1006    fn test_prefer_leapfrog_join_for_triangles() {
1007        let model = CostModel::new();
1008
1009        // Compare costs for triangle pattern
1010        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1011        let output = 1000.0;
1012
1013        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1014
1015        // Leapfrog should have reasonable cost for triangle patterns
1016        assert!(leapfrog_cost.cpu > 0.0);
1017        assert!(leapfrog_cost.memory > 0.0);
1018
1019        // The prefer_leapfrog_join method compares against hash cascade
1020        // Actual preference depends on specific cost parameters
1021        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1022        // Test that it returns a boolean (doesn't panic)
1023    }
1024
1025    #[test]
1026    fn test_prefer_leapfrog_join_binary_case() {
1027        let model = CostModel::new();
1028
1029        // Binary join should NOT prefer leapfrog (need 3+ relations)
1030        let cardinalities = vec![1000.0, 1000.0];
1031        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1032        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1033    }
1034
1035    #[test]
1036    fn test_factorized_benefit_single_hop() {
1037        let model = CostModel::new();
1038
1039        // Single hop: no factorization benefit
1040        let benefit = model.factorized_benefit(10.0, 1);
1041        assert!(
1042            (benefit - 1.0).abs() < 0.001,
1043            "Single hop should have no benefit"
1044        );
1045    }
1046
1047    #[test]
1048    fn test_factorized_benefit_multi_hop() {
1049        let model = CostModel::new();
1050
1051        // Multi-hop with high fanout
1052        let benefit = model.factorized_benefit(10.0, 3);
1053
1054        // The factorized_benefit returns a ratio capped at 1.0
1055        // For high fanout, factorized size / full size approaches 1/fanout
1056        // which is beneficial but the formula gives a value <= 1.0
1057        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1058        assert!(benefit > 0.0, "Benefit should be positive");
1059    }
1060
1061    #[test]
1062    fn test_factorized_benefit_low_fanout() {
1063        let model = CostModel::new();
1064
1065        // Low fanout: minimal benefit
1066        let benefit = model.factorized_benefit(1.5, 2);
1067        assert!(
1068            benefit <= 1.0,
1069            "Low fanout still benefits from factorization"
1070        );
1071    }
1072}