Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7    LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8    TextScanOp, VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12/// Cost of an operation.
13///
14/// Represents the estimated resource consumption of executing an operator.
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17    /// Estimated CPU cycles / work units.
18    pub cpu: f64,
19    /// Estimated I/O operations (page reads).
20    pub io: f64,
21    /// Estimated memory usage in bytes.
22    pub memory: f64,
23    /// Network cost (for distributed queries).
24    pub network: f64,
25}
26
27impl Cost {
28    /// Creates a zero cost.
29    #[must_use]
30    pub fn zero() -> Self {
31        Self {
32            cpu: 0.0,
33            io: 0.0,
34            memory: 0.0,
35            network: 0.0,
36        }
37    }
38
39    /// Creates a cost from CPU work units.
40    #[must_use]
41    pub fn cpu(cpu: f64) -> Self {
42        Self {
43            cpu,
44            io: 0.0,
45            memory: 0.0,
46            network: 0.0,
47        }
48    }
49
50    /// Adds I/O cost.
51    #[must_use]
52    pub fn with_io(mut self, io: f64) -> Self {
53        self.io = io;
54        self
55    }
56
57    /// Adds memory cost.
58    #[must_use]
59    pub fn with_memory(mut self, memory: f64) -> Self {
60        self.memory = memory;
61        self
62    }
63
64    /// Returns the total weighted cost.
65    ///
66    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
67    #[must_use]
68    pub fn total(&self) -> f64 {
69        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70    }
71
72    /// Returns the total cost with custom weights.
73    #[must_use]
74    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76    }
77}
78
79impl std::ops::Add for Cost {
80    type Output = Self;
81
82    fn add(self, other: Self) -> Self {
83        Self {
84            cpu: self.cpu + other.cpu,
85            io: self.io + other.io,
86            memory: self.memory + other.memory,
87            network: self.network + other.network,
88        }
89    }
90}
91
92impl std::ops::AddAssign for Cost {
93    fn add_assign(&mut self, other: Self) {
94        self.cpu += other.cpu;
95        self.io += other.io;
96        self.memory += other.memory;
97        self.network += other.network;
98    }
99}
100
101/// Cost model for estimating operator costs.
102///
103/// Default constants are calibrated relative to each other:
104/// - Tuple scan is the baseline (1x)
105/// - Hash lookup is ~3x (hash computation + potential cache miss)
106/// - Sort comparison is ~2x (key extraction + comparison)
107/// - Distance computation is ~10x (vector math)
108pub struct CostModel {
109    /// Cost per tuple processed by CPU (baseline unit).
110    cpu_tuple_cost: f64,
111    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
112    hash_lookup_cost: f64,
113    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
114    sort_comparison_cost: f64,
115    /// Average tuple size in bytes (for IO estimation).
116    avg_tuple_size: f64,
117    /// Page size in bytes.
118    page_size: f64,
119    /// Global average edge fanout (fallback when per-type stats unavailable).
120    avg_fanout: f64,
121    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
122    edge_type_degrees: HashMap<String, (f64, f64)>,
123    /// Per-label node counts for accurate scan IO estimation.
124    label_cardinalities: HashMap<String, u64>,
125    /// Total node count in the graph.
126    total_nodes: u64,
127    /// Total edge count in the graph.
128    total_edges: u64,
129}
130
131impl CostModel {
132    /// Creates a new cost model with calibrated default parameters.
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            cpu_tuple_cost: 0.01,
137            hash_lookup_cost: 0.03,
138            sort_comparison_cost: 0.02,
139            avg_tuple_size: 100.0,
140            page_size: 8192.0,
141            avg_fanout: 10.0,
142            edge_type_degrees: HashMap::new(),
143            label_cardinalities: HashMap::new(),
144            total_nodes: 0,
145            total_edges: 0,
146        }
147    }
148
149    /// Sets the global average fanout from graph statistics.
150    #[must_use]
151    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153        self
154    }
155
156    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
157    ///
158    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
159    #[must_use]
160    pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161        self.edge_type_degrees = degrees;
162        self
163    }
164
165    /// Sets per-label node counts for accurate scan IO estimation.
166    #[must_use]
167    pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168        self.label_cardinalities = cardinalities;
169        self
170    }
171
172    /// Sets graph-level totals for cost estimation.
173    #[must_use]
174    pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175        self.total_nodes = total_nodes;
176        self.total_edges = total_edges;
177        self
178    }
179
180    /// Returns the fanout for a specific expand operation.
181    ///
182    /// Uses per-edge-type degree stats when available, falling back to the
183    /// global average fanout. For multiple edge types, sums per-type fanouts.
184    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185        if expand.edge_types.is_empty() {
186            return self.avg_fanout;
187        }
188
189        let mut total_fanout = 0.0;
190        let mut all_found = true;
191
192        for edge_type in &expand.edge_types {
193            if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194                total_fanout += match expand.direction {
195                    ExpandDirection::Outgoing => out_deg,
196                    ExpandDirection::Incoming => in_deg,
197                    ExpandDirection::Both => out_deg + in_deg,
198                };
199            } else {
200                all_found = false;
201                break;
202            }
203        }
204
205        if all_found {
206            total_fanout
207        } else {
208            self.avg_fanout
209        }
210    }
211
212    /// Estimates the cost of a logical operator.
213    #[must_use]
214    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215        match op {
216            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227            LogicalOperator::Empty => Cost::zero(),
228            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230            LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231            LogicalOperator::LeftJoin(lj) => {
232                let left_card = self.estimate_child_cardinality(&lj.left);
233                let right_card = self.estimate_child_cardinality(&lj.right);
234                self.left_join_cost(lj, cardinality, left_card, right_card)
235            }
236            LogicalOperator::TextScan(scan) => self.text_scan_cost(scan, cardinality),
237            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
238        }
239    }
240
241    /// Estimates the cost of a node scan.
242    ///
243    /// When label statistics are available, uses the actual label cardinality
244    /// for IO estimation (pages to read) rather than the optimizer's cardinality
245    /// estimate, which may already account for filter selectivity.
246    fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
247        // IO cost: based on how many nodes we actually need to scan from storage
248        let scan_size = if let Some(label) = &scan.label {
249            self.label_cardinalities
250                .get(label)
251                .map_or(cardinality, |&count| count as f64)
252        } else if self.total_nodes > 0 {
253            self.total_nodes as f64
254        } else {
255            cardinality
256        };
257        let pages = (scan_size * self.avg_tuple_size) / self.page_size;
258        // CPU cost: only pay for rows that pass the label filter
259        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
260    }
261
262    /// Estimates the cost of a filter operation.
263    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
264        // Filter cost is just predicate evaluation per tuple
265        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
266    }
267
268    /// Estimates the cost of a projection.
269    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
270        // Cost depends on number of expressions evaluated
271        let expr_count = project.projections.len() as f64;
272        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
273    }
274
275    /// Estimates the cost of an expand operation.
276    ///
277    /// Uses per-edge-type degree stats when available, otherwise falls back
278    /// to the global average fanout.
279    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
280        let fanout = self.fanout_for_expand(expand);
281        // Adjacency list lookup per input row
282        let lookup_cost = cardinality * self.hash_lookup_cost;
283        // Process each expanded output tuple
284        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
285        Cost::cpu(lookup_cost + output_cost)
286    }
287
288    /// Estimates the cost of a join operation.
289    ///
290    /// When child cardinalities are known (from recursive estimation), uses them
291    /// directly for build/probe cost. Otherwise falls back to sqrt approximation.
292    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
293        self.join_cost_with_children(join, cardinality, None, None)
294    }
295
296    /// Estimates join cost using actual child cardinalities.
297    fn join_cost_with_children(
298        &self,
299        join: &JoinOp,
300        cardinality: f64,
301        left_cardinality: Option<f64>,
302        right_cardinality: Option<f64>,
303    ) -> Cost {
304        match join.join_type {
305            JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
306            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
307                // Hash join: build the smaller side, probe with the larger
308                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
309                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
310
311                let build_cost = build_cardinality * self.hash_lookup_cost;
312                let memory_cost = build_cardinality * self.avg_tuple_size;
313                let probe_cost = probe_cardinality * self.hash_lookup_cost;
314                let output_cost = cardinality * self.cpu_tuple_cost;
315
316                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
317            }
318            JoinType::Semi | JoinType::Anti => {
319                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
320                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
321
322                let build_cost = build_cardinality * self.hash_lookup_cost;
323                let probe_cost = probe_cardinality * self.hash_lookup_cost;
324
325                Cost::cpu(build_cost + probe_cost)
326                    .with_memory(build_cardinality * self.avg_tuple_size)
327            }
328        }
329    }
330
331    /// Estimates the cost of a left outer join (OPTIONAL MATCH).
332    ///
333    /// Uses the same hash join cost model as inner joins: the right side is
334    /// built into a hash table, the left side probes. Output cost includes all
335    /// left rows (some with NULL-padded right side).
336    fn left_join_cost(
337        &self,
338        _lj: &LeftJoinOp,
339        cardinality: f64,
340        left_card: f64,
341        right_card: f64,
342    ) -> Cost {
343        let build_cost = right_card * self.hash_lookup_cost;
344        let memory_cost = right_card * self.avg_tuple_size;
345        let probe_cost = left_card * self.hash_lookup_cost;
346        let output_cost = cardinality * self.cpu_tuple_cost;
347
348        Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
349    }
350
351    /// Estimates the cardinality of a child operator using available statistics.
352    ///
353    /// Used internally to compute left/right child cardinalities for joins
354    /// without requiring an external `CardinalityEstimator`. Uses the same
355    /// statistics the cost model already holds (label counts, fanout).
356    fn estimate_child_cardinality(&self, op: &LogicalOperator) -> f64 {
357        match op {
358            LogicalOperator::NodeScan(scan) => if let Some(label) = &scan.label {
359                self.label_cardinalities
360                    .get(label)
361                    .map_or(self.total_nodes as f64, |&c| c as f64)
362            } else {
363                self.total_nodes as f64
364            }
365            .max(1.0),
366            LogicalOperator::Expand(expand) => {
367                let input_card = self.estimate_child_cardinality(&expand.input);
368                let fanout = if expand.edge_types.is_empty() {
369                    self.avg_fanout
370                } else {
371                    self.fanout_for_expand(expand)
372                };
373                (input_card * fanout).max(1.0)
374            }
375            LogicalOperator::Filter(filter) => {
376                // Approximate: filters reduce cardinality by 10%
377                (self.estimate_child_cardinality(&filter.input) * 0.1).max(1.0)
378            }
379            LogicalOperator::Return(ret) => self.estimate_child_cardinality(&ret.input),
380            LogicalOperator::Limit(limit) => {
381                let input = self.estimate_child_cardinality(&limit.input);
382                // LimitOp.count may be a parameter; use input as upper bound
383                input.min(100.0)
384            }
385            _ => (self.total_nodes as f64).max(1.0),
386        }
387    }
388
389    /// Estimates the cost of a multi-way (leapfrog) join.
390    ///
391    /// Delegates to `leapfrog_join_cost` using per-input cardinality estimates
392    /// derived from the output cardinality divided equally among inputs.
393    fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
394        let n = mwj.inputs.len();
395        if n == 0 {
396            return Cost::zero();
397        }
398        // Approximate per-input cardinalities: assume each input contributes
399        // cardinality^(1/n) rows (AGM-style uniform assumption)
400        let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
401        let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
402        self.leapfrog_join_cost(n, &cardinalities, cardinality)
403    }
404
405    /// Estimates the cost of an aggregation.
406    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
407        // Hash aggregation cost
408        let hash_cost = cardinality * self.hash_lookup_cost;
409
410        // Aggregate function evaluation
411        let agg_count = agg.aggregates.len() as f64;
412        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
413
414        // Memory for hash table (estimated distinct groups)
415        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
416        let memory_cost = distinct_groups * self.avg_tuple_size;
417
418        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
419    }
420
421    /// Estimates the cost of a sort operation.
422    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
423        if cardinality <= 1.0 {
424            return Cost::zero();
425        }
426
427        // Sort is O(n log n) comparisons
428        let comparisons = cardinality * cardinality.log2();
429        let key_count = sort.keys.len() as f64;
430
431        // Memory for sorting (full input materialization)
432        let memory_cost = cardinality * self.avg_tuple_size;
433
434        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
435    }
436
437    /// Estimates the cost of a distinct operation.
438    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
439        // Hash-based distinct
440        let hash_cost = cardinality * self.hash_lookup_cost;
441        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
442
443        Cost::cpu(hash_cost).with_memory(memory_cost)
444    }
445
446    /// Estimates the cost of a limit operation.
447    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
448        // Limit is very cheap - just counting
449        Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
450    }
451
452    /// Estimates the cost of a skip operation.
453    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
454        // Skip requires scanning through skipped rows
455        Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
456    }
457
458    /// Estimates the cost of a return operation.
459    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
460        // Return materializes results
461        let expr_count = ret.items.len() as f64;
462        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
463    }
464
465    /// Estimates the cost of a vector scan operation.
466    ///
467    /// HNSW index search is O(log N) per query, while brute-force is O(N).
468    /// This estimates the HNSW case with ef search parameter.
469    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
470        let k = scan.k.unwrap_or(0) as f64;
471        let n = cardinality.max(1.0);
472
473        // HNSW search cost: O(ef * log(N)) distance computations
474        // Brute-force: O(N) distance computations
475        let ef = 64.0;
476        let search_cost = if scan.index_name.is_some() {
477            ef * n.ln() * self.cpu_tuple_cost * 10.0
478        } else {
479            n * self.cpu_tuple_cost * 10.0
480        };
481
482        // Memory for candidate heap (threshold mode uses cardinality estimate)
483        let output_rows = if k > 0.0 { k } else { cardinality };
484        let memory = output_rows * self.avg_tuple_size * 2.0;
485
486        Cost::cpu(search_cost).with_memory(memory)
487    }
488
489    /// Estimates the cost of a text scan operation.
490    ///
491    /// BM25 scoring requires iterating over postings lists proportional to the
492    /// corpus size, with scoring being ~5x more expensive than a plain tuple
493    /// comparison. The index is in-memory so I/O cost is zero.
494    fn text_scan_cost(&self, scan: &TextScanOp, cardinality: f64) -> Cost {
495        // Use actual label row count for corpus size (BM25 scans the postings).
496        let corpus_size = self
497            .label_cardinalities
498            .get(&scan.label)
499            .copied()
500            .map_or(cardinality, |c| c as f64);
501        // BM25 scoring is ~5x a simple tuple comparison
502        let cpu = corpus_size * self.cpu_tuple_cost * 5.0;
503        Cost::cpu(cpu).with_memory(cardinality * self.avg_tuple_size)
504    }
505
506    /// Estimates the cost of a vector join operation.
507    ///
508    /// Vector join performs k-NN search for each input row.
509    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
510        let k = join.k as f64;
511
512        // Each input row triggers a vector search
513        // Assume brute-force for hybrid queries (no index specified typically)
514        let per_row_search_cost = if join.index_name.is_some() {
515            // HNSW: O(ef * log N)
516            let ef = 64.0;
517            let n = cardinality.max(1.0);
518            ef * n.ln() * self.cpu_tuple_cost * 10.0
519        } else {
520            // Brute-force: O(N) per input row
521            cardinality * self.cpu_tuple_cost * 10.0
522        };
523
524        // Total cost: input_rows * search_cost
525        // For vector join, cardinality is typically input cardinality * k
526        let input_cardinality = (cardinality / k).max(1.0);
527        let total_search_cost = input_cardinality * per_row_search_cost;
528
529        // Memory for results
530        let memory = cardinality * self.avg_tuple_size;
531
532        Cost::cpu(total_search_cost).with_memory(memory)
533    }
534
535    /// Estimates the total cost of an operator tree recursively.
536    ///
537    /// Walks the entire plan tree, computing per-operator cost at each level
538    /// using the cardinality estimator for accurate child cardinalities.
539    /// Returns the sum of all operator costs in the tree.
540    #[must_use]
541    pub fn estimate_tree(
542        &self,
543        op: &LogicalOperator,
544        card_estimator: &super::CardinalityEstimator,
545    ) -> Cost {
546        self.estimate_tree_inner(op, card_estimator)
547    }
548
549    fn estimate_tree_inner(
550        &self,
551        op: &LogicalOperator,
552        card_est: &super::CardinalityEstimator,
553    ) -> Cost {
554        let cardinality = card_est.estimate(op);
555
556        match op {
557            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
558            LogicalOperator::Filter(filter) => {
559                let child_cost = self.estimate_tree_inner(&filter.input, card_est);
560                child_cost + self.filter_cost(filter, cardinality)
561            }
562            LogicalOperator::Project(project) => {
563                let child_cost = self.estimate_tree_inner(&project.input, card_est);
564                child_cost + self.project_cost(project, cardinality)
565            }
566            LogicalOperator::Expand(expand) => {
567                let child_cost = self.estimate_tree_inner(&expand.input, card_est);
568                child_cost + self.expand_cost(expand, cardinality)
569            }
570            LogicalOperator::Join(join) => {
571                let left_cost = self.estimate_tree_inner(&join.left, card_est);
572                let right_cost = self.estimate_tree_inner(&join.right, card_est);
573                let left_card = card_est.estimate(&join.left);
574                let right_card = card_est.estimate(&join.right);
575                let join_cost = self.join_cost_with_children(
576                    join,
577                    cardinality,
578                    Some(left_card),
579                    Some(right_card),
580                );
581                left_cost + right_cost + join_cost
582            }
583            LogicalOperator::LeftJoin(lj) => {
584                let left_cost = self.estimate_tree_inner(&lj.left, card_est);
585                let right_cost = self.estimate_tree_inner(&lj.right, card_est);
586                let left_card = card_est.estimate(&lj.left);
587                let right_card = card_est.estimate(&lj.right);
588                let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
589                left_cost + right_cost + join_cost
590            }
591            LogicalOperator::Aggregate(agg) => {
592                let child_cost = self.estimate_tree_inner(&agg.input, card_est);
593                child_cost + self.aggregate_cost(agg, cardinality)
594            }
595            LogicalOperator::Sort(sort) => {
596                let child_cost = self.estimate_tree_inner(&sort.input, card_est);
597                child_cost + self.sort_cost(sort, cardinality)
598            }
599            LogicalOperator::Distinct(distinct) => {
600                let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
601                child_cost + self.distinct_cost(distinct, cardinality)
602            }
603            LogicalOperator::Limit(limit) => {
604                let child_cost = self.estimate_tree_inner(&limit.input, card_est);
605                child_cost + self.limit_cost(limit, cardinality)
606            }
607            LogicalOperator::Skip(skip) => {
608                let child_cost = self.estimate_tree_inner(&skip.input, card_est);
609                child_cost + self.skip_cost(skip, cardinality)
610            }
611            LogicalOperator::Return(ret) => {
612                let child_cost = self.estimate_tree_inner(&ret.input, card_est);
613                child_cost + self.return_cost(ret, cardinality)
614            }
615            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
616            LogicalOperator::VectorJoin(join) => {
617                let child_cost = self.estimate_tree_inner(&join.input, card_est);
618                child_cost + self.vector_join_cost(join, cardinality)
619            }
620            LogicalOperator::MultiWayJoin(mwj) => {
621                let mut children_cost = Cost::zero();
622                for input in &mwj.inputs {
623                    children_cost += self.estimate_tree_inner(input, card_est);
624                }
625                children_cost + self.multi_way_join_cost(mwj, cardinality)
626            }
627            LogicalOperator::Empty => Cost::zero(),
628            LogicalOperator::TextScan(scan) => self.text_scan_cost(scan, cardinality),
629            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
630        }
631    }
632
633    /// Compares two costs and returns the cheaper one.
634    #[must_use]
635    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
636        if a.total() <= b.total() { a } else { b }
637    }
638
639    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
640    ///
641    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
642    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
643    /// all relations simultaneously using sorted iterators.
644    ///
645    /// # Arguments
646    /// * `num_relations` - Number of relations participating in the join
647    /// * `cardinalities` - Cardinality of each input relation
648    /// * `output_cardinality` - Expected output cardinality
649    ///
650    /// # Cost Model
651    /// - Materialization: O(sum of cardinalities) to build trie indexes
652    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
653    /// - Memory: Trie storage for all inputs
654    #[must_use]
655    pub fn leapfrog_join_cost(
656        &self,
657        num_relations: usize,
658        cardinalities: &[f64],
659        output_cardinality: f64,
660    ) -> Cost {
661        if cardinalities.is_empty() {
662            return Cost::zero();
663        }
664
665        let total_input: f64 = cardinalities.iter().sum();
666        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
667
668        // Materialization phase: build trie indexes for each input
669        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
670
671        // Intersection phase: leapfrog seeks are O(log n) per relation
672        let seek_cost = if min_card > 1.0 {
673            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
674        } else {
675            output_cardinality * self.cpu_tuple_cost
676        };
677
678        // Output materialization
679        let output_cost = output_cardinality * self.cpu_tuple_cost;
680
681        // Memory: trie storage (roughly 2x input size for sorted index)
682        let memory = total_input * self.avg_tuple_size * 2.0;
683
684        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
685    }
686
687    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
688    ///
689    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
690    #[must_use]
691    pub fn prefer_leapfrog_join(
692        &self,
693        num_relations: usize,
694        cardinalities: &[f64],
695        output_cardinality: f64,
696    ) -> bool {
697        if num_relations < 3 || cardinalities.len() < 3 {
698            // Leapfrog is only beneficial for multi-way joins (3+)
699            return false;
700        }
701
702        let leapfrog_cost =
703            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
704
705        // Estimate cascade of binary hash joins
706        // For N relations, we need N-1 joins
707        // Each join produces intermediate results that feed the next
708        let mut hash_cascade_cost = Cost::zero();
709        let mut intermediate_cardinality = cardinalities[0];
710
711        for card in &cardinalities[1..] {
712            // Hash join cost: build + probe + output
713            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
714            let join = JoinOp {
715                left: Box::new(LogicalOperator::Empty),
716                right: Box::new(LogicalOperator::Empty),
717                join_type: JoinType::Inner,
718                conditions: vec![],
719            };
720            hash_cascade_cost += self.join_cost(&join, join_output);
721            intermediate_cardinality = join_output;
722        }
723
724        leapfrog_cost.total() < hash_cascade_cost.total()
725    }
726
727    /// Estimates cost for factorized execution (compressed intermediate results).
728    ///
729    /// Factorized execution avoids materializing full cross products by keeping
730    /// results in a compressed "factorized" form. This is beneficial for multi-hop
731    /// traversals where intermediate results can explode.
732    ///
733    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
734    #[must_use]
735    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
736        if num_hops <= 1 || avg_fanout <= 1.0 {
737            return 1.0; // No benefit for single hop or low fanout
738        }
739
740        // Factorized representation compresses repeated prefixes
741        // Compression ratio improves with higher fanout and more hops
742        // Full materialization: fanout^hops
743        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
744
745        // reason: hop count is always small (< 100)
746        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
747        let hops_i32 = num_hops as i32;
748        let full_size = avg_fanout.powi(hops_i32);
749        let factorized_size = if avg_fanout > 1.0 {
750            (avg_fanout.powi(hops_i32 + 1) - 1.0) / (avg_fanout - 1.0)
751        } else {
752            num_hops as f64
753        };
754
755        (factorized_size / full_size).min(1.0)
756    }
757}
758
759impl Default for CostModel {
760    fn default() -> Self {
761        Self::new()
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768    use crate::query::plan::{
769        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
770        PathMode, Projection, ReturnItem, SortOrder,
771    };
772
773    #[test]
774    fn test_cost_addition() {
775        let a = Cost::cpu(10.0).with_io(5.0);
776        let b = Cost::cpu(20.0).with_memory(100.0);
777        let c = a + b;
778
779        assert!((c.cpu - 30.0).abs() < 0.001);
780        assert!((c.io - 5.0).abs() < 0.001);
781        assert!((c.memory - 100.0).abs() < 0.001);
782    }
783
784    #[test]
785    fn test_cost_total() {
786        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
787        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
788        assert!((cost.total() - 30.0).abs() < 0.001);
789    }
790
791    #[test]
792    fn test_cost_model_node_scan() {
793        let model = CostModel::new();
794        let scan = NodeScanOp {
795            variable: "n".to_string(),
796            label: Some("Person".to_string()),
797            input: None,
798        };
799        let cost = model.node_scan_cost(&scan, 1000.0);
800
801        assert!(cost.cpu > 0.0);
802        assert!(cost.io > 0.0);
803    }
804
805    #[test]
806    fn test_cost_model_sort() {
807        let model = CostModel::new();
808        let sort = SortOp {
809            keys: vec![],
810            input: Box::new(LogicalOperator::Empty),
811        };
812
813        let cost_100 = model.sort_cost(&sort, 100.0);
814        let cost_1000 = model.sort_cost(&sort, 1000.0);
815
816        // Sorting 1000 rows should be more expensive than 100 rows
817        assert!(cost_1000.total() > cost_100.total());
818    }
819
820    #[test]
821    fn test_cost_zero() {
822        let cost = Cost::zero();
823        assert!((cost.cpu).abs() < 0.001);
824        assert!((cost.io).abs() < 0.001);
825        assert!((cost.memory).abs() < 0.001);
826        assert!((cost.network).abs() < 0.001);
827        assert!((cost.total()).abs() < 0.001);
828    }
829
830    #[test]
831    fn test_cost_add_assign() {
832        let mut cost = Cost::cpu(10.0);
833        cost += Cost::cpu(5.0).with_io(2.0);
834        assert!((cost.cpu - 15.0).abs() < 0.001);
835        assert!((cost.io - 2.0).abs() < 0.001);
836    }
837
838    #[test]
839    fn test_cost_total_weighted() {
840        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
841        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
842        let total = cost.total_weighted(2.0, 5.0, 0.5);
843        assert!((total - 80.0).abs() < 0.001);
844    }
845
846    #[test]
847    fn test_cost_model_filter() {
848        let model = CostModel::new();
849        let filter = FilterOp {
850            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
851            input: Box::new(LogicalOperator::Empty),
852            pushdown_hint: None,
853        };
854        let cost = model.filter_cost(&filter, 1000.0);
855
856        // Filter cost is CPU only
857        assert!(cost.cpu > 0.0);
858        assert!((cost.io).abs() < 0.001);
859    }
860
861    #[test]
862    fn test_cost_model_project() {
863        let model = CostModel::new();
864        let project = ProjectOp {
865            projections: vec![
866                Projection {
867                    expression: LogicalExpression::Variable("a".to_string()),
868                    alias: None,
869                },
870                Projection {
871                    expression: LogicalExpression::Variable("b".to_string()),
872                    alias: None,
873                },
874            ],
875            input: Box::new(LogicalOperator::Empty),
876            pass_through_input: false,
877        };
878        let cost = model.project_cost(&project, 1000.0);
879
880        // Cost should scale with number of projections
881        assert!(cost.cpu > 0.0);
882    }
883
884    #[test]
885    fn test_cost_model_expand() {
886        let model = CostModel::new();
887        let expand = ExpandOp {
888            from_variable: "a".to_string(),
889            to_variable: "b".to_string(),
890            edge_variable: None,
891            direction: ExpandDirection::Outgoing,
892            edge_types: vec![],
893            min_hops: 1,
894            max_hops: Some(1),
895            input: Box::new(LogicalOperator::Empty),
896            path_alias: None,
897            path_mode: PathMode::Walk,
898        };
899        let cost = model.expand_cost(&expand, 1000.0);
900
901        // Expand involves hash lookups and output generation
902        assert!(cost.cpu > 0.0);
903    }
904
905    #[test]
906    fn test_cost_model_expand_with_edge_type_stats() {
907        let mut degrees = std::collections::HashMap::new();
908        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
909        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
910
911        let model = CostModel::new().with_edge_type_degrees(degrees);
912
913        // Outgoing KNOWS: fanout = 5
914        let knows_out = ExpandOp {
915            from_variable: "a".to_string(),
916            to_variable: "b".to_string(),
917            edge_variable: None,
918            direction: ExpandDirection::Outgoing,
919            edge_types: vec!["KNOWS".to_string()],
920            min_hops: 1,
921            max_hops: Some(1),
922            input: Box::new(LogicalOperator::Empty),
923            path_alias: None,
924            path_mode: PathMode::Walk,
925        };
926        let cost_knows = model.expand_cost(&knows_out, 1000.0);
927
928        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
929        let works_out = ExpandOp {
930            from_variable: "a".to_string(),
931            to_variable: "b".to_string(),
932            edge_variable: None,
933            direction: ExpandDirection::Outgoing,
934            edge_types: vec!["WORKS_AT".to_string()],
935            min_hops: 1,
936            max_hops: Some(1),
937            input: Box::new(LogicalOperator::Empty),
938            path_alias: None,
939            path_mode: PathMode::Walk,
940        };
941        let cost_works = model.expand_cost(&works_out, 1000.0);
942
943        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
944        assert!(
945            cost_knows.cpu > cost_works.cpu,
946            "KNOWS(5) should cost more than WORKS_AT(1)"
947        );
948
949        // Incoming WORKS_AT: fanout = 50 (company has many employees)
950        let works_in = ExpandOp {
951            from_variable: "c".to_string(),
952            to_variable: "p".to_string(),
953            edge_variable: None,
954            direction: ExpandDirection::Incoming,
955            edge_types: vec!["WORKS_AT".to_string()],
956            min_hops: 1,
957            max_hops: Some(1),
958            input: Box::new(LogicalOperator::Empty),
959            path_alias: None,
960            path_mode: PathMode::Walk,
961        };
962        let cost_works_in = model.expand_cost(&works_in, 1000.0);
963
964        // Incoming WORKS_AT (fanout=50) should be most expensive
965        assert!(
966            cost_works_in.cpu > cost_knows.cpu,
967            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
968        );
969    }
970
971    #[test]
972    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
973        let model = CostModel::new().with_avg_fanout(7.0);
974        let expand = ExpandOp {
975            from_variable: "a".to_string(),
976            to_variable: "b".to_string(),
977            edge_variable: None,
978            direction: ExpandDirection::Outgoing,
979            edge_types: vec!["UNKNOWN_TYPE".to_string()],
980            min_hops: 1,
981            max_hops: Some(1),
982            input: Box::new(LogicalOperator::Empty),
983            path_alias: None,
984            path_mode: PathMode::Walk,
985        };
986        let cost_unknown = model.expand_cost(&expand, 1000.0);
987
988        // Without edge type (uses global fanout too)
989        let expand_no_type = ExpandOp {
990            from_variable: "a".to_string(),
991            to_variable: "b".to_string(),
992            edge_variable: None,
993            direction: ExpandDirection::Outgoing,
994            edge_types: vec![],
995            min_hops: 1,
996            max_hops: Some(1),
997            input: Box::new(LogicalOperator::Empty),
998            path_alias: None,
999            path_mode: PathMode::Walk,
1000        };
1001        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
1002
1003        // Both should use global fanout = 7, so costs should be equal
1004        assert!(
1005            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
1006            "Unknown edge type should use global fanout"
1007        );
1008    }
1009
1010    #[test]
1011    fn test_cost_model_hash_join() {
1012        let model = CostModel::new();
1013        let join = JoinOp {
1014            left: Box::new(LogicalOperator::Empty),
1015            right: Box::new(LogicalOperator::Empty),
1016            join_type: JoinType::Inner,
1017            conditions: vec![JoinCondition {
1018                left: LogicalExpression::Variable("a".to_string()),
1019                right: LogicalExpression::Variable("b".to_string()),
1020            }],
1021        };
1022        let cost = model.join_cost(&join, 10000.0);
1023
1024        // Hash join has CPU cost and memory cost
1025        assert!(cost.cpu > 0.0);
1026        assert!(cost.memory > 0.0);
1027    }
1028
1029    #[test]
1030    fn test_cost_model_cross_join() {
1031        let model = CostModel::new();
1032        let join = JoinOp {
1033            left: Box::new(LogicalOperator::Empty),
1034            right: Box::new(LogicalOperator::Empty),
1035            join_type: JoinType::Cross,
1036            conditions: vec![],
1037        };
1038        let cost = model.join_cost(&join, 1000000.0);
1039
1040        // Cross join is expensive
1041        assert!(cost.cpu > 0.0);
1042    }
1043
1044    #[test]
1045    fn test_cost_model_semi_join() {
1046        let model = CostModel::new();
1047        let join = JoinOp {
1048            left: Box::new(LogicalOperator::Empty),
1049            right: Box::new(LogicalOperator::Empty),
1050            join_type: JoinType::Semi,
1051            conditions: vec![],
1052        };
1053        let cost_semi = model.join_cost(&join, 1000.0);
1054
1055        let inner_join = JoinOp {
1056            left: Box::new(LogicalOperator::Empty),
1057            right: Box::new(LogicalOperator::Empty),
1058            join_type: JoinType::Inner,
1059            conditions: vec![],
1060        };
1061        let cost_inner = model.join_cost(&inner_join, 1000.0);
1062
1063        // Semi join can be cheaper than inner join
1064        assert!(cost_semi.cpu > 0.0);
1065        assert!(cost_inner.cpu > 0.0);
1066    }
1067
1068    #[test]
1069    fn test_cost_model_aggregate() {
1070        let model = CostModel::new();
1071        let agg = AggregateOp {
1072            group_by: vec![],
1073            aggregates: vec![
1074                AggregateExpr {
1075                    function: AggregateFunction::Count,
1076                    expression: None,
1077                    expression2: None,
1078                    distinct: false,
1079                    alias: Some("cnt".to_string()),
1080                    percentile: None,
1081                    separator: None,
1082                },
1083                AggregateExpr {
1084                    function: AggregateFunction::Sum,
1085                    expression: Some(LogicalExpression::Variable("x".to_string())),
1086                    expression2: None,
1087                    distinct: false,
1088                    alias: Some("total".to_string()),
1089                    percentile: None,
1090                    separator: None,
1091                },
1092            ],
1093            input: Box::new(LogicalOperator::Empty),
1094            having: None,
1095        };
1096        let cost = model.aggregate_cost(&agg, 1000.0);
1097
1098        // Aggregation has hash cost and memory cost
1099        assert!(cost.cpu > 0.0);
1100        assert!(cost.memory > 0.0);
1101    }
1102
1103    #[test]
1104    fn test_cost_model_distinct() {
1105        let model = CostModel::new();
1106        let distinct = DistinctOp {
1107            input: Box::new(LogicalOperator::Empty),
1108            columns: None,
1109        };
1110        let cost = model.distinct_cost(&distinct, 1000.0);
1111
1112        // Distinct uses hash set
1113        assert!(cost.cpu > 0.0);
1114        assert!(cost.memory > 0.0);
1115    }
1116
1117    #[test]
1118    fn test_cost_model_limit() {
1119        let model = CostModel::new();
1120        let limit = LimitOp {
1121            count: 10.into(),
1122            input: Box::new(LogicalOperator::Empty),
1123        };
1124        let cost = model.limit_cost(&limit, 1000.0);
1125
1126        // Limit is very cheap
1127        assert!(cost.cpu > 0.0);
1128        assert!(cost.cpu < 1.0); // Should be minimal
1129    }
1130
1131    #[test]
1132    fn test_cost_model_skip() {
1133        let model = CostModel::new();
1134        let skip = SkipOp {
1135            count: 100.into(),
1136            input: Box::new(LogicalOperator::Empty),
1137        };
1138        let cost = model.skip_cost(&skip, 1000.0);
1139
1140        // Skip must scan through skipped rows
1141        assert!(cost.cpu > 0.0);
1142    }
1143
1144    #[test]
1145    fn test_cost_model_return() {
1146        let model = CostModel::new();
1147        let ret = ReturnOp {
1148            items: vec![
1149                ReturnItem {
1150                    expression: LogicalExpression::Variable("a".to_string()),
1151                    alias: None,
1152                },
1153                ReturnItem {
1154                    expression: LogicalExpression::Variable("b".to_string()),
1155                    alias: None,
1156                },
1157            ],
1158            distinct: false,
1159            input: Box::new(LogicalOperator::Empty),
1160        };
1161        let cost = model.return_cost(&ret, 1000.0);
1162
1163        // Return materializes results
1164        assert!(cost.cpu > 0.0);
1165    }
1166
1167    #[test]
1168    fn test_cost_cheaper() {
1169        let model = CostModel::new();
1170        let cheap = Cost::cpu(10.0);
1171        let expensive = Cost::cpu(100.0);
1172
1173        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1174        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1175    }
1176
1177    #[test]
1178    fn test_cost_comparison_prefers_lower_total() {
1179        let model = CostModel::new();
1180        // High CPU, low IO
1181        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1182        // Low CPU, high IO
1183        let io_heavy = Cost::cpu(10.0).with_io(20.0);
1184
1185        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
1186        assert!(cpu_heavy.total() < io_heavy.total());
1187        assert_eq!(
1188            model.cheaper(&cpu_heavy, &io_heavy).total(),
1189            cpu_heavy.total()
1190        );
1191    }
1192
1193    #[test]
1194    fn test_cost_model_sort_with_keys() {
1195        let model = CostModel::new();
1196        let sort_single = SortOp {
1197            keys: vec![crate::query::plan::SortKey {
1198                expression: LogicalExpression::Variable("a".to_string()),
1199                order: SortOrder::Ascending,
1200                nulls: None,
1201            }],
1202            input: Box::new(LogicalOperator::Empty),
1203        };
1204        let sort_multi = SortOp {
1205            keys: vec![
1206                crate::query::plan::SortKey {
1207                    expression: LogicalExpression::Variable("a".to_string()),
1208                    order: SortOrder::Ascending,
1209                    nulls: None,
1210                },
1211                crate::query::plan::SortKey {
1212                    expression: LogicalExpression::Variable("b".to_string()),
1213                    order: SortOrder::Descending,
1214                    nulls: None,
1215                },
1216            ],
1217            input: Box::new(LogicalOperator::Empty),
1218        };
1219
1220        let cost_single = model.sort_cost(&sort_single, 1000.0);
1221        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1222
1223        // More sort keys = more comparisons
1224        assert!(cost_multi.cpu > cost_single.cpu);
1225    }
1226
1227    #[test]
1228    fn test_cost_model_empty_operator() {
1229        let model = CostModel::new();
1230        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1231        assert!((cost.total()).abs() < 0.001);
1232    }
1233
1234    #[test]
1235    fn test_cost_model_default() {
1236        let model = CostModel::default();
1237        let scan = NodeScanOp {
1238            variable: "n".to_string(),
1239            label: None,
1240            input: None,
1241        };
1242        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1243        assert!(cost.total() > 0.0);
1244    }
1245
1246    #[test]
1247    fn test_leapfrog_join_cost() {
1248        let model = CostModel::new();
1249
1250        // Three-way join (triangle pattern)
1251        let cardinalities = vec![1000.0, 1000.0, 1000.0];
1252        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1253
1254        // Should have CPU cost for materialization and intersection
1255        assert!(cost.cpu > 0.0);
1256        // Should have memory cost for trie storage
1257        assert!(cost.memory > 0.0);
1258    }
1259
1260    #[test]
1261    fn test_leapfrog_join_cost_empty() {
1262        let model = CostModel::new();
1263        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1264        assert!((cost.total()).abs() < 0.001);
1265    }
1266
1267    #[test]
1268    fn test_prefer_leapfrog_join_for_triangles() {
1269        let model = CostModel::new();
1270
1271        // Compare costs for triangle pattern
1272        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1273        let output = 1000.0;
1274
1275        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1276
1277        // Leapfrog should have reasonable cost for triangle patterns
1278        assert!(leapfrog_cost.cpu > 0.0);
1279        assert!(leapfrog_cost.memory > 0.0);
1280
1281        // The prefer_leapfrog_join method compares against hash cascade
1282        // Actual preference depends on specific cost parameters
1283        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1284        // Test that it returns a boolean (doesn't panic)
1285    }
1286
1287    #[test]
1288    fn test_prefer_leapfrog_join_binary_case() {
1289        let model = CostModel::new();
1290
1291        // Binary join should NOT prefer leapfrog (need 3+ relations)
1292        let cardinalities = vec![1000.0, 1000.0];
1293        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1294        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1295    }
1296
1297    #[test]
1298    fn test_factorized_benefit_single_hop() {
1299        let model = CostModel::new();
1300
1301        // Single hop: no factorization benefit
1302        let benefit = model.factorized_benefit(10.0, 1);
1303        assert!(
1304            (benefit - 1.0).abs() < 0.001,
1305            "Single hop should have no benefit"
1306        );
1307    }
1308
1309    #[test]
1310    fn test_factorized_benefit_multi_hop() {
1311        let model = CostModel::new();
1312
1313        // Multi-hop with high fanout
1314        let benefit = model.factorized_benefit(10.0, 3);
1315
1316        // The factorized_benefit returns a ratio capped at 1.0
1317        // For high fanout, factorized size / full size approaches 1/fanout
1318        // which is beneficial but the formula gives a value <= 1.0
1319        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1320        assert!(benefit > 0.0, "Benefit should be positive");
1321    }
1322
1323    #[test]
1324    fn test_factorized_benefit_low_fanout() {
1325        let model = CostModel::new();
1326
1327        // Low fanout: minimal benefit
1328        let benefit = model.factorized_benefit(1.5, 2);
1329        assert!(
1330            benefit <= 1.0,
1331            "Low fanout still benefits from factorization"
1332        );
1333    }
1334
1335    #[test]
1336    fn test_node_scan_uses_label_cardinality_for_io() {
1337        let mut label_cards = std::collections::HashMap::new();
1338        label_cards.insert("Person".to_string(), 500_u64);
1339        label_cards.insert("Company".to_string(), 50_u64);
1340
1341        let model = CostModel::new()
1342            .with_label_cardinalities(label_cards)
1343            .with_graph_totals(550, 1000);
1344
1345        let person_scan = NodeScanOp {
1346            variable: "n".to_string(),
1347            label: Some("Person".to_string()),
1348            input: None,
1349        };
1350        let company_scan = NodeScanOp {
1351            variable: "n".to_string(),
1352            label: Some("Company".to_string()),
1353            input: None,
1354        };
1355
1356        let person_cost = model.node_scan_cost(&person_scan, 500.0);
1357        let company_cost = model.node_scan_cost(&company_scan, 50.0);
1358
1359        // Person scan reads 10x more pages than Company scan
1360        assert!(
1361            person_cost.io > company_cost.io * 5.0,
1362            "Person ({}) should have much higher IO than Company ({})",
1363            person_cost.io,
1364            company_cost.io
1365        );
1366    }
1367
1368    #[test]
1369    fn test_node_scan_unlabeled_uses_total_nodes() {
1370        let model = CostModel::new().with_graph_totals(10_000, 50_000);
1371
1372        let scan = NodeScanOp {
1373            variable: "n".to_string(),
1374            label: None,
1375            input: None,
1376        };
1377
1378        let cost = model.node_scan_cost(&scan, 10_000.0);
1379        let expected_pages = (10_000.0 * 100.0) / 8192.0;
1380        assert!(
1381            (cost.io - expected_pages).abs() < 0.1,
1382            "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1383            cost.io,
1384            expected_pages
1385        );
1386    }
1387
1388    #[test]
1389    fn test_join_cost_with_actual_child_cardinalities() {
1390        let model = CostModel::new();
1391        let join = JoinOp {
1392            left: Box::new(LogicalOperator::Empty),
1393            right: Box::new(LogicalOperator::Empty),
1394            join_type: JoinType::Inner,
1395            conditions: vec![JoinCondition {
1396                left: LogicalExpression::Variable("a".to_string()),
1397                right: LogicalExpression::Variable("b".to_string()),
1398            }],
1399        };
1400
1401        // With actual child cardinalities (100 left, 10000 right)
1402        let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1403
1404        // With sqrt fallback (sqrt(500) ~ 22.4 for both sides)
1405        let cost_sqrt = model.join_cost(&join, 500.0);
1406
1407        // Actual build side (100) is larger than sqrt(500) ~ 22.4, so
1408        // actual cost should be higher for build, but the probe side (10000)
1409        // should dominate
1410        assert!(
1411            cost_actual.cpu > cost_sqrt.cpu,
1412            "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1413            cost_actual.cpu,
1414            cost_sqrt.cpu
1415        );
1416    }
1417
1418    #[test]
1419    fn test_expand_multi_edge_types() {
1420        let mut degrees = std::collections::HashMap::new();
1421        degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1422        degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1423
1424        let model = CostModel::new().with_edge_type_degrees(degrees);
1425
1426        // Multi-type outgoing: KNOWS(5) + FOLLOWS(20) = 25
1427        let multi_expand = ExpandOp {
1428            from_variable: "a".to_string(),
1429            to_variable: "b".to_string(),
1430            edge_variable: None,
1431            direction: ExpandDirection::Outgoing,
1432            edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1433            min_hops: 1,
1434            max_hops: Some(1),
1435            input: Box::new(LogicalOperator::Empty),
1436            path_alias: None,
1437            path_mode: PathMode::Walk,
1438        };
1439        let multi_cost = model.expand_cost(&multi_expand, 100.0);
1440
1441        // Single type: KNOWS(5) only
1442        let single_expand = ExpandOp {
1443            from_variable: "a".to_string(),
1444            to_variable: "b".to_string(),
1445            edge_variable: None,
1446            direction: ExpandDirection::Outgoing,
1447            edge_types: vec!["KNOWS".to_string()],
1448            min_hops: 1,
1449            max_hops: Some(1),
1450            input: Box::new(LogicalOperator::Empty),
1451            path_alias: None,
1452            path_mode: PathMode::Walk,
1453        };
1454        let single_cost = model.expand_cost(&single_expand, 100.0);
1455
1456        // Multi-type (fanout=25) should be more expensive than single (fanout=5)
1457        assert!(
1458            multi_cost.cpu > single_cost.cpu * 3.0,
1459            "Multi-type fanout ({}) should be much higher than single-type ({})",
1460            multi_cost.cpu,
1461            single_cost.cpu
1462        );
1463    }
1464
1465    #[test]
1466    fn test_recursive_tree_cost() {
1467        use crate::query::optimizer::CardinalityEstimator;
1468
1469        let mut label_cards = std::collections::HashMap::new();
1470        label_cards.insert("Person".to_string(), 1000_u64);
1471
1472        let model = CostModel::new()
1473            .with_label_cardinalities(label_cards)
1474            .with_graph_totals(1000, 5000)
1475            .with_avg_fanout(5.0);
1476
1477        let mut card_est = CardinalityEstimator::new();
1478        card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1479
1480        // Build a plan: NodeScan -> Filter -> Return
1481        let plan = LogicalOperator::Return(ReturnOp {
1482            items: vec![ReturnItem {
1483                expression: LogicalExpression::Variable("n".to_string()),
1484                alias: None,
1485            }],
1486            distinct: false,
1487            input: Box::new(LogicalOperator::Filter(FilterOp {
1488                predicate: LogicalExpression::Binary {
1489                    left: Box::new(LogicalExpression::Property {
1490                        variable: "n".to_string(),
1491                        property: "age".to_string(),
1492                    }),
1493                    op: crate::query::plan::BinaryOp::Gt,
1494                    right: Box::new(LogicalExpression::Literal(
1495                        grafeo_common::types::Value::Int64(30),
1496                    )),
1497                },
1498                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499                    variable: "n".to_string(),
1500                    label: Some("Person".to_string()),
1501                    input: None,
1502                })),
1503                pushdown_hint: None,
1504            })),
1505        });
1506
1507        let tree_cost = model.estimate_tree(&plan, &card_est);
1508
1509        // Tree cost should include scan IO + filter CPU + return CPU
1510        assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1511        assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1512
1513        // Compare with single-operator estimate (only costs the root)
1514        let root_only_card = card_est.estimate(&plan);
1515        let root_only_cost = model.estimate(&plan, root_only_card);
1516
1517        // Tree cost should be strictly higher because it includes child costs
1518        assert!(
1519            tree_cost.total() > root_only_cost.total(),
1520            "Recursive tree cost ({}) should exceed root-only cost ({})",
1521            tree_cost.total(),
1522            root_only_cost.total()
1523        );
1524    }
1525
1526    #[test]
1527    fn test_statistics_driven_vs_default_cost() {
1528        let default_model = CostModel::new();
1529
1530        let mut label_cards = std::collections::HashMap::new();
1531        label_cards.insert("Person".to_string(), 100_u64);
1532        let stats_model = CostModel::new()
1533            .with_label_cardinalities(label_cards)
1534            .with_graph_totals(100, 500);
1535
1536        // Scan a small label: statistics model knows it's only 100 nodes
1537        let scan = NodeScanOp {
1538            variable: "n".to_string(),
1539            label: Some("Person".to_string()),
1540            input: None,
1541        };
1542
1543        let default_cost = default_model.node_scan_cost(&scan, 100.0);
1544        let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1545
1546        // With statistics, IO cost is based on actual label size (100 nodes)
1547        // Without statistics, IO cost uses cardinality parameter (also 100 here)
1548        // They should be equal in this case since cardinality matches label size
1549        assert!(
1550            (default_cost.io - stats_cost.io).abs() < 0.1,
1551            "When cardinality matches label size, costs should be similar"
1552        );
1553    }
1554
1555    #[test]
1556    fn test_leapfrog_join_cost_unit_min_cardinality() {
1557        let model = CostModel::new();
1558        // min_card == 1.0 exercises the else branch in seek_cost (output * cpu_tuple_cost)
1559        let cost = model.leapfrog_join_cost(3, &[1.0, 100.0, 200.0], 50.0);
1560        assert!(cost.cpu > 0.0);
1561        assert!(cost.memory > 0.0);
1562    }
1563
1564    #[test]
1565    fn test_prefer_leapfrog_join_cardinalities_below_three() {
1566        let model = CostModel::new();
1567        // num_relations >= 3 but fewer cardinality entries: second guard fires
1568        assert!(!model.prefer_leapfrog_join(3, &[100.0, 200.0], 50.0));
1569        assert!(!model.prefer_leapfrog_join(5, &[], 10.0));
1570    }
1571
1572    #[test]
1573    fn test_factorized_benefit_zero_hops() {
1574        let model = CostModel::new();
1575        // Zero hops falls through the `num_hops <= 1` guard
1576        assert_eq!(model.factorized_benefit(10.0, 0), 1.0);
1577    }
1578
1579    #[test]
1580    fn test_factorized_benefit_unit_fanout_guard() {
1581        let model = CostModel::new();
1582        // fanout exactly 1.0: no benefit regardless of hop count
1583        assert_eq!(model.factorized_benefit(1.0, 5), 1.0);
1584    }
1585}