Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7    LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8    VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12/// Cost of an operation.
13///
14/// Represents the estimated resource consumption of executing an operator.
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17    /// Estimated CPU cycles / work units.
18    pub cpu: f64,
19    /// Estimated I/O operations (page reads).
20    pub io: f64,
21    /// Estimated memory usage in bytes.
22    pub memory: f64,
23    /// Network cost (for distributed queries).
24    pub network: f64,
25}
26
27impl Cost {
28    /// Creates a zero cost.
29    #[must_use]
30    pub fn zero() -> Self {
31        Self {
32            cpu: 0.0,
33            io: 0.0,
34            memory: 0.0,
35            network: 0.0,
36        }
37    }
38
39    /// Creates a cost from CPU work units.
40    #[must_use]
41    pub fn cpu(cpu: f64) -> Self {
42        Self {
43            cpu,
44            io: 0.0,
45            memory: 0.0,
46            network: 0.0,
47        }
48    }
49
50    /// Adds I/O cost.
51    #[must_use]
52    pub fn with_io(mut self, io: f64) -> Self {
53        self.io = io;
54        self
55    }
56
57    /// Adds memory cost.
58    #[must_use]
59    pub fn with_memory(mut self, memory: f64) -> Self {
60        self.memory = memory;
61        self
62    }
63
64    /// Returns the total weighted cost.
65    ///
66    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
67    #[must_use]
68    pub fn total(&self) -> f64 {
69        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70    }
71
72    /// Returns the total cost with custom weights.
73    #[must_use]
74    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76    }
77}
78
79impl std::ops::Add for Cost {
80    type Output = Self;
81
82    fn add(self, other: Self) -> Self {
83        Self {
84            cpu: self.cpu + other.cpu,
85            io: self.io + other.io,
86            memory: self.memory + other.memory,
87            network: self.network + other.network,
88        }
89    }
90}
91
92impl std::ops::AddAssign for Cost {
93    fn add_assign(&mut self, other: Self) {
94        self.cpu += other.cpu;
95        self.io += other.io;
96        self.memory += other.memory;
97        self.network += other.network;
98    }
99}
100
101/// Cost model for estimating operator costs.
102///
103/// Default constants are calibrated relative to each other:
104/// - Tuple scan is the baseline (1x)
105/// - Hash lookup is ~3x (hash computation + potential cache miss)
106/// - Sort comparison is ~2x (key extraction + comparison)
107/// - Distance computation is ~10x (vector math)
108pub struct CostModel {
109    /// Cost per tuple processed by CPU (baseline unit).
110    cpu_tuple_cost: f64,
111    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
112    hash_lookup_cost: f64,
113    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
114    sort_comparison_cost: f64,
115    /// Average tuple size in bytes (for IO estimation).
116    avg_tuple_size: f64,
117    /// Page size in bytes.
118    page_size: f64,
119    /// Global average edge fanout (fallback when per-type stats unavailable).
120    avg_fanout: f64,
121    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
122    edge_type_degrees: HashMap<String, (f64, f64)>,
123    /// Per-label node counts for accurate scan IO estimation.
124    label_cardinalities: HashMap<String, u64>,
125    /// Total node count in the graph.
126    total_nodes: u64,
127    /// Total edge count in the graph.
128    total_edges: u64,
129}
130
131impl CostModel {
132    /// Creates a new cost model with calibrated default parameters.
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            cpu_tuple_cost: 0.01,
137            hash_lookup_cost: 0.03,
138            sort_comparison_cost: 0.02,
139            avg_tuple_size: 100.0,
140            page_size: 8192.0,
141            avg_fanout: 10.0,
142            edge_type_degrees: HashMap::new(),
143            label_cardinalities: HashMap::new(),
144            total_nodes: 0,
145            total_edges: 0,
146        }
147    }
148
149    /// Sets the global average fanout from graph statistics.
150    #[must_use]
151    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153        self
154    }
155
156    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
157    ///
158    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
159    #[must_use]
160    pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161        self.edge_type_degrees = degrees;
162        self
163    }
164
165    /// Sets per-label node counts for accurate scan IO estimation.
166    #[must_use]
167    pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168        self.label_cardinalities = cardinalities;
169        self
170    }
171
172    /// Sets graph-level totals for cost estimation.
173    #[must_use]
174    pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175        self.total_nodes = total_nodes;
176        self.total_edges = total_edges;
177        self
178    }
179
180    /// Returns the fanout for a specific expand operation.
181    ///
182    /// Uses per-edge-type degree stats when available, falling back to the
183    /// global average fanout. For multiple edge types, sums per-type fanouts.
184    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185        if expand.edge_types.is_empty() {
186            return self.avg_fanout;
187        }
188
189        let mut total_fanout = 0.0;
190        let mut all_found = true;
191
192        for edge_type in &expand.edge_types {
193            if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194                total_fanout += match expand.direction {
195                    ExpandDirection::Outgoing => out_deg,
196                    ExpandDirection::Incoming => in_deg,
197                    ExpandDirection::Both => out_deg + in_deg,
198                };
199            } else {
200                all_found = false;
201                break;
202            }
203        }
204
205        if all_found {
206            total_fanout
207        } else {
208            self.avg_fanout
209        }
210    }
211
212    /// Estimates the cost of a logical operator.
213    #[must_use]
214    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215        match op {
216            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227            LogicalOperator::Empty => Cost::zero(),
228            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230            LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231            LogicalOperator::LeftJoin(lj) => {
232                let left_card = self.estimate_child_cardinality(&lj.left);
233                let right_card = self.estimate_child_cardinality(&lj.right);
234                self.left_join_cost(lj, cardinality, left_card, right_card)
235            }
236            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
237        }
238    }
239
240    /// Estimates the cost of a node scan.
241    ///
242    /// When label statistics are available, uses the actual label cardinality
243    /// for IO estimation (pages to read) rather than the optimizer's cardinality
244    /// estimate, which may already account for filter selectivity.
245    fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
246        // IO cost: based on how many nodes we actually need to scan from storage
247        let scan_size = if let Some(label) = &scan.label {
248            self.label_cardinalities
249                .get(label)
250                .map_or(cardinality, |&count| count as f64)
251        } else if self.total_nodes > 0 {
252            self.total_nodes as f64
253        } else {
254            cardinality
255        };
256        let pages = (scan_size * self.avg_tuple_size) / self.page_size;
257        // CPU cost: only pay for rows that pass the label filter
258        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
259    }
260
261    /// Estimates the cost of a filter operation.
262    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
263        // Filter cost is just predicate evaluation per tuple
264        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
265    }
266
267    /// Estimates the cost of a projection.
268    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
269        // Cost depends on number of expressions evaluated
270        let expr_count = project.projections.len() as f64;
271        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
272    }
273
274    /// Estimates the cost of an expand operation.
275    ///
276    /// Uses per-edge-type degree stats when available, otherwise falls back
277    /// to the global average fanout.
278    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
279        let fanout = self.fanout_for_expand(expand);
280        // Adjacency list lookup per input row
281        let lookup_cost = cardinality * self.hash_lookup_cost;
282        // Process each expanded output tuple
283        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
284        Cost::cpu(lookup_cost + output_cost)
285    }
286
287    /// Estimates the cost of a join operation.
288    ///
289    /// When child cardinalities are known (from recursive estimation), uses them
290    /// directly for build/probe cost. Otherwise falls back to sqrt approximation.
291    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
292        self.join_cost_with_children(join, cardinality, None, None)
293    }
294
295    /// Estimates join cost using actual child cardinalities.
296    fn join_cost_with_children(
297        &self,
298        join: &JoinOp,
299        cardinality: f64,
300        left_cardinality: Option<f64>,
301        right_cardinality: Option<f64>,
302    ) -> Cost {
303        match join.join_type {
304            JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
305            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
306                // Hash join: build the smaller side, probe with the larger
307                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
308                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
309
310                let build_cost = build_cardinality * self.hash_lookup_cost;
311                let memory_cost = build_cardinality * self.avg_tuple_size;
312                let probe_cost = probe_cardinality * self.hash_lookup_cost;
313                let output_cost = cardinality * self.cpu_tuple_cost;
314
315                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
316            }
317            JoinType::Semi | JoinType::Anti => {
318                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
319                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
320
321                let build_cost = build_cardinality * self.hash_lookup_cost;
322                let probe_cost = probe_cardinality * self.hash_lookup_cost;
323
324                Cost::cpu(build_cost + probe_cost)
325                    .with_memory(build_cardinality * self.avg_tuple_size)
326            }
327        }
328    }
329
330    /// Estimates the cost of a left outer join (OPTIONAL MATCH).
331    ///
332    /// Uses the same hash join cost model as inner joins: the right side is
333    /// built into a hash table, the left side probes. Output cost includes all
334    /// left rows (some with NULL-padded right side).
335    fn left_join_cost(
336        &self,
337        _lj: &LeftJoinOp,
338        cardinality: f64,
339        left_card: f64,
340        right_card: f64,
341    ) -> Cost {
342        let build_cost = right_card * self.hash_lookup_cost;
343        let memory_cost = right_card * self.avg_tuple_size;
344        let probe_cost = left_card * self.hash_lookup_cost;
345        let output_cost = cardinality * self.cpu_tuple_cost;
346
347        Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
348    }
349
350    /// Estimates the cardinality of a child operator using available statistics.
351    ///
352    /// Used internally to compute left/right child cardinalities for joins
353    /// without requiring an external `CardinalityEstimator`. Uses the same
354    /// statistics the cost model already holds (label counts, fanout).
355    fn estimate_child_cardinality(&self, op: &LogicalOperator) -> f64 {
356        match op {
357            LogicalOperator::NodeScan(scan) => if let Some(label) = &scan.label {
358                self.label_cardinalities
359                    .get(label)
360                    .map_or(self.total_nodes as f64, |&c| c as f64)
361            } else {
362                self.total_nodes as f64
363            }
364            .max(1.0),
365            LogicalOperator::Expand(expand) => {
366                let input_card = self.estimate_child_cardinality(&expand.input);
367                let fanout = if expand.edge_types.is_empty() {
368                    self.avg_fanout
369                } else {
370                    self.fanout_for_expand(expand)
371                };
372                (input_card * fanout).max(1.0)
373            }
374            LogicalOperator::Filter(filter) => {
375                // Approximate: filters reduce cardinality by 10%
376                (self.estimate_child_cardinality(&filter.input) * 0.1).max(1.0)
377            }
378            LogicalOperator::Return(ret) => self.estimate_child_cardinality(&ret.input),
379            LogicalOperator::Limit(limit) => {
380                let input = self.estimate_child_cardinality(&limit.input);
381                // LimitOp.count may be a parameter; use input as upper bound
382                input.min(100.0)
383            }
384            _ => (self.total_nodes as f64).max(1.0),
385        }
386    }
387
388    /// Estimates the cost of a multi-way (leapfrog) join.
389    ///
390    /// Delegates to `leapfrog_join_cost` using per-input cardinality estimates
391    /// derived from the output cardinality divided equally among inputs.
392    fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
393        let n = mwj.inputs.len();
394        if n == 0 {
395            return Cost::zero();
396        }
397        // Approximate per-input cardinalities: assume each input contributes
398        // cardinality^(1/n) rows (AGM-style uniform assumption)
399        let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
400        let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
401        self.leapfrog_join_cost(n, &cardinalities, cardinality)
402    }
403
404    /// Estimates the cost of an aggregation.
405    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
406        // Hash aggregation cost
407        let hash_cost = cardinality * self.hash_lookup_cost;
408
409        // Aggregate function evaluation
410        let agg_count = agg.aggregates.len() as f64;
411        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
412
413        // Memory for hash table (estimated distinct groups)
414        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
415        let memory_cost = distinct_groups * self.avg_tuple_size;
416
417        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
418    }
419
420    /// Estimates the cost of a sort operation.
421    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
422        if cardinality <= 1.0 {
423            return Cost::zero();
424        }
425
426        // Sort is O(n log n) comparisons
427        let comparisons = cardinality * cardinality.log2();
428        let key_count = sort.keys.len() as f64;
429
430        // Memory for sorting (full input materialization)
431        let memory_cost = cardinality * self.avg_tuple_size;
432
433        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
434    }
435
436    /// Estimates the cost of a distinct operation.
437    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
438        // Hash-based distinct
439        let hash_cost = cardinality * self.hash_lookup_cost;
440        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
441
442        Cost::cpu(hash_cost).with_memory(memory_cost)
443    }
444
445    /// Estimates the cost of a limit operation.
446    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
447        // Limit is very cheap - just counting
448        Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
449    }
450
451    /// Estimates the cost of a skip operation.
452    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
453        // Skip requires scanning through skipped rows
454        Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
455    }
456
457    /// Estimates the cost of a return operation.
458    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
459        // Return materializes results
460        let expr_count = ret.items.len() as f64;
461        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
462    }
463
464    /// Estimates the cost of a vector scan operation.
465    ///
466    /// HNSW index search is O(log N) per query, while brute-force is O(N).
467    /// This estimates the HNSW case with ef search parameter.
468    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
469        // k determines output cardinality
470        let k = scan.k as f64;
471
472        // HNSW search cost: O(ef * log(N)) distance computations
473        // Assume ef = 64 (default), N = cardinality
474        let ef = 64.0;
475        let n = cardinality.max(1.0);
476        let search_cost = if scan.index_name.is_some() {
477            // HNSW: O(ef * log N)
478            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
479        } else {
480            // Brute-force: O(N)
481            n * self.cpu_tuple_cost * 10.0
482        };
483
484        // Memory for candidate heap
485        let memory = k * self.avg_tuple_size * 2.0;
486
487        Cost::cpu(search_cost).with_memory(memory)
488    }
489
490    /// Estimates the cost of a vector join operation.
491    ///
492    /// Vector join performs k-NN search for each input row.
493    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
494        let k = join.k as f64;
495
496        // Each input row triggers a vector search
497        // Assume brute-force for hybrid queries (no index specified typically)
498        let per_row_search_cost = if join.index_name.is_some() {
499            // HNSW: O(ef * log N)
500            let ef = 64.0;
501            let n = cardinality.max(1.0);
502            ef * n.ln() * self.cpu_tuple_cost * 10.0
503        } else {
504            // Brute-force: O(N) per input row
505            cardinality * self.cpu_tuple_cost * 10.0
506        };
507
508        // Total cost: input_rows * search_cost
509        // For vector join, cardinality is typically input cardinality * k
510        let input_cardinality = (cardinality / k).max(1.0);
511        let total_search_cost = input_cardinality * per_row_search_cost;
512
513        // Memory for results
514        let memory = cardinality * self.avg_tuple_size;
515
516        Cost::cpu(total_search_cost).with_memory(memory)
517    }
518
519    /// Estimates the total cost of an operator tree recursively.
520    ///
521    /// Walks the entire plan tree, computing per-operator cost at each level
522    /// using the cardinality estimator for accurate child cardinalities.
523    /// Returns the sum of all operator costs in the tree.
524    #[must_use]
525    pub fn estimate_tree(
526        &self,
527        op: &LogicalOperator,
528        card_estimator: &super::CardinalityEstimator,
529    ) -> Cost {
530        self.estimate_tree_inner(op, card_estimator)
531    }
532
533    fn estimate_tree_inner(
534        &self,
535        op: &LogicalOperator,
536        card_est: &super::CardinalityEstimator,
537    ) -> Cost {
538        let cardinality = card_est.estimate(op);
539
540        match op {
541            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
542            LogicalOperator::Filter(filter) => {
543                let child_cost = self.estimate_tree_inner(&filter.input, card_est);
544                child_cost + self.filter_cost(filter, cardinality)
545            }
546            LogicalOperator::Project(project) => {
547                let child_cost = self.estimate_tree_inner(&project.input, card_est);
548                child_cost + self.project_cost(project, cardinality)
549            }
550            LogicalOperator::Expand(expand) => {
551                let child_cost = self.estimate_tree_inner(&expand.input, card_est);
552                child_cost + self.expand_cost(expand, cardinality)
553            }
554            LogicalOperator::Join(join) => {
555                let left_cost = self.estimate_tree_inner(&join.left, card_est);
556                let right_cost = self.estimate_tree_inner(&join.right, card_est);
557                let left_card = card_est.estimate(&join.left);
558                let right_card = card_est.estimate(&join.right);
559                let join_cost = self.join_cost_with_children(
560                    join,
561                    cardinality,
562                    Some(left_card),
563                    Some(right_card),
564                );
565                left_cost + right_cost + join_cost
566            }
567            LogicalOperator::LeftJoin(lj) => {
568                let left_cost = self.estimate_tree_inner(&lj.left, card_est);
569                let right_cost = self.estimate_tree_inner(&lj.right, card_est);
570                let left_card = card_est.estimate(&lj.left);
571                let right_card = card_est.estimate(&lj.right);
572                let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
573                left_cost + right_cost + join_cost
574            }
575            LogicalOperator::Aggregate(agg) => {
576                let child_cost = self.estimate_tree_inner(&agg.input, card_est);
577                child_cost + self.aggregate_cost(agg, cardinality)
578            }
579            LogicalOperator::Sort(sort) => {
580                let child_cost = self.estimate_tree_inner(&sort.input, card_est);
581                child_cost + self.sort_cost(sort, cardinality)
582            }
583            LogicalOperator::Distinct(distinct) => {
584                let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
585                child_cost + self.distinct_cost(distinct, cardinality)
586            }
587            LogicalOperator::Limit(limit) => {
588                let child_cost = self.estimate_tree_inner(&limit.input, card_est);
589                child_cost + self.limit_cost(limit, cardinality)
590            }
591            LogicalOperator::Skip(skip) => {
592                let child_cost = self.estimate_tree_inner(&skip.input, card_est);
593                child_cost + self.skip_cost(skip, cardinality)
594            }
595            LogicalOperator::Return(ret) => {
596                let child_cost = self.estimate_tree_inner(&ret.input, card_est);
597                child_cost + self.return_cost(ret, cardinality)
598            }
599            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
600            LogicalOperator::VectorJoin(join) => {
601                let child_cost = self.estimate_tree_inner(&join.input, card_est);
602                child_cost + self.vector_join_cost(join, cardinality)
603            }
604            LogicalOperator::MultiWayJoin(mwj) => {
605                let mut children_cost = Cost::zero();
606                for input in &mwj.inputs {
607                    children_cost += self.estimate_tree_inner(input, card_est);
608                }
609                children_cost + self.multi_way_join_cost(mwj, cardinality)
610            }
611            LogicalOperator::Empty => Cost::zero(),
612            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
613        }
614    }
615
616    /// Compares two costs and returns the cheaper one.
617    #[must_use]
618    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
619        if a.total() <= b.total() { a } else { b }
620    }
621
622    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
623    ///
624    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
625    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
626    /// all relations simultaneously using sorted iterators.
627    ///
628    /// # Arguments
629    /// * `num_relations` - Number of relations participating in the join
630    /// * `cardinalities` - Cardinality of each input relation
631    /// * `output_cardinality` - Expected output cardinality
632    ///
633    /// # Cost Model
634    /// - Materialization: O(sum of cardinalities) to build trie indexes
635    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
636    /// - Memory: Trie storage for all inputs
637    #[must_use]
638    pub fn leapfrog_join_cost(
639        &self,
640        num_relations: usize,
641        cardinalities: &[f64],
642        output_cardinality: f64,
643    ) -> Cost {
644        if cardinalities.is_empty() {
645            return Cost::zero();
646        }
647
648        let total_input: f64 = cardinalities.iter().sum();
649        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
650
651        // Materialization phase: build trie indexes for each input
652        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
653
654        // Intersection phase: leapfrog seeks are O(log n) per relation
655        let seek_cost = if min_card > 1.0 {
656            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
657        } else {
658            output_cardinality * self.cpu_tuple_cost
659        };
660
661        // Output materialization
662        let output_cost = output_cardinality * self.cpu_tuple_cost;
663
664        // Memory: trie storage (roughly 2x input size for sorted index)
665        let memory = total_input * self.avg_tuple_size * 2.0;
666
667        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
668    }
669
670    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
671    ///
672    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
673    #[must_use]
674    pub fn prefer_leapfrog_join(
675        &self,
676        num_relations: usize,
677        cardinalities: &[f64],
678        output_cardinality: f64,
679    ) -> bool {
680        if num_relations < 3 || cardinalities.len() < 3 {
681            // Leapfrog is only beneficial for multi-way joins (3+)
682            return false;
683        }
684
685        let leapfrog_cost =
686            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
687
688        // Estimate cascade of binary hash joins
689        // For N relations, we need N-1 joins
690        // Each join produces intermediate results that feed the next
691        let mut hash_cascade_cost = Cost::zero();
692        let mut intermediate_cardinality = cardinalities[0];
693
694        for card in &cardinalities[1..] {
695            // Hash join cost: build + probe + output
696            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
697            let join = JoinOp {
698                left: Box::new(LogicalOperator::Empty),
699                right: Box::new(LogicalOperator::Empty),
700                join_type: JoinType::Inner,
701                conditions: vec![],
702            };
703            hash_cascade_cost += self.join_cost(&join, join_output);
704            intermediate_cardinality = join_output;
705        }
706
707        leapfrog_cost.total() < hash_cascade_cost.total()
708    }
709
710    /// Estimates cost for factorized execution (compressed intermediate results).
711    ///
712    /// Factorized execution avoids materializing full cross products by keeping
713    /// results in a compressed "factorized" form. This is beneficial for multi-hop
714    /// traversals where intermediate results can explode.
715    ///
716    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
717    #[must_use]
718    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
719        if num_hops <= 1 || avg_fanout <= 1.0 {
720            return 1.0; // No benefit for single hop or low fanout
721        }
722
723        // Factorized representation compresses repeated prefixes
724        // Compression ratio improves with higher fanout and more hops
725        // Full materialization: fanout^hops
726        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
727
728        // reason: hop count is always small (< 100)
729        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
730        let hops_i32 = num_hops as i32;
731        let full_size = avg_fanout.powi(hops_i32);
732        let factorized_size = if avg_fanout > 1.0 {
733            (avg_fanout.powi(hops_i32 + 1) - 1.0) / (avg_fanout - 1.0)
734        } else {
735            num_hops as f64
736        };
737
738        (factorized_size / full_size).min(1.0)
739    }
740}
741
742impl Default for CostModel {
743    fn default() -> Self {
744        Self::new()
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use crate::query::plan::{
752        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
753        PathMode, Projection, ReturnItem, SortOrder,
754    };
755
756    #[test]
757    fn test_cost_addition() {
758        let a = Cost::cpu(10.0).with_io(5.0);
759        let b = Cost::cpu(20.0).with_memory(100.0);
760        let c = a + b;
761
762        assert!((c.cpu - 30.0).abs() < 0.001);
763        assert!((c.io - 5.0).abs() < 0.001);
764        assert!((c.memory - 100.0).abs() < 0.001);
765    }
766
767    #[test]
768    fn test_cost_total() {
769        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
770        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
771        assert!((cost.total() - 30.0).abs() < 0.001);
772    }
773
774    #[test]
775    fn test_cost_model_node_scan() {
776        let model = CostModel::new();
777        let scan = NodeScanOp {
778            variable: "n".to_string(),
779            label: Some("Person".to_string()),
780            input: None,
781        };
782        let cost = model.node_scan_cost(&scan, 1000.0);
783
784        assert!(cost.cpu > 0.0);
785        assert!(cost.io > 0.0);
786    }
787
788    #[test]
789    fn test_cost_model_sort() {
790        let model = CostModel::new();
791        let sort = SortOp {
792            keys: vec![],
793            input: Box::new(LogicalOperator::Empty),
794        };
795
796        let cost_100 = model.sort_cost(&sort, 100.0);
797        let cost_1000 = model.sort_cost(&sort, 1000.0);
798
799        // Sorting 1000 rows should be more expensive than 100 rows
800        assert!(cost_1000.total() > cost_100.total());
801    }
802
803    #[test]
804    fn test_cost_zero() {
805        let cost = Cost::zero();
806        assert!((cost.cpu).abs() < 0.001);
807        assert!((cost.io).abs() < 0.001);
808        assert!((cost.memory).abs() < 0.001);
809        assert!((cost.network).abs() < 0.001);
810        assert!((cost.total()).abs() < 0.001);
811    }
812
813    #[test]
814    fn test_cost_add_assign() {
815        let mut cost = Cost::cpu(10.0);
816        cost += Cost::cpu(5.0).with_io(2.0);
817        assert!((cost.cpu - 15.0).abs() < 0.001);
818        assert!((cost.io - 2.0).abs() < 0.001);
819    }
820
821    #[test]
822    fn test_cost_total_weighted() {
823        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
824        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
825        let total = cost.total_weighted(2.0, 5.0, 0.5);
826        assert!((total - 80.0).abs() < 0.001);
827    }
828
829    #[test]
830    fn test_cost_model_filter() {
831        let model = CostModel::new();
832        let filter = FilterOp {
833            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
834            input: Box::new(LogicalOperator::Empty),
835            pushdown_hint: None,
836        };
837        let cost = model.filter_cost(&filter, 1000.0);
838
839        // Filter cost is CPU only
840        assert!(cost.cpu > 0.0);
841        assert!((cost.io).abs() < 0.001);
842    }
843
844    #[test]
845    fn test_cost_model_project() {
846        let model = CostModel::new();
847        let project = ProjectOp {
848            projections: vec![
849                Projection {
850                    expression: LogicalExpression::Variable("a".to_string()),
851                    alias: None,
852                },
853                Projection {
854                    expression: LogicalExpression::Variable("b".to_string()),
855                    alias: None,
856                },
857            ],
858            input: Box::new(LogicalOperator::Empty),
859            pass_through_input: false,
860        };
861        let cost = model.project_cost(&project, 1000.0);
862
863        // Cost should scale with number of projections
864        assert!(cost.cpu > 0.0);
865    }
866
867    #[test]
868    fn test_cost_model_expand() {
869        let model = CostModel::new();
870        let expand = ExpandOp {
871            from_variable: "a".to_string(),
872            to_variable: "b".to_string(),
873            edge_variable: None,
874            direction: ExpandDirection::Outgoing,
875            edge_types: vec![],
876            min_hops: 1,
877            max_hops: Some(1),
878            input: Box::new(LogicalOperator::Empty),
879            path_alias: None,
880            path_mode: PathMode::Walk,
881        };
882        let cost = model.expand_cost(&expand, 1000.0);
883
884        // Expand involves hash lookups and output generation
885        assert!(cost.cpu > 0.0);
886    }
887
888    #[test]
889    fn test_cost_model_expand_with_edge_type_stats() {
890        let mut degrees = std::collections::HashMap::new();
891        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
892        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
893
894        let model = CostModel::new().with_edge_type_degrees(degrees);
895
896        // Outgoing KNOWS: fanout = 5
897        let knows_out = ExpandOp {
898            from_variable: "a".to_string(),
899            to_variable: "b".to_string(),
900            edge_variable: None,
901            direction: ExpandDirection::Outgoing,
902            edge_types: vec!["KNOWS".to_string()],
903            min_hops: 1,
904            max_hops: Some(1),
905            input: Box::new(LogicalOperator::Empty),
906            path_alias: None,
907            path_mode: PathMode::Walk,
908        };
909        let cost_knows = model.expand_cost(&knows_out, 1000.0);
910
911        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
912        let works_out = ExpandOp {
913            from_variable: "a".to_string(),
914            to_variable: "b".to_string(),
915            edge_variable: None,
916            direction: ExpandDirection::Outgoing,
917            edge_types: vec!["WORKS_AT".to_string()],
918            min_hops: 1,
919            max_hops: Some(1),
920            input: Box::new(LogicalOperator::Empty),
921            path_alias: None,
922            path_mode: PathMode::Walk,
923        };
924        let cost_works = model.expand_cost(&works_out, 1000.0);
925
926        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
927        assert!(
928            cost_knows.cpu > cost_works.cpu,
929            "KNOWS(5) should cost more than WORKS_AT(1)"
930        );
931
932        // Incoming WORKS_AT: fanout = 50 (company has many employees)
933        let works_in = ExpandOp {
934            from_variable: "c".to_string(),
935            to_variable: "p".to_string(),
936            edge_variable: None,
937            direction: ExpandDirection::Incoming,
938            edge_types: vec!["WORKS_AT".to_string()],
939            min_hops: 1,
940            max_hops: Some(1),
941            input: Box::new(LogicalOperator::Empty),
942            path_alias: None,
943            path_mode: PathMode::Walk,
944        };
945        let cost_works_in = model.expand_cost(&works_in, 1000.0);
946
947        // Incoming WORKS_AT (fanout=50) should be most expensive
948        assert!(
949            cost_works_in.cpu > cost_knows.cpu,
950            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
951        );
952    }
953
954    #[test]
955    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
956        let model = CostModel::new().with_avg_fanout(7.0);
957        let expand = ExpandOp {
958            from_variable: "a".to_string(),
959            to_variable: "b".to_string(),
960            edge_variable: None,
961            direction: ExpandDirection::Outgoing,
962            edge_types: vec!["UNKNOWN_TYPE".to_string()],
963            min_hops: 1,
964            max_hops: Some(1),
965            input: Box::new(LogicalOperator::Empty),
966            path_alias: None,
967            path_mode: PathMode::Walk,
968        };
969        let cost_unknown = model.expand_cost(&expand, 1000.0);
970
971        // Without edge type (uses global fanout too)
972        let expand_no_type = ExpandOp {
973            from_variable: "a".to_string(),
974            to_variable: "b".to_string(),
975            edge_variable: None,
976            direction: ExpandDirection::Outgoing,
977            edge_types: vec![],
978            min_hops: 1,
979            max_hops: Some(1),
980            input: Box::new(LogicalOperator::Empty),
981            path_alias: None,
982            path_mode: PathMode::Walk,
983        };
984        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
985
986        // Both should use global fanout = 7, so costs should be equal
987        assert!(
988            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
989            "Unknown edge type should use global fanout"
990        );
991    }
992
993    #[test]
994    fn test_cost_model_hash_join() {
995        let model = CostModel::new();
996        let join = JoinOp {
997            left: Box::new(LogicalOperator::Empty),
998            right: Box::new(LogicalOperator::Empty),
999            join_type: JoinType::Inner,
1000            conditions: vec![JoinCondition {
1001                left: LogicalExpression::Variable("a".to_string()),
1002                right: LogicalExpression::Variable("b".to_string()),
1003            }],
1004        };
1005        let cost = model.join_cost(&join, 10000.0);
1006
1007        // Hash join has CPU cost and memory cost
1008        assert!(cost.cpu > 0.0);
1009        assert!(cost.memory > 0.0);
1010    }
1011
1012    #[test]
1013    fn test_cost_model_cross_join() {
1014        let model = CostModel::new();
1015        let join = JoinOp {
1016            left: Box::new(LogicalOperator::Empty),
1017            right: Box::new(LogicalOperator::Empty),
1018            join_type: JoinType::Cross,
1019            conditions: vec![],
1020        };
1021        let cost = model.join_cost(&join, 1000000.0);
1022
1023        // Cross join is expensive
1024        assert!(cost.cpu > 0.0);
1025    }
1026
1027    #[test]
1028    fn test_cost_model_semi_join() {
1029        let model = CostModel::new();
1030        let join = JoinOp {
1031            left: Box::new(LogicalOperator::Empty),
1032            right: Box::new(LogicalOperator::Empty),
1033            join_type: JoinType::Semi,
1034            conditions: vec![],
1035        };
1036        let cost_semi = model.join_cost(&join, 1000.0);
1037
1038        let inner_join = JoinOp {
1039            left: Box::new(LogicalOperator::Empty),
1040            right: Box::new(LogicalOperator::Empty),
1041            join_type: JoinType::Inner,
1042            conditions: vec![],
1043        };
1044        let cost_inner = model.join_cost(&inner_join, 1000.0);
1045
1046        // Semi join can be cheaper than inner join
1047        assert!(cost_semi.cpu > 0.0);
1048        assert!(cost_inner.cpu > 0.0);
1049    }
1050
1051    #[test]
1052    fn test_cost_model_aggregate() {
1053        let model = CostModel::new();
1054        let agg = AggregateOp {
1055            group_by: vec![],
1056            aggregates: vec![
1057                AggregateExpr {
1058                    function: AggregateFunction::Count,
1059                    expression: None,
1060                    expression2: None,
1061                    distinct: false,
1062                    alias: Some("cnt".to_string()),
1063                    percentile: None,
1064                    separator: None,
1065                },
1066                AggregateExpr {
1067                    function: AggregateFunction::Sum,
1068                    expression: Some(LogicalExpression::Variable("x".to_string())),
1069                    expression2: None,
1070                    distinct: false,
1071                    alias: Some("total".to_string()),
1072                    percentile: None,
1073                    separator: None,
1074                },
1075            ],
1076            input: Box::new(LogicalOperator::Empty),
1077            having: None,
1078        };
1079        let cost = model.aggregate_cost(&agg, 1000.0);
1080
1081        // Aggregation has hash cost and memory cost
1082        assert!(cost.cpu > 0.0);
1083        assert!(cost.memory > 0.0);
1084    }
1085
1086    #[test]
1087    fn test_cost_model_distinct() {
1088        let model = CostModel::new();
1089        let distinct = DistinctOp {
1090            input: Box::new(LogicalOperator::Empty),
1091            columns: None,
1092        };
1093        let cost = model.distinct_cost(&distinct, 1000.0);
1094
1095        // Distinct uses hash set
1096        assert!(cost.cpu > 0.0);
1097        assert!(cost.memory > 0.0);
1098    }
1099
1100    #[test]
1101    fn test_cost_model_limit() {
1102        let model = CostModel::new();
1103        let limit = LimitOp {
1104            count: 10.into(),
1105            input: Box::new(LogicalOperator::Empty),
1106        };
1107        let cost = model.limit_cost(&limit, 1000.0);
1108
1109        // Limit is very cheap
1110        assert!(cost.cpu > 0.0);
1111        assert!(cost.cpu < 1.0); // Should be minimal
1112    }
1113
1114    #[test]
1115    fn test_cost_model_skip() {
1116        let model = CostModel::new();
1117        let skip = SkipOp {
1118            count: 100.into(),
1119            input: Box::new(LogicalOperator::Empty),
1120        };
1121        let cost = model.skip_cost(&skip, 1000.0);
1122
1123        // Skip must scan through skipped rows
1124        assert!(cost.cpu > 0.0);
1125    }
1126
1127    #[test]
1128    fn test_cost_model_return() {
1129        let model = CostModel::new();
1130        let ret = ReturnOp {
1131            items: vec![
1132                ReturnItem {
1133                    expression: LogicalExpression::Variable("a".to_string()),
1134                    alias: None,
1135                },
1136                ReturnItem {
1137                    expression: LogicalExpression::Variable("b".to_string()),
1138                    alias: None,
1139                },
1140            ],
1141            distinct: false,
1142            input: Box::new(LogicalOperator::Empty),
1143        };
1144        let cost = model.return_cost(&ret, 1000.0);
1145
1146        // Return materializes results
1147        assert!(cost.cpu > 0.0);
1148    }
1149
1150    #[test]
1151    fn test_cost_cheaper() {
1152        let model = CostModel::new();
1153        let cheap = Cost::cpu(10.0);
1154        let expensive = Cost::cpu(100.0);
1155
1156        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1157        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1158    }
1159
1160    #[test]
1161    fn test_cost_comparison_prefers_lower_total() {
1162        let model = CostModel::new();
1163        // High CPU, low IO
1164        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1165        // Low CPU, high IO
1166        let io_heavy = Cost::cpu(10.0).with_io(20.0);
1167
1168        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
1169        assert!(cpu_heavy.total() < io_heavy.total());
1170        assert_eq!(
1171            model.cheaper(&cpu_heavy, &io_heavy).total(),
1172            cpu_heavy.total()
1173        );
1174    }
1175
1176    #[test]
1177    fn test_cost_model_sort_with_keys() {
1178        let model = CostModel::new();
1179        let sort_single = SortOp {
1180            keys: vec![crate::query::plan::SortKey {
1181                expression: LogicalExpression::Variable("a".to_string()),
1182                order: SortOrder::Ascending,
1183                nulls: None,
1184            }],
1185            input: Box::new(LogicalOperator::Empty),
1186        };
1187        let sort_multi = SortOp {
1188            keys: vec![
1189                crate::query::plan::SortKey {
1190                    expression: LogicalExpression::Variable("a".to_string()),
1191                    order: SortOrder::Ascending,
1192                    nulls: None,
1193                },
1194                crate::query::plan::SortKey {
1195                    expression: LogicalExpression::Variable("b".to_string()),
1196                    order: SortOrder::Descending,
1197                    nulls: None,
1198                },
1199            ],
1200            input: Box::new(LogicalOperator::Empty),
1201        };
1202
1203        let cost_single = model.sort_cost(&sort_single, 1000.0);
1204        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1205
1206        // More sort keys = more comparisons
1207        assert!(cost_multi.cpu > cost_single.cpu);
1208    }
1209
1210    #[test]
1211    fn test_cost_model_empty_operator() {
1212        let model = CostModel::new();
1213        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1214        assert!((cost.total()).abs() < 0.001);
1215    }
1216
1217    #[test]
1218    fn test_cost_model_default() {
1219        let model = CostModel::default();
1220        let scan = NodeScanOp {
1221            variable: "n".to_string(),
1222            label: None,
1223            input: None,
1224        };
1225        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1226        assert!(cost.total() > 0.0);
1227    }
1228
1229    #[test]
1230    fn test_leapfrog_join_cost() {
1231        let model = CostModel::new();
1232
1233        // Three-way join (triangle pattern)
1234        let cardinalities = vec![1000.0, 1000.0, 1000.0];
1235        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1236
1237        // Should have CPU cost for materialization and intersection
1238        assert!(cost.cpu > 0.0);
1239        // Should have memory cost for trie storage
1240        assert!(cost.memory > 0.0);
1241    }
1242
1243    #[test]
1244    fn test_leapfrog_join_cost_empty() {
1245        let model = CostModel::new();
1246        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1247        assert!((cost.total()).abs() < 0.001);
1248    }
1249
1250    #[test]
1251    fn test_prefer_leapfrog_join_for_triangles() {
1252        let model = CostModel::new();
1253
1254        // Compare costs for triangle pattern
1255        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1256        let output = 1000.0;
1257
1258        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1259
1260        // Leapfrog should have reasonable cost for triangle patterns
1261        assert!(leapfrog_cost.cpu > 0.0);
1262        assert!(leapfrog_cost.memory > 0.0);
1263
1264        // The prefer_leapfrog_join method compares against hash cascade
1265        // Actual preference depends on specific cost parameters
1266        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1267        // Test that it returns a boolean (doesn't panic)
1268    }
1269
1270    #[test]
1271    fn test_prefer_leapfrog_join_binary_case() {
1272        let model = CostModel::new();
1273
1274        // Binary join should NOT prefer leapfrog (need 3+ relations)
1275        let cardinalities = vec![1000.0, 1000.0];
1276        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1277        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1278    }
1279
1280    #[test]
1281    fn test_factorized_benefit_single_hop() {
1282        let model = CostModel::new();
1283
1284        // Single hop: no factorization benefit
1285        let benefit = model.factorized_benefit(10.0, 1);
1286        assert!(
1287            (benefit - 1.0).abs() < 0.001,
1288            "Single hop should have no benefit"
1289        );
1290    }
1291
1292    #[test]
1293    fn test_factorized_benefit_multi_hop() {
1294        let model = CostModel::new();
1295
1296        // Multi-hop with high fanout
1297        let benefit = model.factorized_benefit(10.0, 3);
1298
1299        // The factorized_benefit returns a ratio capped at 1.0
1300        // For high fanout, factorized size / full size approaches 1/fanout
1301        // which is beneficial but the formula gives a value <= 1.0
1302        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1303        assert!(benefit > 0.0, "Benefit should be positive");
1304    }
1305
1306    #[test]
1307    fn test_factorized_benefit_low_fanout() {
1308        let model = CostModel::new();
1309
1310        // Low fanout: minimal benefit
1311        let benefit = model.factorized_benefit(1.5, 2);
1312        assert!(
1313            benefit <= 1.0,
1314            "Low fanout still benefits from factorization"
1315        );
1316    }
1317
1318    #[test]
1319    fn test_node_scan_uses_label_cardinality_for_io() {
1320        let mut label_cards = std::collections::HashMap::new();
1321        label_cards.insert("Person".to_string(), 500_u64);
1322        label_cards.insert("Company".to_string(), 50_u64);
1323
1324        let model = CostModel::new()
1325            .with_label_cardinalities(label_cards)
1326            .with_graph_totals(550, 1000);
1327
1328        let person_scan = NodeScanOp {
1329            variable: "n".to_string(),
1330            label: Some("Person".to_string()),
1331            input: None,
1332        };
1333        let company_scan = NodeScanOp {
1334            variable: "n".to_string(),
1335            label: Some("Company".to_string()),
1336            input: None,
1337        };
1338
1339        let person_cost = model.node_scan_cost(&person_scan, 500.0);
1340        let company_cost = model.node_scan_cost(&company_scan, 50.0);
1341
1342        // Person scan reads 10x more pages than Company scan
1343        assert!(
1344            person_cost.io > company_cost.io * 5.0,
1345            "Person ({}) should have much higher IO than Company ({})",
1346            person_cost.io,
1347            company_cost.io
1348        );
1349    }
1350
1351    #[test]
1352    fn test_node_scan_unlabeled_uses_total_nodes() {
1353        let model = CostModel::new().with_graph_totals(10_000, 50_000);
1354
1355        let scan = NodeScanOp {
1356            variable: "n".to_string(),
1357            label: None,
1358            input: None,
1359        };
1360
1361        let cost = model.node_scan_cost(&scan, 10_000.0);
1362        let expected_pages = (10_000.0 * 100.0) / 8192.0;
1363        assert!(
1364            (cost.io - expected_pages).abs() < 0.1,
1365            "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1366            cost.io,
1367            expected_pages
1368        );
1369    }
1370
1371    #[test]
1372    fn test_join_cost_with_actual_child_cardinalities() {
1373        let model = CostModel::new();
1374        let join = JoinOp {
1375            left: Box::new(LogicalOperator::Empty),
1376            right: Box::new(LogicalOperator::Empty),
1377            join_type: JoinType::Inner,
1378            conditions: vec![JoinCondition {
1379                left: LogicalExpression::Variable("a".to_string()),
1380                right: LogicalExpression::Variable("b".to_string()),
1381            }],
1382        };
1383
1384        // With actual child cardinalities (100 left, 10000 right)
1385        let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1386
1387        // With sqrt fallback (sqrt(500) ~ 22.4 for both sides)
1388        let cost_sqrt = model.join_cost(&join, 500.0);
1389
1390        // Actual build side (100) is larger than sqrt(500) ~ 22.4, so
1391        // actual cost should be higher for build, but the probe side (10000)
1392        // should dominate
1393        assert!(
1394            cost_actual.cpu > cost_sqrt.cpu,
1395            "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1396            cost_actual.cpu,
1397            cost_sqrt.cpu
1398        );
1399    }
1400
1401    #[test]
1402    fn test_expand_multi_edge_types() {
1403        let mut degrees = std::collections::HashMap::new();
1404        degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1405        degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1406
1407        let model = CostModel::new().with_edge_type_degrees(degrees);
1408
1409        // Multi-type outgoing: KNOWS(5) + FOLLOWS(20) = 25
1410        let multi_expand = ExpandOp {
1411            from_variable: "a".to_string(),
1412            to_variable: "b".to_string(),
1413            edge_variable: None,
1414            direction: ExpandDirection::Outgoing,
1415            edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1416            min_hops: 1,
1417            max_hops: Some(1),
1418            input: Box::new(LogicalOperator::Empty),
1419            path_alias: None,
1420            path_mode: PathMode::Walk,
1421        };
1422        let multi_cost = model.expand_cost(&multi_expand, 100.0);
1423
1424        // Single type: KNOWS(5) only
1425        let single_expand = ExpandOp {
1426            from_variable: "a".to_string(),
1427            to_variable: "b".to_string(),
1428            edge_variable: None,
1429            direction: ExpandDirection::Outgoing,
1430            edge_types: vec!["KNOWS".to_string()],
1431            min_hops: 1,
1432            max_hops: Some(1),
1433            input: Box::new(LogicalOperator::Empty),
1434            path_alias: None,
1435            path_mode: PathMode::Walk,
1436        };
1437        let single_cost = model.expand_cost(&single_expand, 100.0);
1438
1439        // Multi-type (fanout=25) should be more expensive than single (fanout=5)
1440        assert!(
1441            multi_cost.cpu > single_cost.cpu * 3.0,
1442            "Multi-type fanout ({}) should be much higher than single-type ({})",
1443            multi_cost.cpu,
1444            single_cost.cpu
1445        );
1446    }
1447
1448    #[test]
1449    fn test_recursive_tree_cost() {
1450        use crate::query::optimizer::CardinalityEstimator;
1451
1452        let mut label_cards = std::collections::HashMap::new();
1453        label_cards.insert("Person".to_string(), 1000_u64);
1454
1455        let model = CostModel::new()
1456            .with_label_cardinalities(label_cards)
1457            .with_graph_totals(1000, 5000)
1458            .with_avg_fanout(5.0);
1459
1460        let mut card_est = CardinalityEstimator::new();
1461        card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1462
1463        // Build a plan: NodeScan -> Filter -> Return
1464        let plan = LogicalOperator::Return(ReturnOp {
1465            items: vec![ReturnItem {
1466                expression: LogicalExpression::Variable("n".to_string()),
1467                alias: None,
1468            }],
1469            distinct: false,
1470            input: Box::new(LogicalOperator::Filter(FilterOp {
1471                predicate: LogicalExpression::Binary {
1472                    left: Box::new(LogicalExpression::Property {
1473                        variable: "n".to_string(),
1474                        property: "age".to_string(),
1475                    }),
1476                    op: crate::query::plan::BinaryOp::Gt,
1477                    right: Box::new(LogicalExpression::Literal(
1478                        grafeo_common::types::Value::Int64(30),
1479                    )),
1480                },
1481                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1482                    variable: "n".to_string(),
1483                    label: Some("Person".to_string()),
1484                    input: None,
1485                })),
1486                pushdown_hint: None,
1487            })),
1488        });
1489
1490        let tree_cost = model.estimate_tree(&plan, &card_est);
1491
1492        // Tree cost should include scan IO + filter CPU + return CPU
1493        assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1494        assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1495
1496        // Compare with single-operator estimate (only costs the root)
1497        let root_only_card = card_est.estimate(&plan);
1498        let root_only_cost = model.estimate(&plan, root_only_card);
1499
1500        // Tree cost should be strictly higher because it includes child costs
1501        assert!(
1502            tree_cost.total() > root_only_cost.total(),
1503            "Recursive tree cost ({}) should exceed root-only cost ({})",
1504            tree_cost.total(),
1505            root_only_cost.total()
1506        );
1507    }
1508
1509    #[test]
1510    fn test_statistics_driven_vs_default_cost() {
1511        let default_model = CostModel::new();
1512
1513        let mut label_cards = std::collections::HashMap::new();
1514        label_cards.insert("Person".to_string(), 100_u64);
1515        let stats_model = CostModel::new()
1516            .with_label_cardinalities(label_cards)
1517            .with_graph_totals(100, 500);
1518
1519        // Scan a small label: statistics model knows it's only 100 nodes
1520        let scan = NodeScanOp {
1521            variable: "n".to_string(),
1522            label: Some("Person".to_string()),
1523            input: None,
1524        };
1525
1526        let default_cost = default_model.node_scan_cost(&scan, 100.0);
1527        let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1528
1529        // With statistics, IO cost is based on actual label size (100 nodes)
1530        // Without statistics, IO cost uses cardinality parameter (also 100 here)
1531        // They should be equal in this case since cardinality matches label size
1532        assert!(
1533            (default_cost.io - stats_cost.io).abs() < 0.1,
1534            "When cardinality matches label size, costs should be similar"
1535        );
1536    }
1537
1538    #[test]
1539    fn test_leapfrog_join_cost_unit_min_cardinality() {
1540        let model = CostModel::new();
1541        // min_card == 1.0 exercises the else branch in seek_cost (output * cpu_tuple_cost)
1542        let cost = model.leapfrog_join_cost(3, &[1.0, 100.0, 200.0], 50.0);
1543        assert!(cost.cpu > 0.0);
1544        assert!(cost.memory > 0.0);
1545    }
1546
1547    #[test]
1548    fn test_prefer_leapfrog_join_cardinalities_below_three() {
1549        let model = CostModel::new();
1550        // num_relations >= 3 but fewer cardinality entries: second guard fires
1551        assert!(!model.prefer_leapfrog_join(3, &[100.0, 200.0], 50.0));
1552        assert!(!model.prefer_leapfrog_join(5, &[], 10.0));
1553    }
1554
1555    #[test]
1556    fn test_factorized_benefit_zero_hops() {
1557        let model = CostModel::new();
1558        // Zero hops falls through the `num_hops <= 1` guard
1559        assert_eq!(model.factorized_benefit(10.0, 0), 1.0);
1560    }
1561
1562    #[test]
1563    fn test_factorized_benefit_unit_fanout_guard() {
1564        let model = CostModel::new();
1565        // fanout exactly 1.0: no benefit regardless of hop count
1566        assert_eq!(model.factorized_benefit(1.0, 5), 1.0);
1567    }
1568}