Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7    LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp,
8    VectorScanOp,
9};
10
11/// Cost of an operation.
12///
13/// Represents the estimated resource consumption of executing an operator.
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub struct Cost {
16    /// Estimated CPU cycles / work units.
17    pub cpu: f64,
18    /// Estimated I/O operations (page reads).
19    pub io: f64,
20    /// Estimated memory usage in bytes.
21    pub memory: f64,
22    /// Network cost (for distributed queries).
23    pub network: f64,
24}
25
26impl Cost {
27    /// Creates a zero cost.
28    #[must_use]
29    pub fn zero() -> Self {
30        Self {
31            cpu: 0.0,
32            io: 0.0,
33            memory: 0.0,
34            network: 0.0,
35        }
36    }
37
38    /// Creates a cost from CPU work units.
39    #[must_use]
40    pub fn cpu(cpu: f64) -> Self {
41        Self {
42            cpu,
43            io: 0.0,
44            memory: 0.0,
45            network: 0.0,
46        }
47    }
48
49    /// Adds I/O cost.
50    #[must_use]
51    pub fn with_io(mut self, io: f64) -> Self {
52        self.io = io;
53        self
54    }
55
56    /// Adds memory cost.
57    #[must_use]
58    pub fn with_memory(mut self, memory: f64) -> Self {
59        self.memory = memory;
60        self
61    }
62
63    /// Returns the total weighted cost.
64    ///
65    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
66    #[must_use]
67    pub fn total(&self) -> f64 {
68        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
69    }
70
71    /// Returns the total cost with custom weights.
72    #[must_use]
73    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
74        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
75    }
76}
77
78impl std::ops::Add for Cost {
79    type Output = Self;
80
81    fn add(self, other: Self) -> Self {
82        Self {
83            cpu: self.cpu + other.cpu,
84            io: self.io + other.io,
85            memory: self.memory + other.memory,
86            network: self.network + other.network,
87        }
88    }
89}
90
91impl std::ops::AddAssign for Cost {
92    fn add_assign(&mut self, other: Self) {
93        self.cpu += other.cpu;
94        self.io += other.io;
95        self.memory += other.memory;
96        self.network += other.network;
97    }
98}
99
100/// Cost model for estimating operator costs.
101///
102/// Default constants are calibrated relative to each other:
103/// - Tuple scan is the baseline (1x)
104/// - Hash lookup is ~3x (hash computation + potential cache miss)
105/// - Sort comparison is ~2x (key extraction + comparison)
106/// - Distance computation is ~10x (vector math)
107pub struct CostModel {
108    /// Cost per tuple processed by CPU (baseline unit).
109    cpu_tuple_cost: f64,
110    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
111    hash_lookup_cost: f64,
112    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
113    sort_comparison_cost: f64,
114    /// Average tuple size in bytes (for IO estimation).
115    avg_tuple_size: f64,
116    /// Page size in bytes.
117    page_size: f64,
118    /// Global average edge fanout (fallback when per-type stats unavailable).
119    avg_fanout: f64,
120    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
121    edge_type_degrees: std::collections::HashMap<String, (f64, f64)>,
122}
123
124impl CostModel {
125    /// Creates a new cost model with calibrated default parameters.
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            cpu_tuple_cost: 0.01,
130            hash_lookup_cost: 0.03,
131            sort_comparison_cost: 0.02,
132            avg_tuple_size: 100.0,
133            page_size: 8192.0,
134            avg_fanout: 10.0,
135            edge_type_degrees: std::collections::HashMap::new(),
136        }
137    }
138
139    /// Sets the global average fanout from graph statistics.
140    #[must_use]
141    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
142        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
143        self
144    }
145
146    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
147    ///
148    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
149    #[must_use]
150    pub fn with_edge_type_degrees(
151        mut self,
152        degrees: std::collections::HashMap<String, (f64, f64)>,
153    ) -> Self {
154        self.edge_type_degrees = degrees;
155        self
156    }
157
158    /// Returns the fanout for a specific expand operation.
159    ///
160    /// Uses per-edge-type degree stats when available, falling back to the
161    /// global average fanout.
162    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
163        if expand.edge_types.len() == 1
164            && let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(&expand.edge_types[0])
165        {
166            return match expand.direction {
167                ExpandDirection::Outgoing => out_deg,
168                ExpandDirection::Incoming => in_deg,
169                ExpandDirection::Both => out_deg + in_deg,
170            };
171        }
172        self.avg_fanout
173    }
174
175    /// Estimates the cost of a logical operator.
176    #[must_use]
177    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
178        match op {
179            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
180            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
181            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
182            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
183            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
184            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
185            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
186            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
187            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
188            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
189            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
190            LogicalOperator::Empty => Cost::zero(),
191            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
192            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
193            LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
194            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
195        }
196    }
197
198    /// Estimates the cost of a node scan.
199    fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
200        let pages = (cardinality * self.avg_tuple_size) / self.page_size;
201        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
202    }
203
204    /// Estimates the cost of a filter operation.
205    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
206        // Filter cost is just predicate evaluation per tuple
207        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
208    }
209
210    /// Estimates the cost of a projection.
211    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
212        // Cost depends on number of expressions evaluated
213        let expr_count = project.projections.len() as f64;
214        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
215    }
216
217    /// Estimates the cost of an expand operation.
218    ///
219    /// Uses per-edge-type degree stats when available, otherwise falls back
220    /// to the global average fanout.
221    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
222        let fanout = self.fanout_for_expand(expand);
223        // Adjacency list lookup per input row
224        let lookup_cost = cardinality * self.hash_lookup_cost;
225        // Process each expanded output tuple
226        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
227        Cost::cpu(lookup_cost + output_cost)
228    }
229
230    /// Estimates the cost of a join operation.
231    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
232        // Cost depends on join type
233        match join.join_type {
234            JoinType::Cross => {
235                // Cross join is O(n * m)
236                Cost::cpu(cardinality * self.cpu_tuple_cost)
237            }
238            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
239                // Hash join: build phase + probe phase
240                // Assume left side is build, right side is probe
241                let build_cardinality = cardinality.sqrt(); // Rough estimate
242                let probe_cardinality = cardinality.sqrt();
243
244                // Build hash table
245                let build_cost = build_cardinality * self.hash_lookup_cost;
246                let memory_cost = build_cardinality * self.avg_tuple_size;
247
248                // Probe hash table
249                let probe_cost = probe_cardinality * self.hash_lookup_cost;
250
251                // Output cost
252                let output_cost = cardinality * self.cpu_tuple_cost;
253
254                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
255            }
256            JoinType::Semi | JoinType::Anti => {
257                // Semi/anti joins are typically cheaper
258                let build_cardinality = cardinality.sqrt();
259                let probe_cardinality = cardinality.sqrt();
260
261                let build_cost = build_cardinality * self.hash_lookup_cost;
262                let probe_cost = probe_cardinality * self.hash_lookup_cost;
263
264                Cost::cpu(build_cost + probe_cost)
265                    .with_memory(build_cardinality * self.avg_tuple_size)
266            }
267        }
268    }
269
270    /// Estimates the cost of a multi-way (leapfrog) join.
271    ///
272    /// Delegates to `leapfrog_join_cost` using per-input cardinality estimates
273    /// derived from the output cardinality divided equally among inputs.
274    fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
275        let n = mwj.inputs.len();
276        if n == 0 {
277            return Cost::zero();
278        }
279        // Approximate per-input cardinalities: assume each input contributes
280        // cardinality^(1/n) rows (AGM-style uniform assumption)
281        let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
282        let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
283        self.leapfrog_join_cost(n, &cardinalities, cardinality)
284    }
285
286    /// Estimates the cost of an aggregation.
287    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
288        // Hash aggregation cost
289        let hash_cost = cardinality * self.hash_lookup_cost;
290
291        // Aggregate function evaluation
292        let agg_count = agg.aggregates.len() as f64;
293        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
294
295        // Memory for hash table (estimated distinct groups)
296        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
297        let memory_cost = distinct_groups * self.avg_tuple_size;
298
299        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
300    }
301
302    /// Estimates the cost of a sort operation.
303    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
304        if cardinality <= 1.0 {
305            return Cost::zero();
306        }
307
308        // Sort is O(n log n) comparisons
309        let comparisons = cardinality * cardinality.log2();
310        let key_count = sort.keys.len() as f64;
311
312        // Memory for sorting (full input materialization)
313        let memory_cost = cardinality * self.avg_tuple_size;
314
315        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
316    }
317
318    /// Estimates the cost of a distinct operation.
319    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
320        // Hash-based distinct
321        let hash_cost = cardinality * self.hash_lookup_cost;
322        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
323
324        Cost::cpu(hash_cost).with_memory(memory_cost)
325    }
326
327    /// Estimates the cost of a limit operation.
328    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
329        // Limit is very cheap - just counting
330        Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
331    }
332
333    /// Estimates the cost of a skip operation.
334    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
335        // Skip requires scanning through skipped rows
336        Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
337    }
338
339    /// Estimates the cost of a return operation.
340    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
341        // Return materializes results
342        let expr_count = ret.items.len() as f64;
343        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
344    }
345
346    /// Estimates the cost of a vector scan operation.
347    ///
348    /// HNSW index search is O(log N) per query, while brute-force is O(N).
349    /// This estimates the HNSW case with ef search parameter.
350    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
351        // k determines output cardinality
352        let k = scan.k as f64;
353
354        // HNSW search cost: O(ef * log(N)) distance computations
355        // Assume ef = 64 (default), N = cardinality
356        let ef = 64.0;
357        let n = cardinality.max(1.0);
358        let search_cost = if scan.index_name.is_some() {
359            // HNSW: O(ef * log N)
360            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
361        } else {
362            // Brute-force: O(N)
363            n * self.cpu_tuple_cost * 10.0
364        };
365
366        // Memory for candidate heap
367        let memory = k * self.avg_tuple_size * 2.0;
368
369        Cost::cpu(search_cost).with_memory(memory)
370    }
371
372    /// Estimates the cost of a vector join operation.
373    ///
374    /// Vector join performs k-NN search for each input row.
375    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
376        let k = join.k as f64;
377
378        // Each input row triggers a vector search
379        // Assume brute-force for hybrid queries (no index specified typically)
380        let per_row_search_cost = if join.index_name.is_some() {
381            // HNSW: O(ef * log N)
382            let ef = 64.0;
383            let n = cardinality.max(1.0);
384            ef * n.ln() * self.cpu_tuple_cost * 10.0
385        } else {
386            // Brute-force: O(N) per input row
387            cardinality * self.cpu_tuple_cost * 10.0
388        };
389
390        // Total cost: input_rows * search_cost
391        // For vector join, cardinality is typically input cardinality * k
392        let input_cardinality = (cardinality / k).max(1.0);
393        let total_search_cost = input_cardinality * per_row_search_cost;
394
395        // Memory for results
396        let memory = cardinality * self.avg_tuple_size;
397
398        Cost::cpu(total_search_cost).with_memory(memory)
399    }
400
401    /// Compares two costs and returns the cheaper one.
402    #[must_use]
403    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
404        if a.total() <= b.total() { a } else { b }
405    }
406
407    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
408    ///
409    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
410    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
411    /// all relations simultaneously using sorted iterators.
412    ///
413    /// # Arguments
414    /// * `num_relations` - Number of relations participating in the join
415    /// * `cardinalities` - Cardinality of each input relation
416    /// * `output_cardinality` - Expected output cardinality
417    ///
418    /// # Cost Model
419    /// - Materialization: O(sum of cardinalities) to build trie indexes
420    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
421    /// - Memory: Trie storage for all inputs
422    #[must_use]
423    pub fn leapfrog_join_cost(
424        &self,
425        num_relations: usize,
426        cardinalities: &[f64],
427        output_cardinality: f64,
428    ) -> Cost {
429        if cardinalities.is_empty() {
430            return Cost::zero();
431        }
432
433        let total_input: f64 = cardinalities.iter().sum();
434        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
435
436        // Materialization phase: build trie indexes for each input
437        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
438
439        // Intersection phase: leapfrog seeks are O(log n) per relation
440        let seek_cost = if min_card > 1.0 {
441            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
442        } else {
443            output_cardinality * self.cpu_tuple_cost
444        };
445
446        // Output materialization
447        let output_cost = output_cardinality * self.cpu_tuple_cost;
448
449        // Memory: trie storage (roughly 2x input size for sorted index)
450        let memory = total_input * self.avg_tuple_size * 2.0;
451
452        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
453    }
454
455    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
456    ///
457    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
458    #[must_use]
459    pub fn prefer_leapfrog_join(
460        &self,
461        num_relations: usize,
462        cardinalities: &[f64],
463        output_cardinality: f64,
464    ) -> bool {
465        if num_relations < 3 || cardinalities.len() < 3 {
466            // Leapfrog is only beneficial for multi-way joins (3+)
467            return false;
468        }
469
470        let leapfrog_cost =
471            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
472
473        // Estimate cascade of binary hash joins
474        // For N relations, we need N-1 joins
475        // Each join produces intermediate results that feed the next
476        let mut hash_cascade_cost = Cost::zero();
477        let mut intermediate_cardinality = cardinalities[0];
478
479        for card in &cardinalities[1..] {
480            // Hash join cost: build + probe + output
481            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
482            let join = JoinOp {
483                left: Box::new(LogicalOperator::Empty),
484                right: Box::new(LogicalOperator::Empty),
485                join_type: JoinType::Inner,
486                conditions: vec![],
487            };
488            hash_cascade_cost += self.join_cost(&join, join_output);
489            intermediate_cardinality = join_output;
490        }
491
492        leapfrog_cost.total() < hash_cascade_cost.total()
493    }
494
495    /// Estimates cost for factorized execution (compressed intermediate results).
496    ///
497    /// Factorized execution avoids materializing full cross products by keeping
498    /// results in a compressed "factorized" form. This is beneficial for multi-hop
499    /// traversals where intermediate results can explode.
500    ///
501    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
502    #[must_use]
503    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
504        if num_hops <= 1 || avg_fanout <= 1.0 {
505            return 1.0; // No benefit for single hop or low fanout
506        }
507
508        // Factorized representation compresses repeated prefixes
509        // Compression ratio improves with higher fanout and more hops
510        // Full materialization: fanout^hops
511        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
512
513        let full_size = avg_fanout.powi(num_hops as i32);
514        let factorized_size = if avg_fanout > 1.0 {
515            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
516        } else {
517            num_hops as f64
518        };
519
520        (factorized_size / full_size).min(1.0)
521    }
522}
523
524impl Default for CostModel {
525    fn default() -> Self {
526        Self::new()
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use crate::query::plan::{
534        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
535        PathMode, Projection, ReturnItem, SortOrder,
536    };
537
538    #[test]
539    fn test_cost_addition() {
540        let a = Cost::cpu(10.0).with_io(5.0);
541        let b = Cost::cpu(20.0).with_memory(100.0);
542        let c = a + b;
543
544        assert!((c.cpu - 30.0).abs() < 0.001);
545        assert!((c.io - 5.0).abs() < 0.001);
546        assert!((c.memory - 100.0).abs() < 0.001);
547    }
548
549    #[test]
550    fn test_cost_total() {
551        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
552        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
553        assert!((cost.total() - 30.0).abs() < 0.001);
554    }
555
556    #[test]
557    fn test_cost_model_node_scan() {
558        let model = CostModel::new();
559        let scan = NodeScanOp {
560            variable: "n".to_string(),
561            label: Some("Person".to_string()),
562            input: None,
563        };
564        let cost = model.node_scan_cost(&scan, 1000.0);
565
566        assert!(cost.cpu > 0.0);
567        assert!(cost.io > 0.0);
568    }
569
570    #[test]
571    fn test_cost_model_sort() {
572        let model = CostModel::new();
573        let sort = SortOp {
574            keys: vec![],
575            input: Box::new(LogicalOperator::Empty),
576        };
577
578        let cost_100 = model.sort_cost(&sort, 100.0);
579        let cost_1000 = model.sort_cost(&sort, 1000.0);
580
581        // Sorting 1000 rows should be more expensive than 100 rows
582        assert!(cost_1000.total() > cost_100.total());
583    }
584
585    #[test]
586    fn test_cost_zero() {
587        let cost = Cost::zero();
588        assert!((cost.cpu).abs() < 0.001);
589        assert!((cost.io).abs() < 0.001);
590        assert!((cost.memory).abs() < 0.001);
591        assert!((cost.network).abs() < 0.001);
592        assert!((cost.total()).abs() < 0.001);
593    }
594
595    #[test]
596    fn test_cost_add_assign() {
597        let mut cost = Cost::cpu(10.0);
598        cost += Cost::cpu(5.0).with_io(2.0);
599        assert!((cost.cpu - 15.0).abs() < 0.001);
600        assert!((cost.io - 2.0).abs() < 0.001);
601    }
602
603    #[test]
604    fn test_cost_total_weighted() {
605        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
606        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
607        let total = cost.total_weighted(2.0, 5.0, 0.5);
608        assert!((total - 80.0).abs() < 0.001);
609    }
610
611    #[test]
612    fn test_cost_model_filter() {
613        let model = CostModel::new();
614        let filter = FilterOp {
615            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
616            input: Box::new(LogicalOperator::Empty),
617            pushdown_hint: None,
618        };
619        let cost = model.filter_cost(&filter, 1000.0);
620
621        // Filter cost is CPU only
622        assert!(cost.cpu > 0.0);
623        assert!((cost.io).abs() < 0.001);
624    }
625
626    #[test]
627    fn test_cost_model_project() {
628        let model = CostModel::new();
629        let project = ProjectOp {
630            projections: vec![
631                Projection {
632                    expression: LogicalExpression::Variable("a".to_string()),
633                    alias: None,
634                },
635                Projection {
636                    expression: LogicalExpression::Variable("b".to_string()),
637                    alias: None,
638                },
639            ],
640            input: Box::new(LogicalOperator::Empty),
641        };
642        let cost = model.project_cost(&project, 1000.0);
643
644        // Cost should scale with number of projections
645        assert!(cost.cpu > 0.0);
646    }
647
648    #[test]
649    fn test_cost_model_expand() {
650        let model = CostModel::new();
651        let expand = ExpandOp {
652            from_variable: "a".to_string(),
653            to_variable: "b".to_string(),
654            edge_variable: None,
655            direction: ExpandDirection::Outgoing,
656            edge_types: vec![],
657            min_hops: 1,
658            max_hops: Some(1),
659            input: Box::new(LogicalOperator::Empty),
660            path_alias: None,
661            path_mode: PathMode::Walk,
662        };
663        let cost = model.expand_cost(&expand, 1000.0);
664
665        // Expand involves hash lookups and output generation
666        assert!(cost.cpu > 0.0);
667    }
668
669    #[test]
670    fn test_cost_model_expand_with_edge_type_stats() {
671        let mut degrees = std::collections::HashMap::new();
672        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
673        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
674
675        let model = CostModel::new().with_edge_type_degrees(degrees);
676
677        // Outgoing KNOWS: fanout = 5
678        let knows_out = ExpandOp {
679            from_variable: "a".to_string(),
680            to_variable: "b".to_string(),
681            edge_variable: None,
682            direction: ExpandDirection::Outgoing,
683            edge_types: vec!["KNOWS".to_string()],
684            min_hops: 1,
685            max_hops: Some(1),
686            input: Box::new(LogicalOperator::Empty),
687            path_alias: None,
688            path_mode: PathMode::Walk,
689        };
690        let cost_knows = model.expand_cost(&knows_out, 1000.0);
691
692        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
693        let works_out = ExpandOp {
694            from_variable: "a".to_string(),
695            to_variable: "b".to_string(),
696            edge_variable: None,
697            direction: ExpandDirection::Outgoing,
698            edge_types: vec!["WORKS_AT".to_string()],
699            min_hops: 1,
700            max_hops: Some(1),
701            input: Box::new(LogicalOperator::Empty),
702            path_alias: None,
703            path_mode: PathMode::Walk,
704        };
705        let cost_works = model.expand_cost(&works_out, 1000.0);
706
707        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
708        assert!(
709            cost_knows.cpu > cost_works.cpu,
710            "KNOWS(5) should cost more than WORKS_AT(1)"
711        );
712
713        // Incoming WORKS_AT: fanout = 50 (company has many employees)
714        let works_in = ExpandOp {
715            from_variable: "c".to_string(),
716            to_variable: "p".to_string(),
717            edge_variable: None,
718            direction: ExpandDirection::Incoming,
719            edge_types: vec!["WORKS_AT".to_string()],
720            min_hops: 1,
721            max_hops: Some(1),
722            input: Box::new(LogicalOperator::Empty),
723            path_alias: None,
724            path_mode: PathMode::Walk,
725        };
726        let cost_works_in = model.expand_cost(&works_in, 1000.0);
727
728        // Incoming WORKS_AT (fanout=50) should be most expensive
729        assert!(
730            cost_works_in.cpu > cost_knows.cpu,
731            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
732        );
733    }
734
735    #[test]
736    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
737        let model = CostModel::new().with_avg_fanout(7.0);
738        let expand = ExpandOp {
739            from_variable: "a".to_string(),
740            to_variable: "b".to_string(),
741            edge_variable: None,
742            direction: ExpandDirection::Outgoing,
743            edge_types: vec!["UNKNOWN_TYPE".to_string()],
744            min_hops: 1,
745            max_hops: Some(1),
746            input: Box::new(LogicalOperator::Empty),
747            path_alias: None,
748            path_mode: PathMode::Walk,
749        };
750        let cost_unknown = model.expand_cost(&expand, 1000.0);
751
752        // Without edge type (uses global fanout too)
753        let expand_no_type = ExpandOp {
754            from_variable: "a".to_string(),
755            to_variable: "b".to_string(),
756            edge_variable: None,
757            direction: ExpandDirection::Outgoing,
758            edge_types: vec![],
759            min_hops: 1,
760            max_hops: Some(1),
761            input: Box::new(LogicalOperator::Empty),
762            path_alias: None,
763            path_mode: PathMode::Walk,
764        };
765        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
766
767        // Both should use global fanout = 7, so costs should be equal
768        assert!(
769            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
770            "Unknown edge type should use global fanout"
771        );
772    }
773
774    #[test]
775    fn test_cost_model_hash_join() {
776        let model = CostModel::new();
777        let join = JoinOp {
778            left: Box::new(LogicalOperator::Empty),
779            right: Box::new(LogicalOperator::Empty),
780            join_type: JoinType::Inner,
781            conditions: vec![JoinCondition {
782                left: LogicalExpression::Variable("a".to_string()),
783                right: LogicalExpression::Variable("b".to_string()),
784            }],
785        };
786        let cost = model.join_cost(&join, 10000.0);
787
788        // Hash join has CPU cost and memory cost
789        assert!(cost.cpu > 0.0);
790        assert!(cost.memory > 0.0);
791    }
792
793    #[test]
794    fn test_cost_model_cross_join() {
795        let model = CostModel::new();
796        let join = JoinOp {
797            left: Box::new(LogicalOperator::Empty),
798            right: Box::new(LogicalOperator::Empty),
799            join_type: JoinType::Cross,
800            conditions: vec![],
801        };
802        let cost = model.join_cost(&join, 1000000.0);
803
804        // Cross join is expensive
805        assert!(cost.cpu > 0.0);
806    }
807
808    #[test]
809    fn test_cost_model_semi_join() {
810        let model = CostModel::new();
811        let join = JoinOp {
812            left: Box::new(LogicalOperator::Empty),
813            right: Box::new(LogicalOperator::Empty),
814            join_type: JoinType::Semi,
815            conditions: vec![],
816        };
817        let cost_semi = model.join_cost(&join, 1000.0);
818
819        let inner_join = JoinOp {
820            left: Box::new(LogicalOperator::Empty),
821            right: Box::new(LogicalOperator::Empty),
822            join_type: JoinType::Inner,
823            conditions: vec![],
824        };
825        let cost_inner = model.join_cost(&inner_join, 1000.0);
826
827        // Semi join can be cheaper than inner join
828        assert!(cost_semi.cpu > 0.0);
829        assert!(cost_inner.cpu > 0.0);
830    }
831
832    #[test]
833    fn test_cost_model_aggregate() {
834        let model = CostModel::new();
835        let agg = AggregateOp {
836            group_by: vec![],
837            aggregates: vec![
838                AggregateExpr {
839                    function: AggregateFunction::Count,
840                    expression: None,
841                    expression2: None,
842                    distinct: false,
843                    alias: Some("cnt".to_string()),
844                    percentile: None,
845                    separator: None,
846                },
847                AggregateExpr {
848                    function: AggregateFunction::Sum,
849                    expression: Some(LogicalExpression::Variable("x".to_string())),
850                    expression2: None,
851                    distinct: false,
852                    alias: Some("total".to_string()),
853                    percentile: None,
854                    separator: None,
855                },
856            ],
857            input: Box::new(LogicalOperator::Empty),
858            having: None,
859        };
860        let cost = model.aggregate_cost(&agg, 1000.0);
861
862        // Aggregation has hash cost and memory cost
863        assert!(cost.cpu > 0.0);
864        assert!(cost.memory > 0.0);
865    }
866
867    #[test]
868    fn test_cost_model_distinct() {
869        let model = CostModel::new();
870        let distinct = DistinctOp {
871            input: Box::new(LogicalOperator::Empty),
872            columns: None,
873        };
874        let cost = model.distinct_cost(&distinct, 1000.0);
875
876        // Distinct uses hash set
877        assert!(cost.cpu > 0.0);
878        assert!(cost.memory > 0.0);
879    }
880
881    #[test]
882    fn test_cost_model_limit() {
883        let model = CostModel::new();
884        let limit = LimitOp {
885            count: 10.into(),
886            input: Box::new(LogicalOperator::Empty),
887        };
888        let cost = model.limit_cost(&limit, 1000.0);
889
890        // Limit is very cheap
891        assert!(cost.cpu > 0.0);
892        assert!(cost.cpu < 1.0); // Should be minimal
893    }
894
895    #[test]
896    fn test_cost_model_skip() {
897        let model = CostModel::new();
898        let skip = SkipOp {
899            count: 100.into(),
900            input: Box::new(LogicalOperator::Empty),
901        };
902        let cost = model.skip_cost(&skip, 1000.0);
903
904        // Skip must scan through skipped rows
905        assert!(cost.cpu > 0.0);
906    }
907
908    #[test]
909    fn test_cost_model_return() {
910        let model = CostModel::new();
911        let ret = ReturnOp {
912            items: vec![
913                ReturnItem {
914                    expression: LogicalExpression::Variable("a".to_string()),
915                    alias: None,
916                },
917                ReturnItem {
918                    expression: LogicalExpression::Variable("b".to_string()),
919                    alias: None,
920                },
921            ],
922            distinct: false,
923            input: Box::new(LogicalOperator::Empty),
924        };
925        let cost = model.return_cost(&ret, 1000.0);
926
927        // Return materializes results
928        assert!(cost.cpu > 0.0);
929    }
930
931    #[test]
932    fn test_cost_cheaper() {
933        let model = CostModel::new();
934        let cheap = Cost::cpu(10.0);
935        let expensive = Cost::cpu(100.0);
936
937        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
938        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
939    }
940
941    #[test]
942    fn test_cost_comparison_prefers_lower_total() {
943        let model = CostModel::new();
944        // High CPU, low IO
945        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
946        // Low CPU, high IO
947        let io_heavy = Cost::cpu(10.0).with_io(20.0);
948
949        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
950        assert!(cpu_heavy.total() < io_heavy.total());
951        assert_eq!(
952            model.cheaper(&cpu_heavy, &io_heavy).total(),
953            cpu_heavy.total()
954        );
955    }
956
957    #[test]
958    fn test_cost_model_sort_with_keys() {
959        let model = CostModel::new();
960        let sort_single = SortOp {
961            keys: vec![crate::query::plan::SortKey {
962                expression: LogicalExpression::Variable("a".to_string()),
963                order: SortOrder::Ascending,
964                nulls: None,
965            }],
966            input: Box::new(LogicalOperator::Empty),
967        };
968        let sort_multi = SortOp {
969            keys: vec![
970                crate::query::plan::SortKey {
971                    expression: LogicalExpression::Variable("a".to_string()),
972                    order: SortOrder::Ascending,
973                    nulls: None,
974                },
975                crate::query::plan::SortKey {
976                    expression: LogicalExpression::Variable("b".to_string()),
977                    order: SortOrder::Descending,
978                    nulls: None,
979                },
980            ],
981            input: Box::new(LogicalOperator::Empty),
982        };
983
984        let cost_single = model.sort_cost(&sort_single, 1000.0);
985        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
986
987        // More sort keys = more comparisons
988        assert!(cost_multi.cpu > cost_single.cpu);
989    }
990
991    #[test]
992    fn test_cost_model_empty_operator() {
993        let model = CostModel::new();
994        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
995        assert!((cost.total()).abs() < 0.001);
996    }
997
998    #[test]
999    fn test_cost_model_default() {
1000        let model = CostModel::default();
1001        let scan = NodeScanOp {
1002            variable: "n".to_string(),
1003            label: None,
1004            input: None,
1005        };
1006        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1007        assert!(cost.total() > 0.0);
1008    }
1009
1010    #[test]
1011    fn test_leapfrog_join_cost() {
1012        let model = CostModel::new();
1013
1014        // Three-way join (triangle pattern)
1015        let cardinalities = vec![1000.0, 1000.0, 1000.0];
1016        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1017
1018        // Should have CPU cost for materialization and intersection
1019        assert!(cost.cpu > 0.0);
1020        // Should have memory cost for trie storage
1021        assert!(cost.memory > 0.0);
1022    }
1023
1024    #[test]
1025    fn test_leapfrog_join_cost_empty() {
1026        let model = CostModel::new();
1027        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1028        assert!((cost.total()).abs() < 0.001);
1029    }
1030
1031    #[test]
1032    fn test_prefer_leapfrog_join_for_triangles() {
1033        let model = CostModel::new();
1034
1035        // Compare costs for triangle pattern
1036        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1037        let output = 1000.0;
1038
1039        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1040
1041        // Leapfrog should have reasonable cost for triangle patterns
1042        assert!(leapfrog_cost.cpu > 0.0);
1043        assert!(leapfrog_cost.memory > 0.0);
1044
1045        // The prefer_leapfrog_join method compares against hash cascade
1046        // Actual preference depends on specific cost parameters
1047        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1048        // Test that it returns a boolean (doesn't panic)
1049    }
1050
1051    #[test]
1052    fn test_prefer_leapfrog_join_binary_case() {
1053        let model = CostModel::new();
1054
1055        // Binary join should NOT prefer leapfrog (need 3+ relations)
1056        let cardinalities = vec![1000.0, 1000.0];
1057        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1058        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1059    }
1060
1061    #[test]
1062    fn test_factorized_benefit_single_hop() {
1063        let model = CostModel::new();
1064
1065        // Single hop: no factorization benefit
1066        let benefit = model.factorized_benefit(10.0, 1);
1067        assert!(
1068            (benefit - 1.0).abs() < 0.001,
1069            "Single hop should have no benefit"
1070        );
1071    }
1072
1073    #[test]
1074    fn test_factorized_benefit_multi_hop() {
1075        let model = CostModel::new();
1076
1077        // Multi-hop with high fanout
1078        let benefit = model.factorized_benefit(10.0, 3);
1079
1080        // The factorized_benefit returns a ratio capped at 1.0
1081        // For high fanout, factorized size / full size approaches 1/fanout
1082        // which is beneficial but the formula gives a value <= 1.0
1083        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1084        assert!(benefit > 0.0, "Benefit should be positive");
1085    }
1086
1087    #[test]
1088    fn test_factorized_benefit_low_fanout() {
1089        let model = CostModel::new();
1090
1091        // Low fanout: minimal benefit
1092        let benefit = model.factorized_benefit(1.5, 2);
1093        assert!(
1094            benefit <= 1.0,
1095            "Low fanout still benefits from factorization"
1096        );
1097    }
1098}