Skip to main content

grafeo_engine/query/optimizer/
cost.rs

1//! Cost model for query optimization.
2//!
3//! Provides cost estimates for logical operators to guide optimization decisions.
4
5use crate::query::plan::{
6    AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7    LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8    VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12/// Cost of an operation.
13///
14/// Represents the estimated resource consumption of executing an operator.
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17    /// Estimated CPU cycles / work units.
18    pub cpu: f64,
19    /// Estimated I/O operations (page reads).
20    pub io: f64,
21    /// Estimated memory usage in bytes.
22    pub memory: f64,
23    /// Network cost (for distributed queries).
24    pub network: f64,
25}
26
27impl Cost {
28    /// Creates a zero cost.
29    #[must_use]
30    pub fn zero() -> Self {
31        Self {
32            cpu: 0.0,
33            io: 0.0,
34            memory: 0.0,
35            network: 0.0,
36        }
37    }
38
39    /// Creates a cost from CPU work units.
40    #[must_use]
41    pub fn cpu(cpu: f64) -> Self {
42        Self {
43            cpu,
44            io: 0.0,
45            memory: 0.0,
46            network: 0.0,
47        }
48    }
49
50    /// Adds I/O cost.
51    #[must_use]
52    pub fn with_io(mut self, io: f64) -> Self {
53        self.io = io;
54        self
55    }
56
57    /// Adds memory cost.
58    #[must_use]
59    pub fn with_memory(mut self, memory: f64) -> Self {
60        self.memory = memory;
61        self
62    }
63
64    /// Returns the total weighted cost.
65    ///
66    /// Uses default weights: CPU=1.0, IO=10.0, Memory=0.1, Network=100.0
67    #[must_use]
68    pub fn total(&self) -> f64 {
69        self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70    }
71
72    /// Returns the total cost with custom weights.
73    #[must_use]
74    pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75        self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76    }
77}
78
79impl std::ops::Add for Cost {
80    type Output = Self;
81
82    fn add(self, other: Self) -> Self {
83        Self {
84            cpu: self.cpu + other.cpu,
85            io: self.io + other.io,
86            memory: self.memory + other.memory,
87            network: self.network + other.network,
88        }
89    }
90}
91
92impl std::ops::AddAssign for Cost {
93    fn add_assign(&mut self, other: Self) {
94        self.cpu += other.cpu;
95        self.io += other.io;
96        self.memory += other.memory;
97        self.network += other.network;
98    }
99}
100
101/// Cost model for estimating operator costs.
102///
103/// Default constants are calibrated relative to each other:
104/// - Tuple scan is the baseline (1x)
105/// - Hash lookup is ~3x (hash computation + potential cache miss)
106/// - Sort comparison is ~2x (key extraction + comparison)
107/// - Distance computation is ~10x (vector math)
108pub struct CostModel {
109    /// Cost per tuple processed by CPU (baseline unit).
110    cpu_tuple_cost: f64,
111    /// Cost per hash table lookup (~3x tuple cost: hash + cache miss).
112    hash_lookup_cost: f64,
113    /// Cost per comparison in sorting (~2x tuple cost: key extract + cmp).
114    sort_comparison_cost: f64,
115    /// Average tuple size in bytes (for IO estimation).
116    avg_tuple_size: f64,
117    /// Page size in bytes.
118    page_size: f64,
119    /// Global average edge fanout (fallback when per-type stats unavailable).
120    avg_fanout: f64,
121    /// Per-edge-type degree stats: (avg_out_degree, avg_in_degree).
122    edge_type_degrees: HashMap<String, (f64, f64)>,
123    /// Per-label node counts for accurate scan IO estimation.
124    label_cardinalities: HashMap<String, u64>,
125    /// Total node count in the graph.
126    total_nodes: u64,
127    /// Total edge count in the graph.
128    total_edges: u64,
129}
130
131impl CostModel {
132    /// Creates a new cost model with calibrated default parameters.
133    #[must_use]
134    pub fn new() -> Self {
135        Self {
136            cpu_tuple_cost: 0.01,
137            hash_lookup_cost: 0.03,
138            sort_comparison_cost: 0.02,
139            avg_tuple_size: 100.0,
140            page_size: 8192.0,
141            avg_fanout: 10.0,
142            edge_type_degrees: HashMap::new(),
143            label_cardinalities: HashMap::new(),
144            total_nodes: 0,
145            total_edges: 0,
146        }
147    }
148
149    /// Sets the global average fanout from graph statistics.
150    #[must_use]
151    pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152        self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153        self
154    }
155
156    /// Sets per-edge-type degree statistics for accurate expand cost estimation.
157    ///
158    /// Each entry maps edge type name to `(avg_out_degree, avg_in_degree)`.
159    #[must_use]
160    pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161        self.edge_type_degrees = degrees;
162        self
163    }
164
165    /// Sets per-label node counts for accurate scan IO estimation.
166    #[must_use]
167    pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168        self.label_cardinalities = cardinalities;
169        self
170    }
171
172    /// Sets graph-level totals for cost estimation.
173    #[must_use]
174    pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175        self.total_nodes = total_nodes;
176        self.total_edges = total_edges;
177        self
178    }
179
180    /// Returns the fanout for a specific expand operation.
181    ///
182    /// Uses per-edge-type degree stats when available, falling back to the
183    /// global average fanout. For multiple edge types, sums per-type fanouts.
184    fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185        if expand.edge_types.is_empty() {
186            return self.avg_fanout;
187        }
188
189        let mut total_fanout = 0.0;
190        let mut all_found = true;
191
192        for edge_type in &expand.edge_types {
193            if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194                total_fanout += match expand.direction {
195                    ExpandDirection::Outgoing => out_deg,
196                    ExpandDirection::Incoming => in_deg,
197                    ExpandDirection::Both => out_deg + in_deg,
198                };
199            } else {
200                all_found = false;
201                break;
202            }
203        }
204
205        if all_found {
206            total_fanout
207        } else {
208            self.avg_fanout
209        }
210    }
211
212    /// Estimates the cost of a logical operator.
213    #[must_use]
214    pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215        match op {
216            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217            LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218            LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219            LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220            LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221            LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222            LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223            LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224            LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225            LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226            LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227            LogicalOperator::Empty => Cost::zero(),
228            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229            LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230            LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231            LogicalOperator::LeftJoin(lj) => {
232                self.left_join_cost(lj, cardinality, cardinality.sqrt(), cardinality.sqrt())
233            }
234            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
235        }
236    }
237
238    /// Estimates the cost of a node scan.
239    ///
240    /// When label statistics are available, uses the actual label cardinality
241    /// for IO estimation (pages to read) rather than the optimizer's cardinality
242    /// estimate, which may already account for filter selectivity.
243    fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
244        // IO cost: based on how many nodes we actually need to scan from storage
245        let scan_size = if let Some(label) = &scan.label {
246            self.label_cardinalities
247                .get(label)
248                .map_or(cardinality, |&count| count as f64)
249        } else if self.total_nodes > 0 {
250            self.total_nodes as f64
251        } else {
252            cardinality
253        };
254        let pages = (scan_size * self.avg_tuple_size) / self.page_size;
255        // CPU cost: only pay for rows that pass the label filter
256        Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
257    }
258
259    /// Estimates the cost of a filter operation.
260    fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
261        // Filter cost is just predicate evaluation per tuple
262        Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
263    }
264
265    /// Estimates the cost of a projection.
266    fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
267        // Cost depends on number of expressions evaluated
268        let expr_count = project.projections.len() as f64;
269        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
270    }
271
272    /// Estimates the cost of an expand operation.
273    ///
274    /// Uses per-edge-type degree stats when available, otherwise falls back
275    /// to the global average fanout.
276    fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
277        let fanout = self.fanout_for_expand(expand);
278        // Adjacency list lookup per input row
279        let lookup_cost = cardinality * self.hash_lookup_cost;
280        // Process each expanded output tuple
281        let output_cost = cardinality * fanout * self.cpu_tuple_cost;
282        Cost::cpu(lookup_cost + output_cost)
283    }
284
285    /// Estimates the cost of a join operation.
286    ///
287    /// When child cardinalities are known (from recursive estimation), uses them
288    /// directly for build/probe cost. Otherwise falls back to sqrt approximation.
289    fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
290        self.join_cost_with_children(join, cardinality, None, None)
291    }
292
293    /// Estimates join cost using actual child cardinalities.
294    fn join_cost_with_children(
295        &self,
296        join: &JoinOp,
297        cardinality: f64,
298        left_cardinality: Option<f64>,
299        right_cardinality: Option<f64>,
300    ) -> Cost {
301        match join.join_type {
302            JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
303            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
304                // Hash join: build the smaller side, probe with the larger
305                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
306                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
307
308                let build_cost = build_cardinality * self.hash_lookup_cost;
309                let memory_cost = build_cardinality * self.avg_tuple_size;
310                let probe_cost = probe_cardinality * self.hash_lookup_cost;
311                let output_cost = cardinality * self.cpu_tuple_cost;
312
313                Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
314            }
315            JoinType::Semi | JoinType::Anti => {
316                let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
317                let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
318
319                let build_cost = build_cardinality * self.hash_lookup_cost;
320                let probe_cost = probe_cardinality * self.hash_lookup_cost;
321
322                Cost::cpu(build_cost + probe_cost)
323                    .with_memory(build_cardinality * self.avg_tuple_size)
324            }
325        }
326    }
327
328    /// Estimates the cost of a left outer join (OPTIONAL MATCH).
329    ///
330    /// Uses the same hash join cost model as inner joins: the right side is
331    /// built into a hash table, the left side probes. Output cost includes all
332    /// left rows (some with NULL-padded right side).
333    fn left_join_cost(
334        &self,
335        _lj: &LeftJoinOp,
336        cardinality: f64,
337        left_card: f64,
338        right_card: f64,
339    ) -> Cost {
340        let build_cost = right_card * self.hash_lookup_cost;
341        let memory_cost = right_card * self.avg_tuple_size;
342        let probe_cost = left_card * self.hash_lookup_cost;
343        let output_cost = cardinality * self.cpu_tuple_cost;
344
345        Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
346    }
347
348    /// Estimates the cost of a multi-way (leapfrog) join.
349    ///
350    /// Delegates to `leapfrog_join_cost` using per-input cardinality estimates
351    /// derived from the output cardinality divided equally among inputs.
352    fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
353        let n = mwj.inputs.len();
354        if n == 0 {
355            return Cost::zero();
356        }
357        // Approximate per-input cardinalities: assume each input contributes
358        // cardinality^(1/n) rows (AGM-style uniform assumption)
359        let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
360        let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
361        self.leapfrog_join_cost(n, &cardinalities, cardinality)
362    }
363
364    /// Estimates the cost of an aggregation.
365    fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
366        // Hash aggregation cost
367        let hash_cost = cardinality * self.hash_lookup_cost;
368
369        // Aggregate function evaluation
370        let agg_count = agg.aggregates.len() as f64;
371        let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
372
373        // Memory for hash table (estimated distinct groups)
374        let distinct_groups = (cardinality / 10.0).max(1.0); // Assume 10% distinct
375        let memory_cost = distinct_groups * self.avg_tuple_size;
376
377        Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
378    }
379
380    /// Estimates the cost of a sort operation.
381    fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
382        if cardinality <= 1.0 {
383            return Cost::zero();
384        }
385
386        // Sort is O(n log n) comparisons
387        let comparisons = cardinality * cardinality.log2();
388        let key_count = sort.keys.len() as f64;
389
390        // Memory for sorting (full input materialization)
391        let memory_cost = cardinality * self.avg_tuple_size;
392
393        Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
394    }
395
396    /// Estimates the cost of a distinct operation.
397    fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
398        // Hash-based distinct
399        let hash_cost = cardinality * self.hash_lookup_cost;
400        let memory_cost = cardinality * self.avg_tuple_size * 0.5; // Assume 50% distinct
401
402        Cost::cpu(hash_cost).with_memory(memory_cost)
403    }
404
405    /// Estimates the cost of a limit operation.
406    fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
407        // Limit is very cheap - just counting
408        Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
409    }
410
411    /// Estimates the cost of a skip operation.
412    fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
413        // Skip requires scanning through skipped rows
414        Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
415    }
416
417    /// Estimates the cost of a return operation.
418    fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
419        // Return materializes results
420        let expr_count = ret.items.len() as f64;
421        Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
422    }
423
424    /// Estimates the cost of a vector scan operation.
425    ///
426    /// HNSW index search is O(log N) per query, while brute-force is O(N).
427    /// This estimates the HNSW case with ef search parameter.
428    fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
429        // k determines output cardinality
430        let k = scan.k as f64;
431
432        // HNSW search cost: O(ef * log(N)) distance computations
433        // Assume ef = 64 (default), N = cardinality
434        let ef = 64.0;
435        let n = cardinality.max(1.0);
436        let search_cost = if scan.index_name.is_some() {
437            // HNSW: O(ef * log N)
438            ef * n.ln() * self.cpu_tuple_cost * 10.0 // Distance computation is ~10x regular tuple
439        } else {
440            // Brute-force: O(N)
441            n * self.cpu_tuple_cost * 10.0
442        };
443
444        // Memory for candidate heap
445        let memory = k * self.avg_tuple_size * 2.0;
446
447        Cost::cpu(search_cost).with_memory(memory)
448    }
449
450    /// Estimates the cost of a vector join operation.
451    ///
452    /// Vector join performs k-NN search for each input row.
453    fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
454        let k = join.k as f64;
455
456        // Each input row triggers a vector search
457        // Assume brute-force for hybrid queries (no index specified typically)
458        let per_row_search_cost = if join.index_name.is_some() {
459            // HNSW: O(ef * log N)
460            let ef = 64.0;
461            let n = cardinality.max(1.0);
462            ef * n.ln() * self.cpu_tuple_cost * 10.0
463        } else {
464            // Brute-force: O(N) per input row
465            cardinality * self.cpu_tuple_cost * 10.0
466        };
467
468        // Total cost: input_rows * search_cost
469        // For vector join, cardinality is typically input cardinality * k
470        let input_cardinality = (cardinality / k).max(1.0);
471        let total_search_cost = input_cardinality * per_row_search_cost;
472
473        // Memory for results
474        let memory = cardinality * self.avg_tuple_size;
475
476        Cost::cpu(total_search_cost).with_memory(memory)
477    }
478
479    /// Estimates the total cost of an operator tree recursively.
480    ///
481    /// Walks the entire plan tree, computing per-operator cost at each level
482    /// using the cardinality estimator for accurate child cardinalities.
483    /// Returns the sum of all operator costs in the tree.
484    #[must_use]
485    pub fn estimate_tree(
486        &self,
487        op: &LogicalOperator,
488        card_estimator: &super::CardinalityEstimator,
489    ) -> Cost {
490        self.estimate_tree_inner(op, card_estimator)
491    }
492
493    fn estimate_tree_inner(
494        &self,
495        op: &LogicalOperator,
496        card_est: &super::CardinalityEstimator,
497    ) -> Cost {
498        let cardinality = card_est.estimate(op);
499
500        match op {
501            LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
502            LogicalOperator::Filter(filter) => {
503                let child_cost = self.estimate_tree_inner(&filter.input, card_est);
504                child_cost + self.filter_cost(filter, cardinality)
505            }
506            LogicalOperator::Project(project) => {
507                let child_cost = self.estimate_tree_inner(&project.input, card_est);
508                child_cost + self.project_cost(project, cardinality)
509            }
510            LogicalOperator::Expand(expand) => {
511                let child_cost = self.estimate_tree_inner(&expand.input, card_est);
512                child_cost + self.expand_cost(expand, cardinality)
513            }
514            LogicalOperator::Join(join) => {
515                let left_cost = self.estimate_tree_inner(&join.left, card_est);
516                let right_cost = self.estimate_tree_inner(&join.right, card_est);
517                let left_card = card_est.estimate(&join.left);
518                let right_card = card_est.estimate(&join.right);
519                let join_cost = self.join_cost_with_children(
520                    join,
521                    cardinality,
522                    Some(left_card),
523                    Some(right_card),
524                );
525                left_cost + right_cost + join_cost
526            }
527            LogicalOperator::LeftJoin(lj) => {
528                let left_cost = self.estimate_tree_inner(&lj.left, card_est);
529                let right_cost = self.estimate_tree_inner(&lj.right, card_est);
530                let left_card = card_est.estimate(&lj.left);
531                let right_card = card_est.estimate(&lj.right);
532                let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
533                left_cost + right_cost + join_cost
534            }
535            LogicalOperator::Aggregate(agg) => {
536                let child_cost = self.estimate_tree_inner(&agg.input, card_est);
537                child_cost + self.aggregate_cost(agg, cardinality)
538            }
539            LogicalOperator::Sort(sort) => {
540                let child_cost = self.estimate_tree_inner(&sort.input, card_est);
541                child_cost + self.sort_cost(sort, cardinality)
542            }
543            LogicalOperator::Distinct(distinct) => {
544                let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
545                child_cost + self.distinct_cost(distinct, cardinality)
546            }
547            LogicalOperator::Limit(limit) => {
548                let child_cost = self.estimate_tree_inner(&limit.input, card_est);
549                child_cost + self.limit_cost(limit, cardinality)
550            }
551            LogicalOperator::Skip(skip) => {
552                let child_cost = self.estimate_tree_inner(&skip.input, card_est);
553                child_cost + self.skip_cost(skip, cardinality)
554            }
555            LogicalOperator::Return(ret) => {
556                let child_cost = self.estimate_tree_inner(&ret.input, card_est);
557                child_cost + self.return_cost(ret, cardinality)
558            }
559            LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
560            LogicalOperator::VectorJoin(join) => {
561                let child_cost = self.estimate_tree_inner(&join.input, card_est);
562                child_cost + self.vector_join_cost(join, cardinality)
563            }
564            LogicalOperator::MultiWayJoin(mwj) => {
565                let mut children_cost = Cost::zero();
566                for input in &mwj.inputs {
567                    children_cost += self.estimate_tree_inner(input, card_est);
568                }
569                children_cost + self.multi_way_join_cost(mwj, cardinality)
570            }
571            LogicalOperator::Empty => Cost::zero(),
572            _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
573        }
574    }
575
576    /// Compares two costs and returns the cheaper one.
577    #[must_use]
578    pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
579        if a.total() <= b.total() { a } else { b }
580    }
581
582    /// Estimates the cost of a worst-case optimal join (WCOJ/leapfrog join).
583    ///
584    /// WCOJ is optimal for cyclic patterns like triangles. Traditional binary
585    /// hash joins are O(N²) for triangles; WCOJ achieves O(N^1.5) by processing
586    /// all relations simultaneously using sorted iterators.
587    ///
588    /// # Arguments
589    /// * `num_relations` - Number of relations participating in the join
590    /// * `cardinalities` - Cardinality of each input relation
591    /// * `output_cardinality` - Expected output cardinality
592    ///
593    /// # Cost Model
594    /// - Materialization: O(sum of cardinalities) to build trie indexes
595    /// - Intersection: O(output * log(min_cardinality)) for leapfrog seek operations
596    /// - Memory: Trie storage for all inputs
597    #[must_use]
598    pub fn leapfrog_join_cost(
599        &self,
600        num_relations: usize,
601        cardinalities: &[f64],
602        output_cardinality: f64,
603    ) -> Cost {
604        if cardinalities.is_empty() {
605            return Cost::zero();
606        }
607
608        let total_input: f64 = cardinalities.iter().sum();
609        let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
610
611        // Materialization phase: build trie indexes for each input
612        let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; // Sorting + trie building
613
614        // Intersection phase: leapfrog seeks are O(log n) per relation
615        let seek_cost = if min_card > 1.0 {
616            output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
617        } else {
618            output_cardinality * self.cpu_tuple_cost
619        };
620
621        // Output materialization
622        let output_cost = output_cardinality * self.cpu_tuple_cost;
623
624        // Memory: trie storage (roughly 2x input size for sorted index)
625        let memory = total_input * self.avg_tuple_size * 2.0;
626
627        Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
628    }
629
630    /// Compares hash join cost vs leapfrog join cost for a cyclic pattern.
631    ///
632    /// Returns true if leapfrog (WCOJ) is estimated to be cheaper.
633    #[must_use]
634    pub fn prefer_leapfrog_join(
635        &self,
636        num_relations: usize,
637        cardinalities: &[f64],
638        output_cardinality: f64,
639    ) -> bool {
640        if num_relations < 3 || cardinalities.len() < 3 {
641            // Leapfrog is only beneficial for multi-way joins (3+)
642            return false;
643        }
644
645        let leapfrog_cost =
646            self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
647
648        // Estimate cascade of binary hash joins
649        // For N relations, we need N-1 joins
650        // Each join produces intermediate results that feed the next
651        let mut hash_cascade_cost = Cost::zero();
652        let mut intermediate_cardinality = cardinalities[0];
653
654        for card in &cardinalities[1..] {
655            // Hash join cost: build + probe + output
656            let join_output = (intermediate_cardinality * card).sqrt(); // Estimated selectivity
657            let join = JoinOp {
658                left: Box::new(LogicalOperator::Empty),
659                right: Box::new(LogicalOperator::Empty),
660                join_type: JoinType::Inner,
661                conditions: vec![],
662            };
663            hash_cascade_cost += self.join_cost(&join, join_output);
664            intermediate_cardinality = join_output;
665        }
666
667        leapfrog_cost.total() < hash_cascade_cost.total()
668    }
669
670    /// Estimates cost for factorized execution (compressed intermediate results).
671    ///
672    /// Factorized execution avoids materializing full cross products by keeping
673    /// results in a compressed "factorized" form. This is beneficial for multi-hop
674    /// traversals where intermediate results can explode.
675    ///
676    /// Returns the reduction factor (1.0 = no benefit, lower = more compression).
677    #[must_use]
678    pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
679        if num_hops <= 1 || avg_fanout <= 1.0 {
680            return 1.0; // No benefit for single hop or low fanout
681        }
682
683        // Factorized representation compresses repeated prefixes
684        // Compression ratio improves with higher fanout and more hops
685        // Full materialization: fanout^hops
686        // Factorized: sum(fanout^i for i in 1..=hops) ≈ fanout^(hops+1) / (fanout - 1)
687
688        let full_size = avg_fanout.powi(num_hops as i32);
689        let factorized_size = if avg_fanout > 1.0 {
690            (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
691        } else {
692            num_hops as f64
693        };
694
695        (factorized_size / full_size).min(1.0)
696    }
697}
698
699impl Default for CostModel {
700    fn default() -> Self {
701        Self::new()
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use crate::query::plan::{
709        AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
710        PathMode, Projection, ReturnItem, SortOrder,
711    };
712
713    #[test]
714    fn test_cost_addition() {
715        let a = Cost::cpu(10.0).with_io(5.0);
716        let b = Cost::cpu(20.0).with_memory(100.0);
717        let c = a + b;
718
719        assert!((c.cpu - 30.0).abs() < 0.001);
720        assert!((c.io - 5.0).abs() < 0.001);
721        assert!((c.memory - 100.0).abs() < 0.001);
722    }
723
724    #[test]
725    fn test_cost_total() {
726        let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
727        // Total = 10 + 1*10 + 100*0.1 = 10 + 10 + 10 = 30
728        assert!((cost.total() - 30.0).abs() < 0.001);
729    }
730
731    #[test]
732    fn test_cost_model_node_scan() {
733        let model = CostModel::new();
734        let scan = NodeScanOp {
735            variable: "n".to_string(),
736            label: Some("Person".to_string()),
737            input: None,
738        };
739        let cost = model.node_scan_cost(&scan, 1000.0);
740
741        assert!(cost.cpu > 0.0);
742        assert!(cost.io > 0.0);
743    }
744
745    #[test]
746    fn test_cost_model_sort() {
747        let model = CostModel::new();
748        let sort = SortOp {
749            keys: vec![],
750            input: Box::new(LogicalOperator::Empty),
751        };
752
753        let cost_100 = model.sort_cost(&sort, 100.0);
754        let cost_1000 = model.sort_cost(&sort, 1000.0);
755
756        // Sorting 1000 rows should be more expensive than 100 rows
757        assert!(cost_1000.total() > cost_100.total());
758    }
759
760    #[test]
761    fn test_cost_zero() {
762        let cost = Cost::zero();
763        assert!((cost.cpu).abs() < 0.001);
764        assert!((cost.io).abs() < 0.001);
765        assert!((cost.memory).abs() < 0.001);
766        assert!((cost.network).abs() < 0.001);
767        assert!((cost.total()).abs() < 0.001);
768    }
769
770    #[test]
771    fn test_cost_add_assign() {
772        let mut cost = Cost::cpu(10.0);
773        cost += Cost::cpu(5.0).with_io(2.0);
774        assert!((cost.cpu - 15.0).abs() < 0.001);
775        assert!((cost.io - 2.0).abs() < 0.001);
776    }
777
778    #[test]
779    fn test_cost_total_weighted() {
780        let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
781        // With custom weights: cpu*2 + io*5 + mem*0.5 = 20 + 10 + 50 = 80
782        let total = cost.total_weighted(2.0, 5.0, 0.5);
783        assert!((total - 80.0).abs() < 0.001);
784    }
785
786    #[test]
787    fn test_cost_model_filter() {
788        let model = CostModel::new();
789        let filter = FilterOp {
790            predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
791            input: Box::new(LogicalOperator::Empty),
792            pushdown_hint: None,
793        };
794        let cost = model.filter_cost(&filter, 1000.0);
795
796        // Filter cost is CPU only
797        assert!(cost.cpu > 0.0);
798        assert!((cost.io).abs() < 0.001);
799    }
800
801    #[test]
802    fn test_cost_model_project() {
803        let model = CostModel::new();
804        let project = ProjectOp {
805            projections: vec![
806                Projection {
807                    expression: LogicalExpression::Variable("a".to_string()),
808                    alias: None,
809                },
810                Projection {
811                    expression: LogicalExpression::Variable("b".to_string()),
812                    alias: None,
813                },
814            ],
815            input: Box::new(LogicalOperator::Empty),
816            pass_through_input: false,
817        };
818        let cost = model.project_cost(&project, 1000.0);
819
820        // Cost should scale with number of projections
821        assert!(cost.cpu > 0.0);
822    }
823
824    #[test]
825    fn test_cost_model_expand() {
826        let model = CostModel::new();
827        let expand = ExpandOp {
828            from_variable: "a".to_string(),
829            to_variable: "b".to_string(),
830            edge_variable: None,
831            direction: ExpandDirection::Outgoing,
832            edge_types: vec![],
833            min_hops: 1,
834            max_hops: Some(1),
835            input: Box::new(LogicalOperator::Empty),
836            path_alias: None,
837            path_mode: PathMode::Walk,
838        };
839        let cost = model.expand_cost(&expand, 1000.0);
840
841        // Expand involves hash lookups and output generation
842        assert!(cost.cpu > 0.0);
843    }
844
845    #[test]
846    fn test_cost_model_expand_with_edge_type_stats() {
847        let mut degrees = std::collections::HashMap::new();
848        degrees.insert("KNOWS".to_string(), (5.0, 5.0)); // Symmetric
849        degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); // Many-to-one
850
851        let model = CostModel::new().with_edge_type_degrees(degrees);
852
853        // Outgoing KNOWS: fanout = 5
854        let knows_out = ExpandOp {
855            from_variable: "a".to_string(),
856            to_variable: "b".to_string(),
857            edge_variable: None,
858            direction: ExpandDirection::Outgoing,
859            edge_types: vec!["KNOWS".to_string()],
860            min_hops: 1,
861            max_hops: Some(1),
862            input: Box::new(LogicalOperator::Empty),
863            path_alias: None,
864            path_mode: PathMode::Walk,
865        };
866        let cost_knows = model.expand_cost(&knows_out, 1000.0);
867
868        // Outgoing WORKS_AT: fanout = 1 (each person works at one company)
869        let works_out = ExpandOp {
870            from_variable: "a".to_string(),
871            to_variable: "b".to_string(),
872            edge_variable: None,
873            direction: ExpandDirection::Outgoing,
874            edge_types: vec!["WORKS_AT".to_string()],
875            min_hops: 1,
876            max_hops: Some(1),
877            input: Box::new(LogicalOperator::Empty),
878            path_alias: None,
879            path_mode: PathMode::Walk,
880        };
881        let cost_works = model.expand_cost(&works_out, 1000.0);
882
883        // KNOWS (fanout=5) should be more expensive than WORKS_AT (fanout=1)
884        assert!(
885            cost_knows.cpu > cost_works.cpu,
886            "KNOWS(5) should cost more than WORKS_AT(1)"
887        );
888
889        // Incoming WORKS_AT: fanout = 50 (company has many employees)
890        let works_in = ExpandOp {
891            from_variable: "c".to_string(),
892            to_variable: "p".to_string(),
893            edge_variable: None,
894            direction: ExpandDirection::Incoming,
895            edge_types: vec!["WORKS_AT".to_string()],
896            min_hops: 1,
897            max_hops: Some(1),
898            input: Box::new(LogicalOperator::Empty),
899            path_alias: None,
900            path_mode: PathMode::Walk,
901        };
902        let cost_works_in = model.expand_cost(&works_in, 1000.0);
903
904        // Incoming WORKS_AT (fanout=50) should be most expensive
905        assert!(
906            cost_works_in.cpu > cost_knows.cpu,
907            "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
908        );
909    }
910
911    #[test]
912    fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
913        let model = CostModel::new().with_avg_fanout(7.0);
914        let expand = ExpandOp {
915            from_variable: "a".to_string(),
916            to_variable: "b".to_string(),
917            edge_variable: None,
918            direction: ExpandDirection::Outgoing,
919            edge_types: vec!["UNKNOWN_TYPE".to_string()],
920            min_hops: 1,
921            max_hops: Some(1),
922            input: Box::new(LogicalOperator::Empty),
923            path_alias: None,
924            path_mode: PathMode::Walk,
925        };
926        let cost_unknown = model.expand_cost(&expand, 1000.0);
927
928        // Without edge type (uses global fanout too)
929        let expand_no_type = ExpandOp {
930            from_variable: "a".to_string(),
931            to_variable: "b".to_string(),
932            edge_variable: None,
933            direction: ExpandDirection::Outgoing,
934            edge_types: vec![],
935            min_hops: 1,
936            max_hops: Some(1),
937            input: Box::new(LogicalOperator::Empty),
938            path_alias: None,
939            path_mode: PathMode::Walk,
940        };
941        let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
942
943        // Both should use global fanout = 7, so costs should be equal
944        assert!(
945            (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
946            "Unknown edge type should use global fanout"
947        );
948    }
949
950    #[test]
951    fn test_cost_model_hash_join() {
952        let model = CostModel::new();
953        let join = JoinOp {
954            left: Box::new(LogicalOperator::Empty),
955            right: Box::new(LogicalOperator::Empty),
956            join_type: JoinType::Inner,
957            conditions: vec![JoinCondition {
958                left: LogicalExpression::Variable("a".to_string()),
959                right: LogicalExpression::Variable("b".to_string()),
960            }],
961        };
962        let cost = model.join_cost(&join, 10000.0);
963
964        // Hash join has CPU cost and memory cost
965        assert!(cost.cpu > 0.0);
966        assert!(cost.memory > 0.0);
967    }
968
969    #[test]
970    fn test_cost_model_cross_join() {
971        let model = CostModel::new();
972        let join = JoinOp {
973            left: Box::new(LogicalOperator::Empty),
974            right: Box::new(LogicalOperator::Empty),
975            join_type: JoinType::Cross,
976            conditions: vec![],
977        };
978        let cost = model.join_cost(&join, 1000000.0);
979
980        // Cross join is expensive
981        assert!(cost.cpu > 0.0);
982    }
983
984    #[test]
985    fn test_cost_model_semi_join() {
986        let model = CostModel::new();
987        let join = JoinOp {
988            left: Box::new(LogicalOperator::Empty),
989            right: Box::new(LogicalOperator::Empty),
990            join_type: JoinType::Semi,
991            conditions: vec![],
992        };
993        let cost_semi = model.join_cost(&join, 1000.0);
994
995        let inner_join = JoinOp {
996            left: Box::new(LogicalOperator::Empty),
997            right: Box::new(LogicalOperator::Empty),
998            join_type: JoinType::Inner,
999            conditions: vec![],
1000        };
1001        let cost_inner = model.join_cost(&inner_join, 1000.0);
1002
1003        // Semi join can be cheaper than inner join
1004        assert!(cost_semi.cpu > 0.0);
1005        assert!(cost_inner.cpu > 0.0);
1006    }
1007
1008    #[test]
1009    fn test_cost_model_aggregate() {
1010        let model = CostModel::new();
1011        let agg = AggregateOp {
1012            group_by: vec![],
1013            aggregates: vec![
1014                AggregateExpr {
1015                    function: AggregateFunction::Count,
1016                    expression: None,
1017                    expression2: None,
1018                    distinct: false,
1019                    alias: Some("cnt".to_string()),
1020                    percentile: None,
1021                    separator: None,
1022                },
1023                AggregateExpr {
1024                    function: AggregateFunction::Sum,
1025                    expression: Some(LogicalExpression::Variable("x".to_string())),
1026                    expression2: None,
1027                    distinct: false,
1028                    alias: Some("total".to_string()),
1029                    percentile: None,
1030                    separator: None,
1031                },
1032            ],
1033            input: Box::new(LogicalOperator::Empty),
1034            having: None,
1035        };
1036        let cost = model.aggregate_cost(&agg, 1000.0);
1037
1038        // Aggregation has hash cost and memory cost
1039        assert!(cost.cpu > 0.0);
1040        assert!(cost.memory > 0.0);
1041    }
1042
1043    #[test]
1044    fn test_cost_model_distinct() {
1045        let model = CostModel::new();
1046        let distinct = DistinctOp {
1047            input: Box::new(LogicalOperator::Empty),
1048            columns: None,
1049        };
1050        let cost = model.distinct_cost(&distinct, 1000.0);
1051
1052        // Distinct uses hash set
1053        assert!(cost.cpu > 0.0);
1054        assert!(cost.memory > 0.0);
1055    }
1056
1057    #[test]
1058    fn test_cost_model_limit() {
1059        let model = CostModel::new();
1060        let limit = LimitOp {
1061            count: 10.into(),
1062            input: Box::new(LogicalOperator::Empty),
1063        };
1064        let cost = model.limit_cost(&limit, 1000.0);
1065
1066        // Limit is very cheap
1067        assert!(cost.cpu > 0.0);
1068        assert!(cost.cpu < 1.0); // Should be minimal
1069    }
1070
1071    #[test]
1072    fn test_cost_model_skip() {
1073        let model = CostModel::new();
1074        let skip = SkipOp {
1075            count: 100.into(),
1076            input: Box::new(LogicalOperator::Empty),
1077        };
1078        let cost = model.skip_cost(&skip, 1000.0);
1079
1080        // Skip must scan through skipped rows
1081        assert!(cost.cpu > 0.0);
1082    }
1083
1084    #[test]
1085    fn test_cost_model_return() {
1086        let model = CostModel::new();
1087        let ret = ReturnOp {
1088            items: vec![
1089                ReturnItem {
1090                    expression: LogicalExpression::Variable("a".to_string()),
1091                    alias: None,
1092                },
1093                ReturnItem {
1094                    expression: LogicalExpression::Variable("b".to_string()),
1095                    alias: None,
1096                },
1097            ],
1098            distinct: false,
1099            input: Box::new(LogicalOperator::Empty),
1100        };
1101        let cost = model.return_cost(&ret, 1000.0);
1102
1103        // Return materializes results
1104        assert!(cost.cpu > 0.0);
1105    }
1106
1107    #[test]
1108    fn test_cost_cheaper() {
1109        let model = CostModel::new();
1110        let cheap = Cost::cpu(10.0);
1111        let expensive = Cost::cpu(100.0);
1112
1113        assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1114        assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1115    }
1116
1117    #[test]
1118    fn test_cost_comparison_prefers_lower_total() {
1119        let model = CostModel::new();
1120        // High CPU, low IO
1121        let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1122        // Low CPU, high IO
1123        let io_heavy = Cost::cpu(10.0).with_io(20.0);
1124
1125        // IO is weighted 10x, so io_heavy = 10 + 200 = 210, cpu_heavy = 100 + 10 = 110
1126        assert!(cpu_heavy.total() < io_heavy.total());
1127        assert_eq!(
1128            model.cheaper(&cpu_heavy, &io_heavy).total(),
1129            cpu_heavy.total()
1130        );
1131    }
1132
1133    #[test]
1134    fn test_cost_model_sort_with_keys() {
1135        let model = CostModel::new();
1136        let sort_single = SortOp {
1137            keys: vec![crate::query::plan::SortKey {
1138                expression: LogicalExpression::Variable("a".to_string()),
1139                order: SortOrder::Ascending,
1140                nulls: None,
1141            }],
1142            input: Box::new(LogicalOperator::Empty),
1143        };
1144        let sort_multi = SortOp {
1145            keys: vec![
1146                crate::query::plan::SortKey {
1147                    expression: LogicalExpression::Variable("a".to_string()),
1148                    order: SortOrder::Ascending,
1149                    nulls: None,
1150                },
1151                crate::query::plan::SortKey {
1152                    expression: LogicalExpression::Variable("b".to_string()),
1153                    order: SortOrder::Descending,
1154                    nulls: None,
1155                },
1156            ],
1157            input: Box::new(LogicalOperator::Empty),
1158        };
1159
1160        let cost_single = model.sort_cost(&sort_single, 1000.0);
1161        let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1162
1163        // More sort keys = more comparisons
1164        assert!(cost_multi.cpu > cost_single.cpu);
1165    }
1166
1167    #[test]
1168    fn test_cost_model_empty_operator() {
1169        let model = CostModel::new();
1170        let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1171        assert!((cost.total()).abs() < 0.001);
1172    }
1173
1174    #[test]
1175    fn test_cost_model_default() {
1176        let model = CostModel::default();
1177        let scan = NodeScanOp {
1178            variable: "n".to_string(),
1179            label: None,
1180            input: None,
1181        };
1182        let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1183        assert!(cost.total() > 0.0);
1184    }
1185
1186    #[test]
1187    fn test_leapfrog_join_cost() {
1188        let model = CostModel::new();
1189
1190        // Three-way join (triangle pattern)
1191        let cardinalities = vec![1000.0, 1000.0, 1000.0];
1192        let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1193
1194        // Should have CPU cost for materialization and intersection
1195        assert!(cost.cpu > 0.0);
1196        // Should have memory cost for trie storage
1197        assert!(cost.memory > 0.0);
1198    }
1199
1200    #[test]
1201    fn test_leapfrog_join_cost_empty() {
1202        let model = CostModel::new();
1203        let cost = model.leapfrog_join_cost(0, &[], 0.0);
1204        assert!((cost.total()).abs() < 0.001);
1205    }
1206
1207    #[test]
1208    fn test_prefer_leapfrog_join_for_triangles() {
1209        let model = CostModel::new();
1210
1211        // Compare costs for triangle pattern
1212        let cardinalities = vec![10000.0, 10000.0, 10000.0];
1213        let output = 1000.0;
1214
1215        let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1216
1217        // Leapfrog should have reasonable cost for triangle patterns
1218        assert!(leapfrog_cost.cpu > 0.0);
1219        assert!(leapfrog_cost.memory > 0.0);
1220
1221        // The prefer_leapfrog_join method compares against hash cascade
1222        // Actual preference depends on specific cost parameters
1223        let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1224        // Test that it returns a boolean (doesn't panic)
1225    }
1226
1227    #[test]
1228    fn test_prefer_leapfrog_join_binary_case() {
1229        let model = CostModel::new();
1230
1231        // Binary join should NOT prefer leapfrog (need 3+ relations)
1232        let cardinalities = vec![1000.0, 1000.0];
1233        let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1234        assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1235    }
1236
1237    #[test]
1238    fn test_factorized_benefit_single_hop() {
1239        let model = CostModel::new();
1240
1241        // Single hop: no factorization benefit
1242        let benefit = model.factorized_benefit(10.0, 1);
1243        assert!(
1244            (benefit - 1.0).abs() < 0.001,
1245            "Single hop should have no benefit"
1246        );
1247    }
1248
1249    #[test]
1250    fn test_factorized_benefit_multi_hop() {
1251        let model = CostModel::new();
1252
1253        // Multi-hop with high fanout
1254        let benefit = model.factorized_benefit(10.0, 3);
1255
1256        // The factorized_benefit returns a ratio capped at 1.0
1257        // For high fanout, factorized size / full size approaches 1/fanout
1258        // which is beneficial but the formula gives a value <= 1.0
1259        assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1260        assert!(benefit > 0.0, "Benefit should be positive");
1261    }
1262
1263    #[test]
1264    fn test_factorized_benefit_low_fanout() {
1265        let model = CostModel::new();
1266
1267        // Low fanout: minimal benefit
1268        let benefit = model.factorized_benefit(1.5, 2);
1269        assert!(
1270            benefit <= 1.0,
1271            "Low fanout still benefits from factorization"
1272        );
1273    }
1274
1275    #[test]
1276    fn test_node_scan_uses_label_cardinality_for_io() {
1277        let mut label_cards = std::collections::HashMap::new();
1278        label_cards.insert("Person".to_string(), 500_u64);
1279        label_cards.insert("Company".to_string(), 50_u64);
1280
1281        let model = CostModel::new()
1282            .with_label_cardinalities(label_cards)
1283            .with_graph_totals(550, 1000);
1284
1285        let person_scan = NodeScanOp {
1286            variable: "n".to_string(),
1287            label: Some("Person".to_string()),
1288            input: None,
1289        };
1290        let company_scan = NodeScanOp {
1291            variable: "n".to_string(),
1292            label: Some("Company".to_string()),
1293            input: None,
1294        };
1295
1296        let person_cost = model.node_scan_cost(&person_scan, 500.0);
1297        let company_cost = model.node_scan_cost(&company_scan, 50.0);
1298
1299        // Person scan reads 10x more pages than Company scan
1300        assert!(
1301            person_cost.io > company_cost.io * 5.0,
1302            "Person ({}) should have much higher IO than Company ({})",
1303            person_cost.io,
1304            company_cost.io
1305        );
1306    }
1307
1308    #[test]
1309    fn test_node_scan_unlabeled_uses_total_nodes() {
1310        let model = CostModel::new().with_graph_totals(10_000, 50_000);
1311
1312        let scan = NodeScanOp {
1313            variable: "n".to_string(),
1314            label: None,
1315            input: None,
1316        };
1317
1318        let cost = model.node_scan_cost(&scan, 10_000.0);
1319        let expected_pages = (10_000.0 * 100.0) / 8192.0;
1320        assert!(
1321            (cost.io - expected_pages).abs() < 0.1,
1322            "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1323            cost.io,
1324            expected_pages
1325        );
1326    }
1327
1328    #[test]
1329    fn test_join_cost_with_actual_child_cardinalities() {
1330        let model = CostModel::new();
1331        let join = JoinOp {
1332            left: Box::new(LogicalOperator::Empty),
1333            right: Box::new(LogicalOperator::Empty),
1334            join_type: JoinType::Inner,
1335            conditions: vec![JoinCondition {
1336                left: LogicalExpression::Variable("a".to_string()),
1337                right: LogicalExpression::Variable("b".to_string()),
1338            }],
1339        };
1340
1341        // With actual child cardinalities (100 left, 10000 right)
1342        let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1343
1344        // With sqrt fallback (sqrt(500) ~ 22.4 for both sides)
1345        let cost_sqrt = model.join_cost(&join, 500.0);
1346
1347        // Actual build side (100) is larger than sqrt(500) ~ 22.4, so
1348        // actual cost should be higher for build, but the probe side (10000)
1349        // should dominate
1350        assert!(
1351            cost_actual.cpu > cost_sqrt.cpu,
1352            "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1353            cost_actual.cpu,
1354            cost_sqrt.cpu
1355        );
1356    }
1357
1358    #[test]
1359    fn test_expand_multi_edge_types() {
1360        let mut degrees = std::collections::HashMap::new();
1361        degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1362        degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1363
1364        let model = CostModel::new().with_edge_type_degrees(degrees);
1365
1366        // Multi-type outgoing: KNOWS(5) + FOLLOWS(20) = 25
1367        let multi_expand = ExpandOp {
1368            from_variable: "a".to_string(),
1369            to_variable: "b".to_string(),
1370            edge_variable: None,
1371            direction: ExpandDirection::Outgoing,
1372            edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1373            min_hops: 1,
1374            max_hops: Some(1),
1375            input: Box::new(LogicalOperator::Empty),
1376            path_alias: None,
1377            path_mode: PathMode::Walk,
1378        };
1379        let multi_cost = model.expand_cost(&multi_expand, 100.0);
1380
1381        // Single type: KNOWS(5) only
1382        let single_expand = ExpandOp {
1383            from_variable: "a".to_string(),
1384            to_variable: "b".to_string(),
1385            edge_variable: None,
1386            direction: ExpandDirection::Outgoing,
1387            edge_types: vec!["KNOWS".to_string()],
1388            min_hops: 1,
1389            max_hops: Some(1),
1390            input: Box::new(LogicalOperator::Empty),
1391            path_alias: None,
1392            path_mode: PathMode::Walk,
1393        };
1394        let single_cost = model.expand_cost(&single_expand, 100.0);
1395
1396        // Multi-type (fanout=25) should be more expensive than single (fanout=5)
1397        assert!(
1398            multi_cost.cpu > single_cost.cpu * 3.0,
1399            "Multi-type fanout ({}) should be much higher than single-type ({})",
1400            multi_cost.cpu,
1401            single_cost.cpu
1402        );
1403    }
1404
1405    #[test]
1406    fn test_recursive_tree_cost() {
1407        use crate::query::optimizer::CardinalityEstimator;
1408
1409        let mut label_cards = std::collections::HashMap::new();
1410        label_cards.insert("Person".to_string(), 1000_u64);
1411
1412        let model = CostModel::new()
1413            .with_label_cardinalities(label_cards)
1414            .with_graph_totals(1000, 5000)
1415            .with_avg_fanout(5.0);
1416
1417        let mut card_est = CardinalityEstimator::new();
1418        card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1419
1420        // Build a plan: NodeScan -> Filter -> Return
1421        let plan = LogicalOperator::Return(ReturnOp {
1422            items: vec![ReturnItem {
1423                expression: LogicalExpression::Variable("n".to_string()),
1424                alias: None,
1425            }],
1426            distinct: false,
1427            input: Box::new(LogicalOperator::Filter(FilterOp {
1428                predicate: LogicalExpression::Binary {
1429                    left: Box::new(LogicalExpression::Property {
1430                        variable: "n".to_string(),
1431                        property: "age".to_string(),
1432                    }),
1433                    op: crate::query::plan::BinaryOp::Gt,
1434                    right: Box::new(LogicalExpression::Literal(
1435                        grafeo_common::types::Value::Int64(30),
1436                    )),
1437                },
1438                input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1439                    variable: "n".to_string(),
1440                    label: Some("Person".to_string()),
1441                    input: None,
1442                })),
1443                pushdown_hint: None,
1444            })),
1445        });
1446
1447        let tree_cost = model.estimate_tree(&plan, &card_est);
1448
1449        // Tree cost should include scan IO + filter CPU + return CPU
1450        assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1451        assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1452
1453        // Compare with single-operator estimate (only costs the root)
1454        let root_only_card = card_est.estimate(&plan);
1455        let root_only_cost = model.estimate(&plan, root_only_card);
1456
1457        // Tree cost should be strictly higher because it includes child costs
1458        assert!(
1459            tree_cost.total() > root_only_cost.total(),
1460            "Recursive tree cost ({}) should exceed root-only cost ({})",
1461            tree_cost.total(),
1462            root_only_cost.total()
1463        );
1464    }
1465
1466    #[test]
1467    fn test_statistics_driven_vs_default_cost() {
1468        let default_model = CostModel::new();
1469
1470        let mut label_cards = std::collections::HashMap::new();
1471        label_cards.insert("Person".to_string(), 100_u64);
1472        let stats_model = CostModel::new()
1473            .with_label_cardinalities(label_cards)
1474            .with_graph_totals(100, 500);
1475
1476        // Scan a small label: statistics model knows it's only 100 nodes
1477        let scan = NodeScanOp {
1478            variable: "n".to_string(),
1479            label: Some("Person".to_string()),
1480            input: None,
1481        };
1482
1483        let default_cost = default_model.node_scan_cost(&scan, 100.0);
1484        let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1485
1486        // With statistics, IO cost is based on actual label size (100 nodes)
1487        // Without statistics, IO cost uses cardinality parameter (also 100 here)
1488        // They should be equal in this case since cardinality matches label size
1489        assert!(
1490            (default_cost.io - stats_cost.io).abs() < 0.1,
1491            "When cardinality matches label size, costs should be similar"
1492        );
1493    }
1494}