Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7    LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp,
8    VectorScanOp,
9};
10use std::collections::HashMap;
11
12/// Cost of an operation.
13///
14/// Represents the estimated resource consumption of executing an operator.
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17    /// Estimated CPU cycles / work units.
18    pub cpu: f64,
19    /// Estimated I/O operations (page reads).
20    pub io: f64,
21    /// Estimated memory usage in bytes.
22    pub memory: f64,
23    /// Network cost (for distributed queries).
24    pub network: f64,
25}
26
27impl Cost {
28    /// Creates a zero cost.
29    #[must_use]
30    pub fn zero() -> Self {
31        Self {
32            cpu: 0.0,
33            io: 0.0,
34            memory: 0.0,
35            network: 0.0,
36        }
37    }
38
39    /// Creates a cost from CPU work units.
40    #[must_use]
41    pub fn cpu(cpu: f64) -> Self {
42        Self {
43            cpu,
44            io: 0.0,
45            memory: 0.0,
46            network: 0.0,
47        }
48    }
49
50    /// Adds I/O cost.
51    #[must_use]
52    pub fn with_io(mut self, io: f64) -> Self {
53        self.io = io;
54        self
55    }
56
57    /// Adds memory cost.
58    #[must_use]
59    pub fn with_memory(mut self, memory: f64) -> Self {
60        self.memory = memory;
61        self
62    }
63
64    /// Returns the total weighted cost.
65    ///
66    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
67    #[must_use]
68    pub fn total(&self) -> f64 {
69        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70    }
71
72    /// Returns the total cost with custom weights.
73    #[must_use]
74    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76    }
77}
78
79impl std::ops::Add for Cost {
80    type Output = Self;
81
82    fn add(self, other: Self) -> Self {
83        Self {
84            cpu: self.cpu + other.cpu,
85            io: self.io + other.io,
86            memory: self.memory + other.memory,
87            network: self.network + other.network,
88        }
89    }
90}
91
92impl std::ops::AddAssign for Cost {
93    fn add_assign(&mut self, other: Self) {
94        self.cpu += other.cpu;
95        self.io += other.io;
96        self.memory += other.memory;
97        self.network += other.network;
98    }
99}
100
101/// Cost model for estimating operator costs.
102///
103/// Default constants are calibrated relative to each other:
104/// - Tuple scan is the baseline (1x)
105/// - Hash lookup is ~3x (hash computation + potential cache miss)
106/// - Sort comparison is ~2x (key extraction + comparison)
107/// - Distance computation is ~10x (vector math)
108pub struct CostModel {
109    /// Cost per tuple processed by CPU (baseline unit).
110    cpu_tuple_cost: f64,
111    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
112    hash_lookup_cost: f64,
113    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
114    sort_comparison_cost: f64,
115    /// Average tuple size in bytes (for IO estimation).
116    avg_tuple_size: f64,
117    /// Page size in bytes.
118    page_size: f64,
119    /// Global average edge fanout (fallback when per-type stats unavailable).
120    avg_fanout: f64,
121    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
122    edge_type_degrees: HashMap<String, (f64, f64)>,
123    /// Per-label node counts for accurate scan IO estimation.
124    label_cardinalities: HashMap<String, u64>,
125    /// Total node count in the graph.
126    total_nodes: u64,
127    /// Total edge count in the graph.
128    total_edges: u64,
129}
130
131impl CostModel {
132    /// Creates a new cost model with calibrated default parameters.
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            cpu_tuple_cost: 0.01,
137            hash_lookup_cost: 0.03,
138            sort_comparison_cost: 0.02,
139            avg_tuple_size: 100.0,
140            page_size: 8192.0,
141            avg_fanout: 10.0,
142            edge_type_degrees: HashMap::new(),
143            label_cardinalities: HashMap::new(),
144            total_nodes: 0,
145            total_edges: 0,
146        }
147    }
148
149    /// Sets the global average fanout from graph statistics.
150    #[must_use]
151    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153        self
154    }
155
156    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
157    ///
158    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
159    #[must_use]
160    pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161        self.edge_type_degrees = degrees;
162        self
163    }
164
165    /// Sets per-label node counts for accurate scan IO estimation.
166    #[must_use]
167    pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168        self.label_cardinalities = cardinalities;
169        self
170    }
171
172    /// Sets graph-level totals for cost estimation.
173    #[must_use]
174    pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175        self.total_nodes = total_nodes;
176        self.total_edges = total_edges;
177        self
178    }
179
180    /// Returns the fanout for a specific expand operation.
181    ///
182    /// Uses per-edge-type degree stats when available, falling back to the
183    /// global average fanout. For multiple edge types, sums per-type fanouts.
184    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185        if expand.edge_types.is_empty() {
186            return self.avg_fanout;
187        }
188
189        let mut total_fanout = 0.0;
190        let mut all_found = true;
191
192        for edge_type in &expand.edge_types {
193            if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194                total_fanout += match expand.direction {
195                    ExpandDirection::Outgoing => out_deg,
196                    ExpandDirection::Incoming => in_deg,
197                    ExpandDirection::Both => out_deg + in_deg,
198                };
199            } else {
200                all_found = false;
201                break;
202            }
203        }
204
205        if all_found {
206            total_fanout
207        } else {
208            self.avg_fanout
209        }
210    }
211
212    /// Estimates the cost of a logical operator.
213    #[must_use]
214    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215        match op {
216            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227            LogicalOperator::Empty => Cost::zero(),
228            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230            LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
232        }
233    }
234
235    /// Estimates the cost of a node scan.
236    ///
237    /// When label statistics are available, uses the actual label cardinality
238    /// for IO estimation (pages to read) rather than the optimizer's cardinality
239    /// estimate, which may already account for filter selectivity.
240    fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
241        // IO cost: based on how many nodes we actually need to scan from storage
242        let scan_size = if let Some(label) = &scan.label {
243            self.label_cardinalities
244                .get(label)
245                .map_or(cardinality, |&count| count as f64)
246        } else if self.total_nodes > 0 {
247            self.total_nodes as f64
248        } else {
249            cardinality
250        };
251        let pages = (scan_size * self.avg_tuple_size) / self.page_size;
252        // CPU cost: only pay for rows that pass the label filter
253        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
254    }
255
256    /// Estimates the cost of a filter operation.
257    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
258        // Filter cost is just predicate evaluation per tuple
259        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
260    }
261
262    /// Estimates the cost of a projection.
263    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
264        // Cost depends on number of expressions evaluated
265        let expr_count = project.projections.len() as f64;
266        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
267    }
268
269    /// Estimates the cost of an expand operation.
270    ///
271    /// Uses per-edge-type degree stats when available, otherwise falls back
272    /// to the global average fanout.
273    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
274        let fanout = self.fanout_for_expand(expand);
275        // Adjacency list lookup per input row
276        let lookup_cost = cardinality * self.hash_lookup_cost;
277        // Process each expanded output tuple
278        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
279        Cost::cpu(lookup_cost + output_cost)
280    }
281
282    /// Estimates the cost of a join operation.
283    ///
284    /// When child cardinalities are known (from recursive estimation), uses them
285    /// directly for build/probe cost. Otherwise falls back to sqrt approximation.
286    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
287        self.join_cost_with_children(join, cardinality, None, None)
288    }
289
290    /// Estimates join cost using actual child cardinalities.
291    fn join_cost_with_children(
292        &self,
293        join: &JoinOp,
294        cardinality: f64,
295        left_cardinality: Option<f64>,
296        right_cardinality: Option<f64>,
297    ) -> Cost {
298        match join.join_type {
299            JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
300            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
301                // Hash join: build the smaller side, probe with the larger
302                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
303                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
304
305                let build_cost = build_cardinality * self.hash_lookup_cost;
306                let memory_cost = build_cardinality * self.avg_tuple_size;
307                let probe_cost = probe_cardinality * self.hash_lookup_cost;
308                let output_cost = cardinality * self.cpu_tuple_cost;
309
310                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
311            }
312            JoinType::Semi | JoinType::Anti => {
313                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
314                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
315
316                let build_cost = build_cardinality * self.hash_lookup_cost;
317                let probe_cost = probe_cardinality * self.hash_lookup_cost;
318
319                Cost::cpu(build_cost + probe_cost)
320                    .with_memory(build_cardinality * self.avg_tuple_size)
321            }
322        }
323    }
324
325    /// Estimates the cost of a multi-way (leapfrog) join.
326    ///
327    /// Delegates to `leapfrog_join_cost` using per-input cardinality estimates
328    /// derived from the output cardinality divided equally among inputs.
329    fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
330        let n = mwj.inputs.len();
331        if n == 0 {
332            return Cost::zero();
333        }
334        // Approximate per-input cardinalities: assume each input contributes
335        // cardinality^(1/n) rows (AGM-style uniform assumption)
336        let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
337        let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
338        self.leapfrog_join_cost(n, &cardinalities, cardinality)
339    }
340
341    /// Estimates the cost of an aggregation.
342    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
343        // Hash aggregation cost
344        let hash_cost = cardinality * self.hash_lookup_cost;
345
346        // Aggregate function evaluation
347        let agg_count = agg.aggregates.len() as f64;
348        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
349
350        // Memory for hash table (estimated distinct groups)
351        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
352        let memory_cost = distinct_groups * self.avg_tuple_size;
353
354        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
355    }
356
357    /// Estimates the cost of a sort operation.
358    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
359        if cardinality <= 1.0 {
360            return Cost::zero();
361        }
362
363        // Sort is O(n log n) comparisons
364        let comparisons = cardinality * cardinality.log2();
365        let key_count = sort.keys.len() as f64;
366
367        // Memory for sorting (full input materialization)
368        let memory_cost = cardinality * self.avg_tuple_size;
369
370        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
371    }
372
373    /// Estimates the cost of a distinct operation.
374    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
375        // Hash-based distinct
376        let hash_cost = cardinality * self.hash_lookup_cost;
377        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
378
379        Cost::cpu(hash_cost).with_memory(memory_cost)
380    }
381
382    /// Estimates the cost of a limit operation.
383    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
384        // Limit is very cheap - just counting
385        Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
386    }
387
388    /// Estimates the cost of a skip operation.
389    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
390        // Skip requires scanning through skipped rows
391        Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
392    }
393
394    /// Estimates the cost of a return operation.
395    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
396        // Return materializes results
397        let expr_count = ret.items.len() as f64;
398        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
399    }
400
401    /// Estimates the cost of a vector scan operation.
402    ///
403    /// HNSW index search is O(log N) per query, while brute-force is O(N).
404    /// This estimates the HNSW case with ef search parameter.
405    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
406        // k determines output cardinality
407        let k = scan.k as f64;
408
409        // HNSW search cost: O(ef * log(N)) distance computations
410        // Assume ef = 64 (default), N = cardinality
411        let ef = 64.0;
412        let n = cardinality.max(1.0);
413        let search_cost = if scan.index_name.is_some() {
414            // HNSW: O(ef * log N)
415            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
416        } else {
417            // Brute-force: O(N)
418            n * self.cpu_tuple_cost * 10.0
419        };
420
421        // Memory for candidate heap
422        let memory = k * self.avg_tuple_size * 2.0;
423
424        Cost::cpu(search_cost).with_memory(memory)
425    }
426
427    /// Estimates the cost of a vector join operation.
428    ///
429    /// Vector join performs k-NN search for each input row.
430    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
431        let k = join.k as f64;
432
433        // Each input row triggers a vector search
434        // Assume brute-force for hybrid queries (no index specified typically)
435        let per_row_search_cost = if join.index_name.is_some() {
436            // HNSW: O(ef * log N)
437            let ef = 64.0;
438            let n = cardinality.max(1.0);
439            ef * n.ln() * self.cpu_tuple_cost * 10.0
440        } else {
441            // Brute-force: O(N) per input row
442            cardinality * self.cpu_tuple_cost * 10.0
443        };
444
445        // Total cost: input_rows * search_cost
446        // For vector join, cardinality is typically input cardinality * k
447        let input_cardinality = (cardinality / k).max(1.0);
448        let total_search_cost = input_cardinality * per_row_search_cost;
449
450        // Memory for results
451        let memory = cardinality * self.avg_tuple_size;
452
453        Cost::cpu(total_search_cost).with_memory(memory)
454    }
455
456    /// Estimates the total cost of an operator tree recursively.
457    ///
458    /// Walks the entire plan tree, computing per-operator cost at each level
459    /// using the cardinality estimator for accurate child cardinalities.
460    /// Returns the sum of all operator costs in the tree.
461    #[must_use]
462    pub fn estimate_tree(
463        &self,
464        op: &LogicalOperator,
465        card_estimator: &super::CardinalityEstimator,
466    ) -> Cost {
467        self.estimate_tree_inner(op, card_estimator)
468    }
469
470    fn estimate_tree_inner(
471        &self,
472        op: &LogicalOperator,
473        card_est: &super::CardinalityEstimator,
474    ) -> Cost {
475        let cardinality = card_est.estimate(op);
476
477        match op {
478            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
479            LogicalOperator::Filter(filter) => {
480                let child_cost = self.estimate_tree_inner(&filter.input, card_est);
481                child_cost + self.filter_cost(filter, cardinality)
482            }
483            LogicalOperator::Project(project) => {
484                let child_cost = self.estimate_tree_inner(&project.input, card_est);
485                child_cost + self.project_cost(project, cardinality)
486            }
487            LogicalOperator::Expand(expand) => {
488                let child_cost = self.estimate_tree_inner(&expand.input, card_est);
489                child_cost + self.expand_cost(expand, cardinality)
490            }
491            LogicalOperator::Join(join) => {
492                let left_cost = self.estimate_tree_inner(&join.left, card_est);
493                let right_cost = self.estimate_tree_inner(&join.right, card_est);
494                let left_card = card_est.estimate(&join.left);
495                let right_card = card_est.estimate(&join.right);
496                let join_cost = self.join_cost_with_children(
497                    join,
498                    cardinality,
499                    Some(left_card),
500                    Some(right_card),
501                );
502                left_cost + right_cost + join_cost
503            }
504            LogicalOperator::Aggregate(agg) => {
505                let child_cost = self.estimate_tree_inner(&agg.input, card_est);
506                child_cost + self.aggregate_cost(agg, cardinality)
507            }
508            LogicalOperator::Sort(sort) => {
509                let child_cost = self.estimate_tree_inner(&sort.input, card_est);
510                child_cost + self.sort_cost(sort, cardinality)
511            }
512            LogicalOperator::Distinct(distinct) => {
513                let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
514                child_cost + self.distinct_cost(distinct, cardinality)
515            }
516            LogicalOperator::Limit(limit) => {
517                let child_cost = self.estimate_tree_inner(&limit.input, card_est);
518                child_cost + self.limit_cost(limit, cardinality)
519            }
520            LogicalOperator::Skip(skip) => {
521                let child_cost = self.estimate_tree_inner(&skip.input, card_est);
522                child_cost + self.skip_cost(skip, cardinality)
523            }
524            LogicalOperator::Return(ret) => {
525                let child_cost = self.estimate_tree_inner(&ret.input, card_est);
526                child_cost + self.return_cost(ret, cardinality)
527            }
528            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
529            LogicalOperator::VectorJoin(join) => {
530                let child_cost = self.estimate_tree_inner(&join.input, card_est);
531                child_cost + self.vector_join_cost(join, cardinality)
532            }
533            LogicalOperator::MultiWayJoin(mwj) => {
534                let mut children_cost = Cost::zero();
535                for input in &mwj.inputs {
536                    children_cost += self.estimate_tree_inner(input, card_est);
537                }
538                children_cost + self.multi_way_join_cost(mwj, cardinality)
539            }
540            LogicalOperator::Empty => Cost::zero(),
541            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
542        }
543    }
544
545    /// Compares two costs and returns the cheaper one.
546    #[must_use]
547    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
548        if a.total() <= b.total() { a } else { b }
549    }
550
551    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
552    ///
553    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
554    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
555    /// all relations simultaneously using sorted iterators.
556    ///
557    /// # Arguments
558    /// * `num_relations` - Number of relations participating in the join
559    /// * `cardinalities` - Cardinality of each input relation
560    /// * `output_cardinality` - Expected output cardinality
561    ///
562    /// # Cost Model
563    /// - Materialization: O(sum of cardinalities) to build trie indexes
564    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
565    /// - Memory: Trie storage for all inputs
566    #[must_use]
567    pub fn leapfrog_join_cost(
568        &self,
569        num_relations: usize,
570        cardinalities: &[f64],
571        output_cardinality: f64,
572    ) -> Cost {
573        if cardinalities.is_empty() {
574            return Cost::zero();
575        }
576
577        let total_input: f64 = cardinalities.iter().sum();
578        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
579
580        // Materialization phase: build trie indexes for each input
581        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
582
583        // Intersection phase: leapfrog seeks are O(log n) per relation
584        let seek_cost = if min_card > 1.0 {
585            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
586        } else {
587            output_cardinality * self.cpu_tuple_cost
588        };
589
590        // Output materialization
591        let output_cost = output_cardinality * self.cpu_tuple_cost;
592
593        // Memory: trie storage (roughly 2x input size for sorted index)
594        let memory = total_input * self.avg_tuple_size * 2.0;
595
596        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
597    }
598
599    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
600    ///
601    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
602    #[must_use]
603    pub fn prefer_leapfrog_join(
604        &self,
605        num_relations: usize,
606        cardinalities: &[f64],
607        output_cardinality: f64,
608    ) -> bool {
609        if num_relations < 3 || cardinalities.len() < 3 {
610            // Leapfrog is only beneficial for multi-way joins (3+)
611            return false;
612        }
613
614        let leapfrog_cost =
615            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
616
617        // Estimate cascade of binary hash joins
618        // For N relations, we need N-1 joins
619        // Each join produces intermediate results that feed the next
620        let mut hash_cascade_cost = Cost::zero();
621        let mut intermediate_cardinality = cardinalities[0];
622
623        for card in &cardinalities[1..] {
624            // Hash join cost: build + probe + output
625            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
626            let join = JoinOp {
627                left: Box::new(LogicalOperator::Empty),
628                right: Box::new(LogicalOperator::Empty),
629                join_type: JoinType::Inner,
630                conditions: vec![],
631            };
632            hash_cascade_cost += self.join_cost(&join, join_output);
633            intermediate_cardinality = join_output;
634        }
635
636        leapfrog_cost.total() < hash_cascade_cost.total()
637    }
638
639    /// Estimates cost for factorized execution (compressed intermediate results).
640    ///
641    /// Factorized execution avoids materializing full cross products by keeping
642    /// results in a compressed "factorized" form. This is beneficial for multi-hop
643    /// traversals where intermediate results can explode.
644    ///
645    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
646    #[must_use]
647    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
648        if num_hops <= 1 || avg_fanout <= 1.0 {
649            return 1.0; // No benefit for single hop or low fanout
650        }
651
652        // Factorized representation compresses repeated prefixes
653        // Compression ratio improves with higher fanout and more hops
654        // Full materialization: fanout^hops
655        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
656
657        let full_size = avg_fanout.powi(num_hops as i32);
658        let factorized_size = if avg_fanout > 1.0 {
659            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
660        } else {
661            num_hops as f64
662        };
663
664        (factorized_size / full_size).min(1.0)
665    }
666}
667
668impl Default for CostModel {
669    fn default() -> Self {
670        Self::new()
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use crate::query::plan::{
678        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
679        PathMode, Projection, ReturnItem, SortOrder,
680    };
681
682    #[test]
683    fn test_cost_addition() {
684        let a = Cost::cpu(10.0).with_io(5.0);
685        let b = Cost::cpu(20.0).with_memory(100.0);
686        let c = a + b;
687
688        assert!((c.cpu - 30.0).abs() < 0.001);
689        assert!((c.io - 5.0).abs() < 0.001);
690        assert!((c.memory - 100.0).abs() < 0.001);
691    }
692
693    #[test]
694    fn test_cost_total() {
695        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
696        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
697        assert!((cost.total() - 30.0).abs() < 0.001);
698    }
699
700    #[test]
701    fn test_cost_model_node_scan() {
702        let model = CostModel::new();
703        let scan = NodeScanOp {
704            variable: "n".to_string(),
705            label: Some("Person".to_string()),
706            input: None,
707        };
708        let cost = model.node_scan_cost(&scan, 1000.0);
709
710        assert!(cost.cpu > 0.0);
711        assert!(cost.io > 0.0);
712    }
713
714    #[test]
715    fn test_cost_model_sort() {
716        let model = CostModel::new();
717        let sort = SortOp {
718            keys: vec![],
719            input: Box::new(LogicalOperator::Empty),
720        };
721
722        let cost_100 = model.sort_cost(&sort, 100.0);
723        let cost_1000 = model.sort_cost(&sort, 1000.0);
724
725        // Sorting 1000 rows should be more expensive than 100 rows
726        assert!(cost_1000.total() > cost_100.total());
727    }
728
729    #[test]
730    fn test_cost_zero() {
731        let cost = Cost::zero();
732        assert!((cost.cpu).abs() < 0.001);
733        assert!((cost.io).abs() < 0.001);
734        assert!((cost.memory).abs() < 0.001);
735        assert!((cost.network).abs() < 0.001);
736        assert!((cost.total()).abs() < 0.001);
737    }
738
739    #[test]
740    fn test_cost_add_assign() {
741        let mut cost = Cost::cpu(10.0);
742        cost += Cost::cpu(5.0).with_io(2.0);
743        assert!((cost.cpu - 15.0).abs() < 0.001);
744        assert!((cost.io - 2.0).abs() < 0.001);
745    }
746
747    #[test]
748    fn test_cost_total_weighted() {
749        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
750        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
751        let total = cost.total_weighted(2.0, 5.0, 0.5);
752        assert!((total - 80.0).abs() < 0.001);
753    }
754
755    #[test]
756    fn test_cost_model_filter() {
757        let model = CostModel::new();
758        let filter = FilterOp {
759            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
760            input: Box::new(LogicalOperator::Empty),
761            pushdown_hint: None,
762        };
763        let cost = model.filter_cost(&filter, 1000.0);
764
765        // Filter cost is CPU only
766        assert!(cost.cpu > 0.0);
767        assert!((cost.io).abs() < 0.001);
768    }
769
770    #[test]
771    fn test_cost_model_project() {
772        let model = CostModel::new();
773        let project = ProjectOp {
774            projections: vec![
775                Projection {
776                    expression: LogicalExpression::Variable("a".to_string()),
777                    alias: None,
778                },
779                Projection {
780                    expression: LogicalExpression::Variable("b".to_string()),
781                    alias: None,
782                },
783            ],
784            input: Box::new(LogicalOperator::Empty),
785            pass_through_input: false,
786        };
787        let cost = model.project_cost(&project, 1000.0);
788
789        // Cost should scale with number of projections
790        assert!(cost.cpu > 0.0);
791    }
792
793    #[test]
794    fn test_cost_model_expand() {
795        let model = CostModel::new();
796        let expand = ExpandOp {
797            from_variable: "a".to_string(),
798            to_variable: "b".to_string(),
799            edge_variable: None,
800            direction: ExpandDirection::Outgoing,
801            edge_types: vec![],
802            min_hops: 1,
803            max_hops: Some(1),
804            input: Box::new(LogicalOperator::Empty),
805            path_alias: None,
806            path_mode: PathMode::Walk,
807        };
808        let cost = model.expand_cost(&expand, 1000.0);
809
810        // Expand involves hash lookups and output generation
811        assert!(cost.cpu > 0.0);
812    }
813
814    #[test]
815    fn test_cost_model_expand_with_edge_type_stats() {
816        let mut degrees = std::collections::HashMap::new();
817        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
818        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
819
820        let model = CostModel::new().with_edge_type_degrees(degrees);
821
822        // Outgoing KNOWS: fanout = 5
823        let knows_out = ExpandOp {
824            from_variable: "a".to_string(),
825            to_variable: "b".to_string(),
826            edge_variable: None,
827            direction: ExpandDirection::Outgoing,
828            edge_types: vec!["KNOWS".to_string()],
829            min_hops: 1,
830            max_hops: Some(1),
831            input: Box::new(LogicalOperator::Empty),
832            path_alias: None,
833            path_mode: PathMode::Walk,
834        };
835        let cost_knows = model.expand_cost(&knows_out, 1000.0);
836
837        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
838        let works_out = ExpandOp {
839            from_variable: "a".to_string(),
840            to_variable: "b".to_string(),
841            edge_variable: None,
842            direction: ExpandDirection::Outgoing,
843            edge_types: vec!["WORKS_AT".to_string()],
844            min_hops: 1,
845            max_hops: Some(1),
846            input: Box::new(LogicalOperator::Empty),
847            path_alias: None,
848            path_mode: PathMode::Walk,
849        };
850        let cost_works = model.expand_cost(&works_out, 1000.0);
851
852        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
853        assert!(
854            cost_knows.cpu > cost_works.cpu,
855            "KNOWS(5) should cost more than WORKS_AT(1)"
856        );
857
858        // Incoming WORKS_AT: fanout = 50 (company has many employees)
859        let works_in = ExpandOp {
860            from_variable: "c".to_string(),
861            to_variable: "p".to_string(),
862            edge_variable: None,
863            direction: ExpandDirection::Incoming,
864            edge_types: vec!["WORKS_AT".to_string()],
865            min_hops: 1,
866            max_hops: Some(1),
867            input: Box::new(LogicalOperator::Empty),
868            path_alias: None,
869            path_mode: PathMode::Walk,
870        };
871        let cost_works_in = model.expand_cost(&works_in, 1000.0);
872
873        // Incoming WORKS_AT (fanout=50) should be most expensive
874        assert!(
875            cost_works_in.cpu > cost_knows.cpu,
876            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
877        );
878    }
879
880    #[test]
881    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
882        let model = CostModel::new().with_avg_fanout(7.0);
883        let expand = ExpandOp {
884            from_variable: "a".to_string(),
885            to_variable: "b".to_string(),
886            edge_variable: None,
887            direction: ExpandDirection::Outgoing,
888            edge_types: vec!["UNKNOWN_TYPE".to_string()],
889            min_hops: 1,
890            max_hops: Some(1),
891            input: Box::new(LogicalOperator::Empty),
892            path_alias: None,
893            path_mode: PathMode::Walk,
894        };
895        let cost_unknown = model.expand_cost(&expand, 1000.0);
896
897        // Without edge type (uses global fanout too)
898        let expand_no_type = ExpandOp {
899            from_variable: "a".to_string(),
900            to_variable: "b".to_string(),
901            edge_variable: None,
902            direction: ExpandDirection::Outgoing,
903            edge_types: vec![],
904            min_hops: 1,
905            max_hops: Some(1),
906            input: Box::new(LogicalOperator::Empty),
907            path_alias: None,
908            path_mode: PathMode::Walk,
909        };
910        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
911
912        // Both should use global fanout = 7, so costs should be equal
913        assert!(
914            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
915            "Unknown edge type should use global fanout"
916        );
917    }
918
919    #[test]
920    fn test_cost_model_hash_join() {
921        let model = CostModel::new();
922        let join = JoinOp {
923            left: Box::new(LogicalOperator::Empty),
924            right: Box::new(LogicalOperator::Empty),
925            join_type: JoinType::Inner,
926            conditions: vec![JoinCondition {
927                left: LogicalExpression::Variable("a".to_string()),
928                right: LogicalExpression::Variable("b".to_string()),
929            }],
930        };
931        let cost = model.join_cost(&join, 10000.0);
932
933        // Hash join has CPU cost and memory cost
934        assert!(cost.cpu > 0.0);
935        assert!(cost.memory > 0.0);
936    }
937
938    #[test]
939    fn test_cost_model_cross_join() {
940        let model = CostModel::new();
941        let join = JoinOp {
942            left: Box::new(LogicalOperator::Empty),
943            right: Box::new(LogicalOperator::Empty),
944            join_type: JoinType::Cross,
945            conditions: vec![],
946        };
947        let cost = model.join_cost(&join, 1000000.0);
948
949        // Cross join is expensive
950        assert!(cost.cpu > 0.0);
951    }
952
953    #[test]
954    fn test_cost_model_semi_join() {
955        let model = CostModel::new();
956        let join = JoinOp {
957            left: Box::new(LogicalOperator::Empty),
958            right: Box::new(LogicalOperator::Empty),
959            join_type: JoinType::Semi,
960            conditions: vec![],
961        };
962        let cost_semi = model.join_cost(&join, 1000.0);
963
964        let inner_join = JoinOp {
965            left: Box::new(LogicalOperator::Empty),
966            right: Box::new(LogicalOperator::Empty),
967            join_type: JoinType::Inner,
968            conditions: vec![],
969        };
970        let cost_inner = model.join_cost(&inner_join, 1000.0);
971
972        // Semi join can be cheaper than inner join
973        assert!(cost_semi.cpu > 0.0);
974        assert!(cost_inner.cpu > 0.0);
975    }
976
977    #[test]
978    fn test_cost_model_aggregate() {
979        let model = CostModel::new();
980        let agg = AggregateOp {
981            group_by: vec![],
982            aggregates: vec![
983                AggregateExpr {
984                    function: AggregateFunction::Count,
985                    expression: None,
986                    expression2: None,
987                    distinct: false,
988                    alias: Some("cnt".to_string()),
989                    percentile: None,
990                    separator: None,
991                },
992                AggregateExpr {
993                    function: AggregateFunction::Sum,
994                    expression: Some(LogicalExpression::Variable("x".to_string())),
995                    expression2: None,
996                    distinct: false,
997                    alias: Some("total".to_string()),
998                    percentile: None,
999                    separator: None,
1000                },
1001            ],
1002            input: Box::new(LogicalOperator::Empty),
1003            having: None,
1004        };
1005        let cost = model.aggregate_cost(&agg, 1000.0);
1006
1007        // Aggregation has hash cost and memory cost
1008        assert!(cost.cpu > 0.0);
1009        assert!(cost.memory > 0.0);
1010    }
1011
1012    #[test]
1013    fn test_cost_model_distinct() {
1014        let model = CostModel::new();
1015        let distinct = DistinctOp {
1016            input: Box::new(LogicalOperator::Empty),
1017            columns: None,
1018        };
1019        let cost = model.distinct_cost(&distinct, 1000.0);
1020
1021        // Distinct uses hash set
1022        assert!(cost.cpu > 0.0);
1023        assert!(cost.memory > 0.0);
1024    }
1025
1026    #[test]
1027    fn test_cost_model_limit() {
1028        let model = CostModel::new();
1029        let limit = LimitOp {
1030            count: 10.into(),
1031            input: Box::new(LogicalOperator::Empty),
1032        };
1033        let cost = model.limit_cost(&limit, 1000.0);
1034
1035        // Limit is very cheap
1036        assert!(cost.cpu > 0.0);
1037        assert!(cost.cpu < 1.0); // Should be minimal
1038    }
1039
1040    #[test]
1041    fn test_cost_model_skip() {
1042        let model = CostModel::new();
1043        let skip = SkipOp {
1044            count: 100.into(),
1045            input: Box::new(LogicalOperator::Empty),
1046        };
1047        let cost = model.skip_cost(&skip, 1000.0);
1048
1049        // Skip must scan through skipped rows
1050        assert!(cost.cpu > 0.0);
1051    }
1052
1053    #[test]
1054    fn test_cost_model_return() {
1055        let model = CostModel::new();
1056        let ret = ReturnOp {
1057            items: vec![
1058                ReturnItem {
1059                    expression: LogicalExpression::Variable("a".to_string()),
1060                    alias: None,
1061                },
1062                ReturnItem {
1063                    expression: LogicalExpression::Variable("b".to_string()),
1064                    alias: None,
1065                },
1066            ],
1067            distinct: false,
1068            input: Box::new(LogicalOperator::Empty),
1069        };
1070        let cost = model.return_cost(&ret, 1000.0);
1071
1072        // Return materializes results
1073        assert!(cost.cpu > 0.0);
1074    }
1075
1076    #[test]
1077    fn test_cost_cheaper() {
1078        let model = CostModel::new();
1079        let cheap = Cost::cpu(10.0);
1080        let expensive = Cost::cpu(100.0);
1081
1082        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1083        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1084    }
1085
1086    #[test]
1087    fn test_cost_comparison_prefers_lower_total() {
1088        let model = CostModel::new();
1089        // High CPU, low IO
1090        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1091        // Low CPU, high IO
1092        let io_heavy = Cost::cpu(10.0).with_io(20.0);
1093
1094        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
1095        assert!(cpu_heavy.total() < io_heavy.total());
1096        assert_eq!(
1097            model.cheaper(&cpu_heavy, &io_heavy).total(),
1098            cpu_heavy.total()
1099        );
1100    }
1101
1102    #[test]
1103    fn test_cost_model_sort_with_keys() {
1104        let model = CostModel::new();
1105        let sort_single = SortOp {
1106            keys: vec![crate::query::plan::SortKey {
1107                expression: LogicalExpression::Variable("a".to_string()),
1108                order: SortOrder::Ascending,
1109                nulls: None,
1110            }],
1111            input: Box::new(LogicalOperator::Empty),
1112        };
1113        let sort_multi = SortOp {
1114            keys: vec![
1115                crate::query::plan::SortKey {
1116                    expression: LogicalExpression::Variable("a".to_string()),
1117                    order: SortOrder::Ascending,
1118                    nulls: None,
1119                },
1120                crate::query::plan::SortKey {
1121                    expression: LogicalExpression::Variable("b".to_string()),
1122                    order: SortOrder::Descending,
1123                    nulls: None,
1124                },
1125            ],
1126            input: Box::new(LogicalOperator::Empty),
1127        };
1128
1129        let cost_single = model.sort_cost(&sort_single, 1000.0);
1130        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1131
1132        // More sort keys = more comparisons
1133        assert!(cost_multi.cpu > cost_single.cpu);
1134    }
1135
1136    #[test]
1137    fn test_cost_model_empty_operator() {
1138        let model = CostModel::new();
1139        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1140        assert!((cost.total()).abs() < 0.001);
1141    }
1142
1143    #[test]
1144    fn test_cost_model_default() {
1145        let model = CostModel::default();
1146        let scan = NodeScanOp {
1147            variable: "n".to_string(),
1148            label: None,
1149            input: None,
1150        };
1151        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1152        assert!(cost.total() > 0.0);
1153    }
1154
1155    #[test]
1156    fn test_leapfrog_join_cost() {
1157        let model = CostModel::new();
1158
1159        // Three-way join (triangle pattern)
1160        let cardinalities = vec![1000.0, 1000.0, 1000.0];
1161        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1162
1163        // Should have CPU cost for materialization and intersection
1164        assert!(cost.cpu > 0.0);
1165        // Should have memory cost for trie storage
1166        assert!(cost.memory > 0.0);
1167    }
1168
1169    #[test]
1170    fn test_leapfrog_join_cost_empty() {
1171        let model = CostModel::new();
1172        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1173        assert!((cost.total()).abs() < 0.001);
1174    }
1175
1176    #[test]
1177    fn test_prefer_leapfrog_join_for_triangles() {
1178        let model = CostModel::new();
1179
1180        // Compare costs for triangle pattern
1181        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1182        let output = 1000.0;
1183
1184        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1185
1186        // Leapfrog should have reasonable cost for triangle patterns
1187        assert!(leapfrog_cost.cpu > 0.0);
1188        assert!(leapfrog_cost.memory > 0.0);
1189
1190        // The prefer_leapfrog_join method compares against hash cascade
1191        // Actual preference depends on specific cost parameters
1192        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1193        // Test that it returns a boolean (doesn't panic)
1194    }
1195
1196    #[test]
1197    fn test_prefer_leapfrog_join_binary_case() {
1198        let model = CostModel::new();
1199
1200        // Binary join should NOT prefer leapfrog (need 3+ relations)
1201        let cardinalities = vec![1000.0, 1000.0];
1202        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1203        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1204    }
1205
1206    #[test]
1207    fn test_factorized_benefit_single_hop() {
1208        let model = CostModel::new();
1209
1210        // Single hop: no factorization benefit
1211        let benefit = model.factorized_benefit(10.0, 1);
1212        assert!(
1213            (benefit - 1.0).abs() < 0.001,
1214            "Single hop should have no benefit"
1215        );
1216    }
1217
1218    #[test]
1219    fn test_factorized_benefit_multi_hop() {
1220        let model = CostModel::new();
1221
1222        // Multi-hop with high fanout
1223        let benefit = model.factorized_benefit(10.0, 3);
1224
1225        // The factorized_benefit returns a ratio capped at 1.0
1226        // For high fanout, factorized size / full size approaches 1/fanout
1227        // which is beneficial but the formula gives a value <= 1.0
1228        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1229        assert!(benefit > 0.0, "Benefit should be positive");
1230    }
1231
1232    #[test]
1233    fn test_factorized_benefit_low_fanout() {
1234        let model = CostModel::new();
1235
1236        // Low fanout: minimal benefit
1237        let benefit = model.factorized_benefit(1.5, 2);
1238        assert!(
1239            benefit <= 1.0,
1240            "Low fanout still benefits from factorization"
1241        );
1242    }
1243
1244    #[test]
1245    fn test_node_scan_uses_label_cardinality_for_io() {
1246        let mut label_cards = std::collections::HashMap::new();
1247        label_cards.insert("Person".to_string(), 500_u64);
1248        label_cards.insert("Company".to_string(), 50_u64);
1249
1250        let model = CostModel::new()
1251            .with_label_cardinalities(label_cards)
1252            .with_graph_totals(550, 1000);
1253
1254        let person_scan = NodeScanOp {
1255            variable: "n".to_string(),
1256            label: Some("Person".to_string()),
1257            input: None,
1258        };
1259        let company_scan = NodeScanOp {
1260            variable: "n".to_string(),
1261            label: Some("Company".to_string()),
1262            input: None,
1263        };
1264
1265        let person_cost = model.node_scan_cost(&person_scan, 500.0);
1266        let company_cost = model.node_scan_cost(&company_scan, 50.0);
1267
1268        // Person scan reads 10x more pages than Company scan
1269        assert!(
1270            person_cost.io > company_cost.io * 5.0,
1271            "Person ({}) should have much higher IO than Company ({})",
1272            person_cost.io,
1273            company_cost.io
1274        );
1275    }
1276
1277    #[test]
1278    fn test_node_scan_unlabeled_uses_total_nodes() {
1279        let model = CostModel::new().with_graph_totals(10_000, 50_000);
1280
1281        let scan = NodeScanOp {
1282            variable: "n".to_string(),
1283            label: None,
1284            input: None,
1285        };
1286
1287        let cost = model.node_scan_cost(&scan, 10_000.0);
1288        let expected_pages = (10_000.0 * 100.0) / 8192.0;
1289        assert!(
1290            (cost.io - expected_pages).abs() < 0.1,
1291            "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1292            cost.io,
1293            expected_pages
1294        );
1295    }
1296
1297    #[test]
1298    fn test_join_cost_with_actual_child_cardinalities() {
1299        let model = CostModel::new();
1300        let join = JoinOp {
1301            left: Box::new(LogicalOperator::Empty),
1302            right: Box::new(LogicalOperator::Empty),
1303            join_type: JoinType::Inner,
1304            conditions: vec![JoinCondition {
1305                left: LogicalExpression::Variable("a".to_string()),
1306                right: LogicalExpression::Variable("b".to_string()),
1307            }],
1308        };
1309
1310        // With actual child cardinalities (100 left, 10000 right)
1311        let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1312
1313        // With sqrt fallback (sqrt(500) ~ 22.4 for both sides)
1314        let cost_sqrt = model.join_cost(&join, 500.0);
1315
1316        // Actual build side (100) is larger than sqrt(500) ~ 22.4, so
1317        // actual cost should be higher for build, but the probe side (10000)
1318        // should dominate
1319        assert!(
1320            cost_actual.cpu > cost_sqrt.cpu,
1321            "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1322            cost_actual.cpu,
1323            cost_sqrt.cpu
1324        );
1325    }
1326
1327    #[test]
1328    fn test_expand_multi_edge_types() {
1329        let mut degrees = std::collections::HashMap::new();
1330        degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1331        degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1332
1333        let model = CostModel::new().with_edge_type_degrees(degrees);
1334
1335        // Multi-type outgoing: KNOWS(5) + FOLLOWS(20) = 25
1336        let multi_expand = ExpandOp {
1337            from_variable: "a".to_string(),
1338            to_variable: "b".to_string(),
1339            edge_variable: None,
1340            direction: ExpandDirection::Outgoing,
1341            edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1342            min_hops: 1,
1343            max_hops: Some(1),
1344            input: Box::new(LogicalOperator::Empty),
1345            path_alias: None,
1346            path_mode: PathMode::Walk,
1347        };
1348        let multi_cost = model.expand_cost(&multi_expand, 100.0);
1349
1350        // Single type: KNOWS(5) only
1351        let single_expand = ExpandOp {
1352            from_variable: "a".to_string(),
1353            to_variable: "b".to_string(),
1354            edge_variable: None,
1355            direction: ExpandDirection::Outgoing,
1356            edge_types: vec!["KNOWS".to_string()],
1357            min_hops: 1,
1358            max_hops: Some(1),
1359            input: Box::new(LogicalOperator::Empty),
1360            path_alias: None,
1361            path_mode: PathMode::Walk,
1362        };
1363        let single_cost = model.expand_cost(&single_expand, 100.0);
1364
1365        // Multi-type (fanout=25) should be more expensive than single (fanout=5)
1366        assert!(
1367            multi_cost.cpu > single_cost.cpu * 3.0,
1368            "Multi-type fanout ({}) should be much higher than single-type ({})",
1369            multi_cost.cpu,
1370            single_cost.cpu
1371        );
1372    }
1373
1374    #[test]
1375    fn test_recursive_tree_cost() {
1376        use crate::query::optimizer::CardinalityEstimator;
1377
1378        let mut label_cards = std::collections::HashMap::new();
1379        label_cards.insert("Person".to_string(), 1000_u64);
1380
1381        let model = CostModel::new()
1382            .with_label_cardinalities(label_cards)
1383            .with_graph_totals(1000, 5000)
1384            .with_avg_fanout(5.0);
1385
1386        let mut card_est = CardinalityEstimator::new();
1387        card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1388
1389        // Build a plan: NodeScan -> Filter -> Return
1390        let plan = LogicalOperator::Return(ReturnOp {
1391            items: vec![ReturnItem {
1392                expression: LogicalExpression::Variable("n".to_string()),
1393                alias: None,
1394            }],
1395            distinct: false,
1396            input: Box::new(LogicalOperator::Filter(FilterOp {
1397                predicate: LogicalExpression::Binary {
1398                    left: Box::new(LogicalExpression::Property {
1399                        variable: "n".to_string(),
1400                        property: "age".to_string(),
1401                    }),
1402                    op: crate::query::plan::BinaryOp::Gt,
1403                    right: Box::new(LogicalExpression::Literal(
1404                        grafeo_common::types::Value::Int64(30),
1405                    )),
1406                },
1407                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1408                    variable: "n".to_string(),
1409                    label: Some("Person".to_string()),
1410                    input: None,
1411                })),
1412                pushdown_hint: None,
1413            })),
1414        });
1415
1416        let tree_cost = model.estimate_tree(&plan, &card_est);
1417
1418        // Tree cost should include scan IO + filter CPU + return CPU
1419        assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1420        assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1421
1422        // Compare with single-operator estimate (only costs the root)
1423        let root_only_card = card_est.estimate(&plan);
1424        let root_only_cost = model.estimate(&plan, root_only_card);
1425
1426        // Tree cost should be strictly higher because it includes child costs
1427        assert!(
1428            tree_cost.total() > root_only_cost.total(),
1429            "Recursive tree cost ({}) should exceed root-only cost ({})",
1430            tree_cost.total(),
1431            root_only_cost.total()
1432        );
1433    }
1434
1435    #[test]
1436    fn test_statistics_driven_vs_default_cost() {
1437        let default_model = CostModel::new();
1438
1439        let mut label_cards = std::collections::HashMap::new();
1440        label_cards.insert("Person".to_string(), 100_u64);
1441        let stats_model = CostModel::new()
1442            .with_label_cardinalities(label_cards)
1443            .with_graph_totals(100, 500);
1444
1445        // Scan a small label: statistics model knows it's only 100 nodes
1446        let scan = NodeScanOp {
1447            variable: "n".to_string(),
1448            label: Some("Person".to_string()),
1449            input: None,
1450        };
1451
1452        let default_cost = default_model.node_scan_cost(&scan, 100.0);
1453        let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1454
1455        // With statistics, IO cost is based on actual label size (100 nodes)
1456        // Without statistics, IO cost uses cardinality parameter (also 100 here)
1457        // They should be equal in this case since cardinality matches label size
1458        assert!(
1459            (default_cost.io - stats_cost.io).abs() < 0.1,
1460            "When cardinality matches label size, costs should be similar"
1461        );
1462    }
1463}