1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7 LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8 TextScanOp, VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17 pub cpu: f64,
19 pub io: f64,
21 pub memory: f64,
23 pub network: f64,
25}
26
27impl Cost {
28 #[must_use]
30 pub fn zero() -> Self {
31 Self {
32 cpu: 0.0,
33 io: 0.0,
34 memory: 0.0,
35 network: 0.0,
36 }
37 }
38
39 #[must_use]
41 pub fn cpu(cpu: f64) -> Self {
42 Self {
43 cpu,
44 io: 0.0,
45 memory: 0.0,
46 network: 0.0,
47 }
48 }
49
50 #[must_use]
52 pub fn with_io(mut self, io: f64) -> Self {
53 self.io = io;
54 self
55 }
56
57 #[must_use]
59 pub fn with_memory(mut self, memory: f64) -> Self {
60 self.memory = memory;
61 self
62 }
63
64 #[must_use]
68 pub fn total(&self) -> f64 {
69 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70 }
71
72 #[must_use]
74 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76 }
77}
78
79impl std::ops::Add for Cost {
80 type Output = Self;
81
82 fn add(self, other: Self) -> Self {
83 Self {
84 cpu: self.cpu + other.cpu,
85 io: self.io + other.io,
86 memory: self.memory + other.memory,
87 network: self.network + other.network,
88 }
89 }
90}
91
92impl std::ops::AddAssign for Cost {
93 fn add_assign(&mut self, other: Self) {
94 self.cpu += other.cpu;
95 self.io += other.io;
96 self.memory += other.memory;
97 self.network += other.network;
98 }
99}
100
101pub struct CostModel {
109 cpu_tuple_cost: f64,
111 hash_lookup_cost: f64,
113 sort_comparison_cost: f64,
115 avg_tuple_size: f64,
117 page_size: f64,
119 avg_fanout: f64,
121 edge_type_degrees: HashMap<String, (f64, f64)>,
123 label_cardinalities: HashMap<String, u64>,
125 total_nodes: u64,
127 total_edges: u64,
129}
130
131impl CostModel {
132 #[must_use]
134 pub fn new() -> Self {
135 Self {
136 cpu_tuple_cost: 0.01,
137 hash_lookup_cost: 0.03,
138 sort_comparison_cost: 0.02,
139 avg_tuple_size: 100.0,
140 page_size: 8192.0,
141 avg_fanout: 10.0,
142 edge_type_degrees: HashMap::new(),
143 label_cardinalities: HashMap::new(),
144 total_nodes: 0,
145 total_edges: 0,
146 }
147 }
148
149 #[must_use]
151 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153 self
154 }
155
156 #[must_use]
160 pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161 self.edge_type_degrees = degrees;
162 self
163 }
164
165 #[must_use]
167 pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168 self.label_cardinalities = cardinalities;
169 self
170 }
171
172 #[must_use]
174 pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175 self.total_nodes = total_nodes;
176 self.total_edges = total_edges;
177 self
178 }
179
180 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185 if expand.edge_types.is_empty() {
186 return self.avg_fanout;
187 }
188
189 let mut total_fanout = 0.0;
190 let mut all_found = true;
191
192 for edge_type in &expand.edge_types {
193 if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194 total_fanout += match expand.direction {
195 ExpandDirection::Outgoing => out_deg,
196 ExpandDirection::Incoming => in_deg,
197 ExpandDirection::Both => out_deg + in_deg,
198 };
199 } else {
200 all_found = false;
201 break;
202 }
203 }
204
205 if all_found {
206 total_fanout
207 } else {
208 self.avg_fanout
209 }
210 }
211
212 #[must_use]
214 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215 match op {
216 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227 LogicalOperator::Empty => Cost::zero(),
228 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230 LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231 LogicalOperator::LeftJoin(lj) => {
232 let left_card = self.estimate_child_cardinality(&lj.left);
233 let right_card = self.estimate_child_cardinality(&lj.right);
234 self.left_join_cost(lj, cardinality, left_card, right_card)
235 }
236 LogicalOperator::TextScan(scan) => self.text_scan_cost(scan, cardinality),
237 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
238 }
239 }
240
241 fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
247 let scan_size = if let Some(label) = &scan.label {
249 self.label_cardinalities
250 .get(label)
251 .map_or(cardinality, |&count| count as f64)
252 } else if self.total_nodes > 0 {
253 self.total_nodes as f64
254 } else {
255 cardinality
256 };
257 let pages = (scan_size * self.avg_tuple_size) / self.page_size;
258 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
260 }
261
262 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
264 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
266 }
267
268 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
270 let expr_count = project.projections.len() as f64;
272 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
273 }
274
275 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
280 let fanout = self.fanout_for_expand(expand);
281 let lookup_cost = cardinality * self.hash_lookup_cost;
283 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
285 Cost::cpu(lookup_cost + output_cost)
286 }
287
288 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
293 self.join_cost_with_children(join, cardinality, None, None)
294 }
295
296 fn join_cost_with_children(
298 &self,
299 join: &JoinOp,
300 cardinality: f64,
301 left_cardinality: Option<f64>,
302 right_cardinality: Option<f64>,
303 ) -> Cost {
304 match join.join_type {
305 JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
306 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
307 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
309 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
310
311 let build_cost = build_cardinality * self.hash_lookup_cost;
312 let memory_cost = build_cardinality * self.avg_tuple_size;
313 let probe_cost = probe_cardinality * self.hash_lookup_cost;
314 let output_cost = cardinality * self.cpu_tuple_cost;
315
316 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
317 }
318 JoinType::Semi | JoinType::Anti => {
319 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
320 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
321
322 let build_cost = build_cardinality * self.hash_lookup_cost;
323 let probe_cost = probe_cardinality * self.hash_lookup_cost;
324
325 Cost::cpu(build_cost + probe_cost)
326 .with_memory(build_cardinality * self.avg_tuple_size)
327 }
328 }
329 }
330
331 fn left_join_cost(
337 &self,
338 _lj: &LeftJoinOp,
339 cardinality: f64,
340 left_card: f64,
341 right_card: f64,
342 ) -> Cost {
343 let build_cost = right_card * self.hash_lookup_cost;
344 let memory_cost = right_card * self.avg_tuple_size;
345 let probe_cost = left_card * self.hash_lookup_cost;
346 let output_cost = cardinality * self.cpu_tuple_cost;
347
348 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
349 }
350
351 fn estimate_child_cardinality(&self, op: &LogicalOperator) -> f64 {
357 match op {
358 LogicalOperator::NodeScan(scan) => if let Some(label) = &scan.label {
359 self.label_cardinalities
360 .get(label)
361 .map_or(self.total_nodes as f64, |&c| c as f64)
362 } else {
363 self.total_nodes as f64
364 }
365 .max(1.0),
366 LogicalOperator::Expand(expand) => {
367 let input_card = self.estimate_child_cardinality(&expand.input);
368 let fanout = if expand.edge_types.is_empty() {
369 self.avg_fanout
370 } else {
371 self.fanout_for_expand(expand)
372 };
373 (input_card * fanout).max(1.0)
374 }
375 LogicalOperator::Filter(filter) => {
376 (self.estimate_child_cardinality(&filter.input) * 0.1).max(1.0)
378 }
379 LogicalOperator::Return(ret) => self.estimate_child_cardinality(&ret.input),
380 LogicalOperator::Limit(limit) => {
381 let input = self.estimate_child_cardinality(&limit.input);
382 input.min(100.0)
384 }
385 _ => (self.total_nodes as f64).max(1.0),
386 }
387 }
388
389 fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
394 let n = mwj.inputs.len();
395 if n == 0 {
396 return Cost::zero();
397 }
398 let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
401 let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
402 self.leapfrog_join_cost(n, &cardinalities, cardinality)
403 }
404
405 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
407 let hash_cost = cardinality * self.hash_lookup_cost;
409
410 let agg_count = agg.aggregates.len() as f64;
412 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
413
414 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
417
418 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
419 }
420
421 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
423 if cardinality <= 1.0 {
424 return Cost::zero();
425 }
426
427 let comparisons = cardinality * cardinality.log2();
429 let key_count = sort.keys.len() as f64;
430
431 let memory_cost = cardinality * self.avg_tuple_size;
433
434 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
435 }
436
437 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
439 let hash_cost = cardinality * self.hash_lookup_cost;
441 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
444 }
445
446 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
448 Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
450 }
451
452 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
454 Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
456 }
457
458 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
460 let expr_count = ret.items.len() as f64;
462 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
463 }
464
465 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
470 let k = scan.k.unwrap_or(0) as f64;
471 let n = cardinality.max(1.0);
472
473 let ef = 64.0;
476 let search_cost = if scan.index_name.is_some() {
477 ef * n.ln() * self.cpu_tuple_cost * 10.0
478 } else {
479 n * self.cpu_tuple_cost * 10.0
480 };
481
482 let output_rows = if k > 0.0 { k } else { cardinality };
484 let memory = output_rows * self.avg_tuple_size * 2.0;
485
486 Cost::cpu(search_cost).with_memory(memory)
487 }
488
489 fn text_scan_cost(&self, scan: &TextScanOp, cardinality: f64) -> Cost {
495 let corpus_size = self
497 .label_cardinalities
498 .get(&scan.label)
499 .copied()
500 .map_or(cardinality, |c| c as f64);
501 let cpu = corpus_size * self.cpu_tuple_cost * 5.0;
503 Cost::cpu(cpu).with_memory(cardinality * self.avg_tuple_size)
504 }
505
506 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
510 let k = join.k as f64;
511
512 let per_row_search_cost = if join.index_name.is_some() {
515 let ef = 64.0;
517 let n = cardinality.max(1.0);
518 ef * n.ln() * self.cpu_tuple_cost * 10.0
519 } else {
520 cardinality * self.cpu_tuple_cost * 10.0
522 };
523
524 let input_cardinality = (cardinality / k).max(1.0);
527 let total_search_cost = input_cardinality * per_row_search_cost;
528
529 let memory = cardinality * self.avg_tuple_size;
531
532 Cost::cpu(total_search_cost).with_memory(memory)
533 }
534
535 #[must_use]
541 pub fn estimate_tree(
542 &self,
543 op: &LogicalOperator,
544 card_estimator: &super::CardinalityEstimator,
545 ) -> Cost {
546 self.estimate_tree_inner(op, card_estimator)
547 }
548
549 fn estimate_tree_inner(
550 &self,
551 op: &LogicalOperator,
552 card_est: &super::CardinalityEstimator,
553 ) -> Cost {
554 let cardinality = card_est.estimate(op);
555
556 match op {
557 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
558 LogicalOperator::Filter(filter) => {
559 let child_cost = self.estimate_tree_inner(&filter.input, card_est);
560 child_cost + self.filter_cost(filter, cardinality)
561 }
562 LogicalOperator::Project(project) => {
563 let child_cost = self.estimate_tree_inner(&project.input, card_est);
564 child_cost + self.project_cost(project, cardinality)
565 }
566 LogicalOperator::Expand(expand) => {
567 let child_cost = self.estimate_tree_inner(&expand.input, card_est);
568 child_cost + self.expand_cost(expand, cardinality)
569 }
570 LogicalOperator::Join(join) => {
571 let left_cost = self.estimate_tree_inner(&join.left, card_est);
572 let right_cost = self.estimate_tree_inner(&join.right, card_est);
573 let left_card = card_est.estimate(&join.left);
574 let right_card = card_est.estimate(&join.right);
575 let join_cost = self.join_cost_with_children(
576 join,
577 cardinality,
578 Some(left_card),
579 Some(right_card),
580 );
581 left_cost + right_cost + join_cost
582 }
583 LogicalOperator::LeftJoin(lj) => {
584 let left_cost = self.estimate_tree_inner(&lj.left, card_est);
585 let right_cost = self.estimate_tree_inner(&lj.right, card_est);
586 let left_card = card_est.estimate(&lj.left);
587 let right_card = card_est.estimate(&lj.right);
588 let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
589 left_cost + right_cost + join_cost
590 }
591 LogicalOperator::Aggregate(agg) => {
592 let child_cost = self.estimate_tree_inner(&agg.input, card_est);
593 child_cost + self.aggregate_cost(agg, cardinality)
594 }
595 LogicalOperator::Sort(sort) => {
596 let child_cost = self.estimate_tree_inner(&sort.input, card_est);
597 child_cost + self.sort_cost(sort, cardinality)
598 }
599 LogicalOperator::Distinct(distinct) => {
600 let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
601 child_cost + self.distinct_cost(distinct, cardinality)
602 }
603 LogicalOperator::Limit(limit) => {
604 let child_cost = self.estimate_tree_inner(&limit.input, card_est);
605 child_cost + self.limit_cost(limit, cardinality)
606 }
607 LogicalOperator::Skip(skip) => {
608 let child_cost = self.estimate_tree_inner(&skip.input, card_est);
609 child_cost + self.skip_cost(skip, cardinality)
610 }
611 LogicalOperator::Return(ret) => {
612 let child_cost = self.estimate_tree_inner(&ret.input, card_est);
613 child_cost + self.return_cost(ret, cardinality)
614 }
615 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
616 LogicalOperator::VectorJoin(join) => {
617 let child_cost = self.estimate_tree_inner(&join.input, card_est);
618 child_cost + self.vector_join_cost(join, cardinality)
619 }
620 LogicalOperator::MultiWayJoin(mwj) => {
621 let mut children_cost = Cost::zero();
622 for input in &mwj.inputs {
623 children_cost += self.estimate_tree_inner(input, card_est);
624 }
625 children_cost + self.multi_way_join_cost(mwj, cardinality)
626 }
627 LogicalOperator::Empty => Cost::zero(),
628 LogicalOperator::TextScan(scan) => self.text_scan_cost(scan, cardinality),
629 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
630 }
631 }
632
633 #[must_use]
635 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
636 if a.total() <= b.total() { a } else { b }
637 }
638
639 #[must_use]
655 pub fn leapfrog_join_cost(
656 &self,
657 num_relations: usize,
658 cardinalities: &[f64],
659 output_cardinality: f64,
660 ) -> Cost {
661 if cardinalities.is_empty() {
662 return Cost::zero();
663 }
664
665 let total_input: f64 = cardinalities.iter().sum();
666 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
667
668 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
673 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
674 } else {
675 output_cardinality * self.cpu_tuple_cost
676 };
677
678 let output_cost = output_cardinality * self.cpu_tuple_cost;
680
681 let memory = total_input * self.avg_tuple_size * 2.0;
683
684 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
685 }
686
687 #[must_use]
691 pub fn prefer_leapfrog_join(
692 &self,
693 num_relations: usize,
694 cardinalities: &[f64],
695 output_cardinality: f64,
696 ) -> bool {
697 if num_relations < 3 || cardinalities.len() < 3 {
698 return false;
700 }
701
702 let leapfrog_cost =
703 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
704
705 let mut hash_cascade_cost = Cost::zero();
709 let mut intermediate_cardinality = cardinalities[0];
710
711 for card in &cardinalities[1..] {
712 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
715 left: Box::new(LogicalOperator::Empty),
716 right: Box::new(LogicalOperator::Empty),
717 join_type: JoinType::Inner,
718 conditions: vec![],
719 };
720 hash_cascade_cost += self.join_cost(&join, join_output);
721 intermediate_cardinality = join_output;
722 }
723
724 leapfrog_cost.total() < hash_cascade_cost.total()
725 }
726
727 #[must_use]
735 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
736 if num_hops <= 1 || avg_fanout <= 1.0 {
737 return 1.0; }
739
740 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
747 let hops_i32 = num_hops as i32;
748 let full_size = avg_fanout.powi(hops_i32);
749 let factorized_size = if avg_fanout > 1.0 {
750 (avg_fanout.powi(hops_i32 + 1) - 1.0) / (avg_fanout - 1.0)
751 } else {
752 num_hops as f64
753 };
754
755 (factorized_size / full_size).min(1.0)
756 }
757}
758
759impl Default for CostModel {
760 fn default() -> Self {
761 Self::new()
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use crate::query::plan::{
769 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
770 PathMode, Projection, ReturnItem, SortOrder,
771 };
772
773 #[test]
774 fn test_cost_addition() {
775 let a = Cost::cpu(10.0).with_io(5.0);
776 let b = Cost::cpu(20.0).with_memory(100.0);
777 let c = a + b;
778
779 assert!((c.cpu - 30.0).abs() < 0.001);
780 assert!((c.io - 5.0).abs() < 0.001);
781 assert!((c.memory - 100.0).abs() < 0.001);
782 }
783
784 #[test]
785 fn test_cost_total() {
786 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
787 assert!((cost.total() - 30.0).abs() < 0.001);
789 }
790
791 #[test]
792 fn test_cost_model_node_scan() {
793 let model = CostModel::new();
794 let scan = NodeScanOp {
795 variable: "n".to_string(),
796 label: Some("Person".to_string()),
797 input: None,
798 };
799 let cost = model.node_scan_cost(&scan, 1000.0);
800
801 assert!(cost.cpu > 0.0);
802 assert!(cost.io > 0.0);
803 }
804
805 #[test]
806 fn test_cost_model_sort() {
807 let model = CostModel::new();
808 let sort = SortOp {
809 keys: vec![],
810 input: Box::new(LogicalOperator::Empty),
811 };
812
813 let cost_100 = model.sort_cost(&sort, 100.0);
814 let cost_1000 = model.sort_cost(&sort, 1000.0);
815
816 assert!(cost_1000.total() > cost_100.total());
818 }
819
820 #[test]
821 fn test_cost_zero() {
822 let cost = Cost::zero();
823 assert!((cost.cpu).abs() < 0.001);
824 assert!((cost.io).abs() < 0.001);
825 assert!((cost.memory).abs() < 0.001);
826 assert!((cost.network).abs() < 0.001);
827 assert!((cost.total()).abs() < 0.001);
828 }
829
830 #[test]
831 fn test_cost_add_assign() {
832 let mut cost = Cost::cpu(10.0);
833 cost += Cost::cpu(5.0).with_io(2.0);
834 assert!((cost.cpu - 15.0).abs() < 0.001);
835 assert!((cost.io - 2.0).abs() < 0.001);
836 }
837
838 #[test]
839 fn test_cost_total_weighted() {
840 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
841 let total = cost.total_weighted(2.0, 5.0, 0.5);
843 assert!((total - 80.0).abs() < 0.001);
844 }
845
846 #[test]
847 fn test_cost_model_filter() {
848 let model = CostModel::new();
849 let filter = FilterOp {
850 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
851 input: Box::new(LogicalOperator::Empty),
852 pushdown_hint: None,
853 };
854 let cost = model.filter_cost(&filter, 1000.0);
855
856 assert!(cost.cpu > 0.0);
858 assert!((cost.io).abs() < 0.001);
859 }
860
861 #[test]
862 fn test_cost_model_project() {
863 let model = CostModel::new();
864 let project = ProjectOp {
865 projections: vec![
866 Projection {
867 expression: LogicalExpression::Variable("a".to_string()),
868 alias: None,
869 },
870 Projection {
871 expression: LogicalExpression::Variable("b".to_string()),
872 alias: None,
873 },
874 ],
875 input: Box::new(LogicalOperator::Empty),
876 pass_through_input: false,
877 };
878 let cost = model.project_cost(&project, 1000.0);
879
880 assert!(cost.cpu > 0.0);
882 }
883
884 #[test]
885 fn test_cost_model_expand() {
886 let model = CostModel::new();
887 let expand = ExpandOp {
888 from_variable: "a".to_string(),
889 to_variable: "b".to_string(),
890 edge_variable: None,
891 direction: ExpandDirection::Outgoing,
892 edge_types: vec![],
893 min_hops: 1,
894 max_hops: Some(1),
895 input: Box::new(LogicalOperator::Empty),
896 path_alias: None,
897 path_mode: PathMode::Walk,
898 };
899 let cost = model.expand_cost(&expand, 1000.0);
900
901 assert!(cost.cpu > 0.0);
903 }
904
905 #[test]
906 fn test_cost_model_expand_with_edge_type_stats() {
907 let mut degrees = std::collections::HashMap::new();
908 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
912
913 let knows_out = ExpandOp {
915 from_variable: "a".to_string(),
916 to_variable: "b".to_string(),
917 edge_variable: None,
918 direction: ExpandDirection::Outgoing,
919 edge_types: vec!["KNOWS".to_string()],
920 min_hops: 1,
921 max_hops: Some(1),
922 input: Box::new(LogicalOperator::Empty),
923 path_alias: None,
924 path_mode: PathMode::Walk,
925 };
926 let cost_knows = model.expand_cost(&knows_out, 1000.0);
927
928 let works_out = ExpandOp {
930 from_variable: "a".to_string(),
931 to_variable: "b".to_string(),
932 edge_variable: None,
933 direction: ExpandDirection::Outgoing,
934 edge_types: vec!["WORKS_AT".to_string()],
935 min_hops: 1,
936 max_hops: Some(1),
937 input: Box::new(LogicalOperator::Empty),
938 path_alias: None,
939 path_mode: PathMode::Walk,
940 };
941 let cost_works = model.expand_cost(&works_out, 1000.0);
942
943 assert!(
945 cost_knows.cpu > cost_works.cpu,
946 "KNOWS(5) should cost more than WORKS_AT(1)"
947 );
948
949 let works_in = ExpandOp {
951 from_variable: "c".to_string(),
952 to_variable: "p".to_string(),
953 edge_variable: None,
954 direction: ExpandDirection::Incoming,
955 edge_types: vec!["WORKS_AT".to_string()],
956 min_hops: 1,
957 max_hops: Some(1),
958 input: Box::new(LogicalOperator::Empty),
959 path_alias: None,
960 path_mode: PathMode::Walk,
961 };
962 let cost_works_in = model.expand_cost(&works_in, 1000.0);
963
964 assert!(
966 cost_works_in.cpu > cost_knows.cpu,
967 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
968 );
969 }
970
971 #[test]
972 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
973 let model = CostModel::new().with_avg_fanout(7.0);
974 let expand = ExpandOp {
975 from_variable: "a".to_string(),
976 to_variable: "b".to_string(),
977 edge_variable: None,
978 direction: ExpandDirection::Outgoing,
979 edge_types: vec!["UNKNOWN_TYPE".to_string()],
980 min_hops: 1,
981 max_hops: Some(1),
982 input: Box::new(LogicalOperator::Empty),
983 path_alias: None,
984 path_mode: PathMode::Walk,
985 };
986 let cost_unknown = model.expand_cost(&expand, 1000.0);
987
988 let expand_no_type = ExpandOp {
990 from_variable: "a".to_string(),
991 to_variable: "b".to_string(),
992 edge_variable: None,
993 direction: ExpandDirection::Outgoing,
994 edge_types: vec![],
995 min_hops: 1,
996 max_hops: Some(1),
997 input: Box::new(LogicalOperator::Empty),
998 path_alias: None,
999 path_mode: PathMode::Walk,
1000 };
1001 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
1002
1003 assert!(
1005 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
1006 "Unknown edge type should use global fanout"
1007 );
1008 }
1009
1010 #[test]
1011 fn test_cost_model_hash_join() {
1012 let model = CostModel::new();
1013 let join = JoinOp {
1014 left: Box::new(LogicalOperator::Empty),
1015 right: Box::new(LogicalOperator::Empty),
1016 join_type: JoinType::Inner,
1017 conditions: vec![JoinCondition {
1018 left: LogicalExpression::Variable("a".to_string()),
1019 right: LogicalExpression::Variable("b".to_string()),
1020 }],
1021 };
1022 let cost = model.join_cost(&join, 10000.0);
1023
1024 assert!(cost.cpu > 0.0);
1026 assert!(cost.memory > 0.0);
1027 }
1028
1029 #[test]
1030 fn test_cost_model_cross_join() {
1031 let model = CostModel::new();
1032 let join = JoinOp {
1033 left: Box::new(LogicalOperator::Empty),
1034 right: Box::new(LogicalOperator::Empty),
1035 join_type: JoinType::Cross,
1036 conditions: vec![],
1037 };
1038 let cost = model.join_cost(&join, 1000000.0);
1039
1040 assert!(cost.cpu > 0.0);
1042 }
1043
1044 #[test]
1045 fn test_cost_model_semi_join() {
1046 let model = CostModel::new();
1047 let join = JoinOp {
1048 left: Box::new(LogicalOperator::Empty),
1049 right: Box::new(LogicalOperator::Empty),
1050 join_type: JoinType::Semi,
1051 conditions: vec![],
1052 };
1053 let cost_semi = model.join_cost(&join, 1000.0);
1054
1055 let inner_join = JoinOp {
1056 left: Box::new(LogicalOperator::Empty),
1057 right: Box::new(LogicalOperator::Empty),
1058 join_type: JoinType::Inner,
1059 conditions: vec![],
1060 };
1061 let cost_inner = model.join_cost(&inner_join, 1000.0);
1062
1063 assert!(cost_semi.cpu > 0.0);
1065 assert!(cost_inner.cpu > 0.0);
1066 }
1067
1068 #[test]
1069 fn test_cost_model_aggregate() {
1070 let model = CostModel::new();
1071 let agg = AggregateOp {
1072 group_by: vec![],
1073 aggregates: vec![
1074 AggregateExpr {
1075 function: AggregateFunction::Count,
1076 expression: None,
1077 expression2: None,
1078 distinct: false,
1079 alias: Some("cnt".to_string()),
1080 percentile: None,
1081 separator: None,
1082 },
1083 AggregateExpr {
1084 function: AggregateFunction::Sum,
1085 expression: Some(LogicalExpression::Variable("x".to_string())),
1086 expression2: None,
1087 distinct: false,
1088 alias: Some("total".to_string()),
1089 percentile: None,
1090 separator: None,
1091 },
1092 ],
1093 input: Box::new(LogicalOperator::Empty),
1094 having: None,
1095 };
1096 let cost = model.aggregate_cost(&agg, 1000.0);
1097
1098 assert!(cost.cpu > 0.0);
1100 assert!(cost.memory > 0.0);
1101 }
1102
1103 #[test]
1104 fn test_cost_model_distinct() {
1105 let model = CostModel::new();
1106 let distinct = DistinctOp {
1107 input: Box::new(LogicalOperator::Empty),
1108 columns: None,
1109 };
1110 let cost = model.distinct_cost(&distinct, 1000.0);
1111
1112 assert!(cost.cpu > 0.0);
1114 assert!(cost.memory > 0.0);
1115 }
1116
1117 #[test]
1118 fn test_cost_model_limit() {
1119 let model = CostModel::new();
1120 let limit = LimitOp {
1121 count: 10.into(),
1122 input: Box::new(LogicalOperator::Empty),
1123 };
1124 let cost = model.limit_cost(&limit, 1000.0);
1125
1126 assert!(cost.cpu > 0.0);
1128 assert!(cost.cpu < 1.0); }
1130
1131 #[test]
1132 fn test_cost_model_skip() {
1133 let model = CostModel::new();
1134 let skip = SkipOp {
1135 count: 100.into(),
1136 input: Box::new(LogicalOperator::Empty),
1137 };
1138 let cost = model.skip_cost(&skip, 1000.0);
1139
1140 assert!(cost.cpu > 0.0);
1142 }
1143
1144 #[test]
1145 fn test_cost_model_return() {
1146 let model = CostModel::new();
1147 let ret = ReturnOp {
1148 items: vec![
1149 ReturnItem {
1150 expression: LogicalExpression::Variable("a".to_string()),
1151 alias: None,
1152 },
1153 ReturnItem {
1154 expression: LogicalExpression::Variable("b".to_string()),
1155 alias: None,
1156 },
1157 ],
1158 distinct: false,
1159 input: Box::new(LogicalOperator::Empty),
1160 };
1161 let cost = model.return_cost(&ret, 1000.0);
1162
1163 assert!(cost.cpu > 0.0);
1165 }
1166
1167 #[test]
1168 fn test_cost_cheaper() {
1169 let model = CostModel::new();
1170 let cheap = Cost::cpu(10.0);
1171 let expensive = Cost::cpu(100.0);
1172
1173 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1174 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1175 }
1176
1177 #[test]
1178 fn test_cost_comparison_prefers_lower_total() {
1179 let model = CostModel::new();
1180 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1182 let io_heavy = Cost::cpu(10.0).with_io(20.0);
1184
1185 assert!(cpu_heavy.total() < io_heavy.total());
1187 assert_eq!(
1188 model.cheaper(&cpu_heavy, &io_heavy).total(),
1189 cpu_heavy.total()
1190 );
1191 }
1192
1193 #[test]
1194 fn test_cost_model_sort_with_keys() {
1195 let model = CostModel::new();
1196 let sort_single = SortOp {
1197 keys: vec![crate::query::plan::SortKey {
1198 expression: LogicalExpression::Variable("a".to_string()),
1199 order: SortOrder::Ascending,
1200 nulls: None,
1201 }],
1202 input: Box::new(LogicalOperator::Empty),
1203 };
1204 let sort_multi = SortOp {
1205 keys: vec![
1206 crate::query::plan::SortKey {
1207 expression: LogicalExpression::Variable("a".to_string()),
1208 order: SortOrder::Ascending,
1209 nulls: None,
1210 },
1211 crate::query::plan::SortKey {
1212 expression: LogicalExpression::Variable("b".to_string()),
1213 order: SortOrder::Descending,
1214 nulls: None,
1215 },
1216 ],
1217 input: Box::new(LogicalOperator::Empty),
1218 };
1219
1220 let cost_single = model.sort_cost(&sort_single, 1000.0);
1221 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1222
1223 assert!(cost_multi.cpu > cost_single.cpu);
1225 }
1226
1227 #[test]
1228 fn test_cost_model_empty_operator() {
1229 let model = CostModel::new();
1230 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1231 assert!((cost.total()).abs() < 0.001);
1232 }
1233
1234 #[test]
1235 fn test_cost_model_default() {
1236 let model = CostModel::default();
1237 let scan = NodeScanOp {
1238 variable: "n".to_string(),
1239 label: None,
1240 input: None,
1241 };
1242 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1243 assert!(cost.total() > 0.0);
1244 }
1245
1246 #[test]
1247 fn test_leapfrog_join_cost() {
1248 let model = CostModel::new();
1249
1250 let cardinalities = vec![1000.0, 1000.0, 1000.0];
1252 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1253
1254 assert!(cost.cpu > 0.0);
1256 assert!(cost.memory > 0.0);
1258 }
1259
1260 #[test]
1261 fn test_leapfrog_join_cost_empty() {
1262 let model = CostModel::new();
1263 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1264 assert!((cost.total()).abs() < 0.001);
1265 }
1266
1267 #[test]
1268 fn test_prefer_leapfrog_join_for_triangles() {
1269 let model = CostModel::new();
1270
1271 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1273 let output = 1000.0;
1274
1275 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1276
1277 assert!(leapfrog_cost.cpu > 0.0);
1279 assert!(leapfrog_cost.memory > 0.0);
1280
1281 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1284 }
1286
1287 #[test]
1288 fn test_prefer_leapfrog_join_binary_case() {
1289 let model = CostModel::new();
1290
1291 let cardinalities = vec![1000.0, 1000.0];
1293 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1294 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1295 }
1296
1297 #[test]
1298 fn test_factorized_benefit_single_hop() {
1299 let model = CostModel::new();
1300
1301 let benefit = model.factorized_benefit(10.0, 1);
1303 assert!(
1304 (benefit - 1.0).abs() < 0.001,
1305 "Single hop should have no benefit"
1306 );
1307 }
1308
1309 #[test]
1310 fn test_factorized_benefit_multi_hop() {
1311 let model = CostModel::new();
1312
1313 let benefit = model.factorized_benefit(10.0, 3);
1315
1316 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1320 assert!(benefit > 0.0, "Benefit should be positive");
1321 }
1322
1323 #[test]
1324 fn test_factorized_benefit_low_fanout() {
1325 let model = CostModel::new();
1326
1327 let benefit = model.factorized_benefit(1.5, 2);
1329 assert!(
1330 benefit <= 1.0,
1331 "Low fanout still benefits from factorization"
1332 );
1333 }
1334
1335 #[test]
1336 fn test_node_scan_uses_label_cardinality_for_io() {
1337 let mut label_cards = std::collections::HashMap::new();
1338 label_cards.insert("Person".to_string(), 500_u64);
1339 label_cards.insert("Company".to_string(), 50_u64);
1340
1341 let model = CostModel::new()
1342 .with_label_cardinalities(label_cards)
1343 .with_graph_totals(550, 1000);
1344
1345 let person_scan = NodeScanOp {
1346 variable: "n".to_string(),
1347 label: Some("Person".to_string()),
1348 input: None,
1349 };
1350 let company_scan = NodeScanOp {
1351 variable: "n".to_string(),
1352 label: Some("Company".to_string()),
1353 input: None,
1354 };
1355
1356 let person_cost = model.node_scan_cost(&person_scan, 500.0);
1357 let company_cost = model.node_scan_cost(&company_scan, 50.0);
1358
1359 assert!(
1361 person_cost.io > company_cost.io * 5.0,
1362 "Person ({}) should have much higher IO than Company ({})",
1363 person_cost.io,
1364 company_cost.io
1365 );
1366 }
1367
1368 #[test]
1369 fn test_node_scan_unlabeled_uses_total_nodes() {
1370 let model = CostModel::new().with_graph_totals(10_000, 50_000);
1371
1372 let scan = NodeScanOp {
1373 variable: "n".to_string(),
1374 label: None,
1375 input: None,
1376 };
1377
1378 let cost = model.node_scan_cost(&scan, 10_000.0);
1379 let expected_pages = (10_000.0 * 100.0) / 8192.0;
1380 assert!(
1381 (cost.io - expected_pages).abs() < 0.1,
1382 "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1383 cost.io,
1384 expected_pages
1385 );
1386 }
1387
1388 #[test]
1389 fn test_join_cost_with_actual_child_cardinalities() {
1390 let model = CostModel::new();
1391 let join = JoinOp {
1392 left: Box::new(LogicalOperator::Empty),
1393 right: Box::new(LogicalOperator::Empty),
1394 join_type: JoinType::Inner,
1395 conditions: vec![JoinCondition {
1396 left: LogicalExpression::Variable("a".to_string()),
1397 right: LogicalExpression::Variable("b".to_string()),
1398 }],
1399 };
1400
1401 let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1403
1404 let cost_sqrt = model.join_cost(&join, 500.0);
1406
1407 assert!(
1411 cost_actual.cpu > cost_sqrt.cpu,
1412 "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1413 cost_actual.cpu,
1414 cost_sqrt.cpu
1415 );
1416 }
1417
1418 #[test]
1419 fn test_expand_multi_edge_types() {
1420 let mut degrees = std::collections::HashMap::new();
1421 degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1422 degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1423
1424 let model = CostModel::new().with_edge_type_degrees(degrees);
1425
1426 let multi_expand = ExpandOp {
1428 from_variable: "a".to_string(),
1429 to_variable: "b".to_string(),
1430 edge_variable: None,
1431 direction: ExpandDirection::Outgoing,
1432 edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1433 min_hops: 1,
1434 max_hops: Some(1),
1435 input: Box::new(LogicalOperator::Empty),
1436 path_alias: None,
1437 path_mode: PathMode::Walk,
1438 };
1439 let multi_cost = model.expand_cost(&multi_expand, 100.0);
1440
1441 let single_expand = ExpandOp {
1443 from_variable: "a".to_string(),
1444 to_variable: "b".to_string(),
1445 edge_variable: None,
1446 direction: ExpandDirection::Outgoing,
1447 edge_types: vec!["KNOWS".to_string()],
1448 min_hops: 1,
1449 max_hops: Some(1),
1450 input: Box::new(LogicalOperator::Empty),
1451 path_alias: None,
1452 path_mode: PathMode::Walk,
1453 };
1454 let single_cost = model.expand_cost(&single_expand, 100.0);
1455
1456 assert!(
1458 multi_cost.cpu > single_cost.cpu * 3.0,
1459 "Multi-type fanout ({}) should be much higher than single-type ({})",
1460 multi_cost.cpu,
1461 single_cost.cpu
1462 );
1463 }
1464
1465 #[test]
1466 fn test_recursive_tree_cost() {
1467 use crate::query::optimizer::CardinalityEstimator;
1468
1469 let mut label_cards = std::collections::HashMap::new();
1470 label_cards.insert("Person".to_string(), 1000_u64);
1471
1472 let model = CostModel::new()
1473 .with_label_cardinalities(label_cards)
1474 .with_graph_totals(1000, 5000)
1475 .with_avg_fanout(5.0);
1476
1477 let mut card_est = CardinalityEstimator::new();
1478 card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1479
1480 let plan = LogicalOperator::Return(ReturnOp {
1482 items: vec![ReturnItem {
1483 expression: LogicalExpression::Variable("n".to_string()),
1484 alias: None,
1485 }],
1486 distinct: false,
1487 input: Box::new(LogicalOperator::Filter(FilterOp {
1488 predicate: LogicalExpression::Binary {
1489 left: Box::new(LogicalExpression::Property {
1490 variable: "n".to_string(),
1491 property: "age".to_string(),
1492 }),
1493 op: crate::query::plan::BinaryOp::Gt,
1494 right: Box::new(LogicalExpression::Literal(
1495 grafeo_common::types::Value::Int64(30),
1496 )),
1497 },
1498 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1499 variable: "n".to_string(),
1500 label: Some("Person".to_string()),
1501 input: None,
1502 })),
1503 pushdown_hint: None,
1504 })),
1505 });
1506
1507 let tree_cost = model.estimate_tree(&plan, &card_est);
1508
1509 assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1511 assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1512
1513 let root_only_card = card_est.estimate(&plan);
1515 let root_only_cost = model.estimate(&plan, root_only_card);
1516
1517 assert!(
1519 tree_cost.total() > root_only_cost.total(),
1520 "Recursive tree cost ({}) should exceed root-only cost ({})",
1521 tree_cost.total(),
1522 root_only_cost.total()
1523 );
1524 }
1525
1526 #[test]
1527 fn test_statistics_driven_vs_default_cost() {
1528 let default_model = CostModel::new();
1529
1530 let mut label_cards = std::collections::HashMap::new();
1531 label_cards.insert("Person".to_string(), 100_u64);
1532 let stats_model = CostModel::new()
1533 .with_label_cardinalities(label_cards)
1534 .with_graph_totals(100, 500);
1535
1536 let scan = NodeScanOp {
1538 variable: "n".to_string(),
1539 label: Some("Person".to_string()),
1540 input: None,
1541 };
1542
1543 let default_cost = default_model.node_scan_cost(&scan, 100.0);
1544 let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1545
1546 assert!(
1550 (default_cost.io - stats_cost.io).abs() < 0.1,
1551 "When cardinality matches label size, costs should be similar"
1552 );
1553 }
1554
1555 #[test]
1556 fn test_leapfrog_join_cost_unit_min_cardinality() {
1557 let model = CostModel::new();
1558 let cost = model.leapfrog_join_cost(3, &[1.0, 100.0, 200.0], 50.0);
1560 assert!(cost.cpu > 0.0);
1561 assert!(cost.memory > 0.0);
1562 }
1563
1564 #[test]
1565 fn test_prefer_leapfrog_join_cardinalities_below_three() {
1566 let model = CostModel::new();
1567 assert!(!model.prefer_leapfrog_join(3, &[100.0, 200.0], 50.0));
1569 assert!(!model.prefer_leapfrog_join(5, &[], 10.0));
1570 }
1571
1572 #[test]
1573 fn test_factorized_benefit_zero_hops() {
1574 let model = CostModel::new();
1575 assert_eq!(model.factorized_benefit(10.0, 0), 1.0);
1577 }
1578
1579 #[test]
1580 fn test_factorized_benefit_unit_fanout_guard() {
1581 let model = CostModel::new();
1582 assert_eq!(model.factorized_benefit(1.0, 5), 1.0);
1584 }
1585}