1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7 LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8 VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17 pub cpu: f64,
19 pub io: f64,
21 pub memory: f64,
23 pub network: f64,
25}
26
27impl Cost {
28 #[must_use]
30 pub fn zero() -> Self {
31 Self {
32 cpu: 0.0,
33 io: 0.0,
34 memory: 0.0,
35 network: 0.0,
36 }
37 }
38
39 #[must_use]
41 pub fn cpu(cpu: f64) -> Self {
42 Self {
43 cpu,
44 io: 0.0,
45 memory: 0.0,
46 network: 0.0,
47 }
48 }
49
50 #[must_use]
52 pub fn with_io(mut self, io: f64) -> Self {
53 self.io = io;
54 self
55 }
56
57 #[must_use]
59 pub fn with_memory(mut self, memory: f64) -> Self {
60 self.memory = memory;
61 self
62 }
63
64 #[must_use]
68 pub fn total(&self) -> f64 {
69 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70 }
71
72 #[must_use]
74 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76 }
77}
78
79impl std::ops::Add for Cost {
80 type Output = Self;
81
82 fn add(self, other: Self) -> Self {
83 Self {
84 cpu: self.cpu + other.cpu,
85 io: self.io + other.io,
86 memory: self.memory + other.memory,
87 network: self.network + other.network,
88 }
89 }
90}
91
92impl std::ops::AddAssign for Cost {
93 fn add_assign(&mut self, other: Self) {
94 self.cpu += other.cpu;
95 self.io += other.io;
96 self.memory += other.memory;
97 self.network += other.network;
98 }
99}
100
101pub struct CostModel {
109 cpu_tuple_cost: f64,
111 hash_lookup_cost: f64,
113 sort_comparison_cost: f64,
115 avg_tuple_size: f64,
117 page_size: f64,
119 avg_fanout: f64,
121 edge_type_degrees: HashMap<String, (f64, f64)>,
123 label_cardinalities: HashMap<String, u64>,
125 total_nodes: u64,
127 total_edges: u64,
129}
130
131impl CostModel {
132 #[must_use]
134 pub fn new() -> Self {
135 Self {
136 cpu_tuple_cost: 0.01,
137 hash_lookup_cost: 0.03,
138 sort_comparison_cost: 0.02,
139 avg_tuple_size: 100.0,
140 page_size: 8192.0,
141 avg_fanout: 10.0,
142 edge_type_degrees: HashMap::new(),
143 label_cardinalities: HashMap::new(),
144 total_nodes: 0,
145 total_edges: 0,
146 }
147 }
148
149 #[must_use]
151 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153 self
154 }
155
156 #[must_use]
160 pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161 self.edge_type_degrees = degrees;
162 self
163 }
164
165 #[must_use]
167 pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168 self.label_cardinalities = cardinalities;
169 self
170 }
171
172 #[must_use]
174 pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175 self.total_nodes = total_nodes;
176 self.total_edges = total_edges;
177 self
178 }
179
180 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185 if expand.edge_types.is_empty() {
186 return self.avg_fanout;
187 }
188
189 let mut total_fanout = 0.0;
190 let mut all_found = true;
191
192 for edge_type in &expand.edge_types {
193 if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194 total_fanout += match expand.direction {
195 ExpandDirection::Outgoing => out_deg,
196 ExpandDirection::Incoming => in_deg,
197 ExpandDirection::Both => out_deg + in_deg,
198 };
199 } else {
200 all_found = false;
201 break;
202 }
203 }
204
205 if all_found {
206 total_fanout
207 } else {
208 self.avg_fanout
209 }
210 }
211
212 #[must_use]
214 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215 match op {
216 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227 LogicalOperator::Empty => Cost::zero(),
228 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230 LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231 LogicalOperator::LeftJoin(lj) => {
232 let left_card = self.estimate_child_cardinality(&lj.left);
233 let right_card = self.estimate_child_cardinality(&lj.right);
234 self.left_join_cost(lj, cardinality, left_card, right_card)
235 }
236 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
237 }
238 }
239
240 fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
246 let scan_size = if let Some(label) = &scan.label {
248 self.label_cardinalities
249 .get(label)
250 .map_or(cardinality, |&count| count as f64)
251 } else if self.total_nodes > 0 {
252 self.total_nodes as f64
253 } else {
254 cardinality
255 };
256 let pages = (scan_size * self.avg_tuple_size) / self.page_size;
257 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
259 }
260
261 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
263 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
265 }
266
267 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
269 let expr_count = project.projections.len() as f64;
271 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
272 }
273
274 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
279 let fanout = self.fanout_for_expand(expand);
280 let lookup_cost = cardinality * self.hash_lookup_cost;
282 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
284 Cost::cpu(lookup_cost + output_cost)
285 }
286
287 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
292 self.join_cost_with_children(join, cardinality, None, None)
293 }
294
295 fn join_cost_with_children(
297 &self,
298 join: &JoinOp,
299 cardinality: f64,
300 left_cardinality: Option<f64>,
301 right_cardinality: Option<f64>,
302 ) -> Cost {
303 match join.join_type {
304 JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
305 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
306 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
308 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
309
310 let build_cost = build_cardinality * self.hash_lookup_cost;
311 let memory_cost = build_cardinality * self.avg_tuple_size;
312 let probe_cost = probe_cardinality * self.hash_lookup_cost;
313 let output_cost = cardinality * self.cpu_tuple_cost;
314
315 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
316 }
317 JoinType::Semi | JoinType::Anti => {
318 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
319 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
320
321 let build_cost = build_cardinality * self.hash_lookup_cost;
322 let probe_cost = probe_cardinality * self.hash_lookup_cost;
323
324 Cost::cpu(build_cost + probe_cost)
325 .with_memory(build_cardinality * self.avg_tuple_size)
326 }
327 }
328 }
329
330 fn left_join_cost(
336 &self,
337 _lj: &LeftJoinOp,
338 cardinality: f64,
339 left_card: f64,
340 right_card: f64,
341 ) -> Cost {
342 let build_cost = right_card * self.hash_lookup_cost;
343 let memory_cost = right_card * self.avg_tuple_size;
344 let probe_cost = left_card * self.hash_lookup_cost;
345 let output_cost = cardinality * self.cpu_tuple_cost;
346
347 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
348 }
349
350 fn estimate_child_cardinality(&self, op: &LogicalOperator) -> f64 {
356 match op {
357 LogicalOperator::NodeScan(scan) => if let Some(label) = &scan.label {
358 self.label_cardinalities
359 .get(label)
360 .map_or(self.total_nodes as f64, |&c| c as f64)
361 } else {
362 self.total_nodes as f64
363 }
364 .max(1.0),
365 LogicalOperator::Expand(expand) => {
366 let input_card = self.estimate_child_cardinality(&expand.input);
367 let fanout = if expand.edge_types.is_empty() {
368 self.avg_fanout
369 } else {
370 self.fanout_for_expand(expand)
371 };
372 (input_card * fanout).max(1.0)
373 }
374 LogicalOperator::Filter(filter) => {
375 (self.estimate_child_cardinality(&filter.input) * 0.1).max(1.0)
377 }
378 LogicalOperator::Return(ret) => self.estimate_child_cardinality(&ret.input),
379 LogicalOperator::Limit(limit) => {
380 let input = self.estimate_child_cardinality(&limit.input);
381 input.min(100.0)
383 }
384 _ => (self.total_nodes as f64).max(1.0),
385 }
386 }
387
388 fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
393 let n = mwj.inputs.len();
394 if n == 0 {
395 return Cost::zero();
396 }
397 let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
400 let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
401 self.leapfrog_join_cost(n, &cardinalities, cardinality)
402 }
403
404 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
406 let hash_cost = cardinality * self.hash_lookup_cost;
408
409 let agg_count = agg.aggregates.len() as f64;
411 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
412
413 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
416
417 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
418 }
419
420 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
422 if cardinality <= 1.0 {
423 return Cost::zero();
424 }
425
426 let comparisons = cardinality * cardinality.log2();
428 let key_count = sort.keys.len() as f64;
429
430 let memory_cost = cardinality * self.avg_tuple_size;
432
433 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
434 }
435
436 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
438 let hash_cost = cardinality * self.hash_lookup_cost;
440 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
443 }
444
445 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
447 Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
449 }
450
451 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
453 Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
455 }
456
457 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
459 let expr_count = ret.items.len() as f64;
461 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
462 }
463
464 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
469 let k = scan.k as f64;
471
472 let ef = 64.0;
475 let n = cardinality.max(1.0);
476 let search_cost = if scan.index_name.is_some() {
477 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
480 n * self.cpu_tuple_cost * 10.0
482 };
483
484 let memory = k * self.avg_tuple_size * 2.0;
486
487 Cost::cpu(search_cost).with_memory(memory)
488 }
489
490 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
494 let k = join.k as f64;
495
496 let per_row_search_cost = if join.index_name.is_some() {
499 let ef = 64.0;
501 let n = cardinality.max(1.0);
502 ef * n.ln() * self.cpu_tuple_cost * 10.0
503 } else {
504 cardinality * self.cpu_tuple_cost * 10.0
506 };
507
508 let input_cardinality = (cardinality / k).max(1.0);
511 let total_search_cost = input_cardinality * per_row_search_cost;
512
513 let memory = cardinality * self.avg_tuple_size;
515
516 Cost::cpu(total_search_cost).with_memory(memory)
517 }
518
519 #[must_use]
525 pub fn estimate_tree(
526 &self,
527 op: &LogicalOperator,
528 card_estimator: &super::CardinalityEstimator,
529 ) -> Cost {
530 self.estimate_tree_inner(op, card_estimator)
531 }
532
533 fn estimate_tree_inner(
534 &self,
535 op: &LogicalOperator,
536 card_est: &super::CardinalityEstimator,
537 ) -> Cost {
538 let cardinality = card_est.estimate(op);
539
540 match op {
541 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
542 LogicalOperator::Filter(filter) => {
543 let child_cost = self.estimate_tree_inner(&filter.input, card_est);
544 child_cost + self.filter_cost(filter, cardinality)
545 }
546 LogicalOperator::Project(project) => {
547 let child_cost = self.estimate_tree_inner(&project.input, card_est);
548 child_cost + self.project_cost(project, cardinality)
549 }
550 LogicalOperator::Expand(expand) => {
551 let child_cost = self.estimate_tree_inner(&expand.input, card_est);
552 child_cost + self.expand_cost(expand, cardinality)
553 }
554 LogicalOperator::Join(join) => {
555 let left_cost = self.estimate_tree_inner(&join.left, card_est);
556 let right_cost = self.estimate_tree_inner(&join.right, card_est);
557 let left_card = card_est.estimate(&join.left);
558 let right_card = card_est.estimate(&join.right);
559 let join_cost = self.join_cost_with_children(
560 join,
561 cardinality,
562 Some(left_card),
563 Some(right_card),
564 );
565 left_cost + right_cost + join_cost
566 }
567 LogicalOperator::LeftJoin(lj) => {
568 let left_cost = self.estimate_tree_inner(&lj.left, card_est);
569 let right_cost = self.estimate_tree_inner(&lj.right, card_est);
570 let left_card = card_est.estimate(&lj.left);
571 let right_card = card_est.estimate(&lj.right);
572 let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
573 left_cost + right_cost + join_cost
574 }
575 LogicalOperator::Aggregate(agg) => {
576 let child_cost = self.estimate_tree_inner(&agg.input, card_est);
577 child_cost + self.aggregate_cost(agg, cardinality)
578 }
579 LogicalOperator::Sort(sort) => {
580 let child_cost = self.estimate_tree_inner(&sort.input, card_est);
581 child_cost + self.sort_cost(sort, cardinality)
582 }
583 LogicalOperator::Distinct(distinct) => {
584 let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
585 child_cost + self.distinct_cost(distinct, cardinality)
586 }
587 LogicalOperator::Limit(limit) => {
588 let child_cost = self.estimate_tree_inner(&limit.input, card_est);
589 child_cost + self.limit_cost(limit, cardinality)
590 }
591 LogicalOperator::Skip(skip) => {
592 let child_cost = self.estimate_tree_inner(&skip.input, card_est);
593 child_cost + self.skip_cost(skip, cardinality)
594 }
595 LogicalOperator::Return(ret) => {
596 let child_cost = self.estimate_tree_inner(&ret.input, card_est);
597 child_cost + self.return_cost(ret, cardinality)
598 }
599 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
600 LogicalOperator::VectorJoin(join) => {
601 let child_cost = self.estimate_tree_inner(&join.input, card_est);
602 child_cost + self.vector_join_cost(join, cardinality)
603 }
604 LogicalOperator::MultiWayJoin(mwj) => {
605 let mut children_cost = Cost::zero();
606 for input in &mwj.inputs {
607 children_cost += self.estimate_tree_inner(input, card_est);
608 }
609 children_cost + self.multi_way_join_cost(mwj, cardinality)
610 }
611 LogicalOperator::Empty => Cost::zero(),
612 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
613 }
614 }
615
616 #[must_use]
618 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
619 if a.total() <= b.total() { a } else { b }
620 }
621
622 #[must_use]
638 pub fn leapfrog_join_cost(
639 &self,
640 num_relations: usize,
641 cardinalities: &[f64],
642 output_cardinality: f64,
643 ) -> Cost {
644 if cardinalities.is_empty() {
645 return Cost::zero();
646 }
647
648 let total_input: f64 = cardinalities.iter().sum();
649 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
650
651 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
656 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
657 } else {
658 output_cardinality * self.cpu_tuple_cost
659 };
660
661 let output_cost = output_cardinality * self.cpu_tuple_cost;
663
664 let memory = total_input * self.avg_tuple_size * 2.0;
666
667 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
668 }
669
670 #[must_use]
674 pub fn prefer_leapfrog_join(
675 &self,
676 num_relations: usize,
677 cardinalities: &[f64],
678 output_cardinality: f64,
679 ) -> bool {
680 if num_relations < 3 || cardinalities.len() < 3 {
681 return false;
683 }
684
685 let leapfrog_cost =
686 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
687
688 let mut hash_cascade_cost = Cost::zero();
692 let mut intermediate_cardinality = cardinalities[0];
693
694 for card in &cardinalities[1..] {
695 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
698 left: Box::new(LogicalOperator::Empty),
699 right: Box::new(LogicalOperator::Empty),
700 join_type: JoinType::Inner,
701 conditions: vec![],
702 };
703 hash_cascade_cost += self.join_cost(&join, join_output);
704 intermediate_cardinality = join_output;
705 }
706
707 leapfrog_cost.total() < hash_cascade_cost.total()
708 }
709
710 #[must_use]
718 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
719 if num_hops <= 1 || avg_fanout <= 1.0 {
720 return 1.0; }
722
723 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
730 let hops_i32 = num_hops as i32;
731 let full_size = avg_fanout.powi(hops_i32);
732 let factorized_size = if avg_fanout > 1.0 {
733 (avg_fanout.powi(hops_i32 + 1) - 1.0) / (avg_fanout - 1.0)
734 } else {
735 num_hops as f64
736 };
737
738 (factorized_size / full_size).min(1.0)
739 }
740}
741
742impl Default for CostModel {
743 fn default() -> Self {
744 Self::new()
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 use crate::query::plan::{
752 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
753 PathMode, Projection, ReturnItem, SortOrder,
754 };
755
756 #[test]
757 fn test_cost_addition() {
758 let a = Cost::cpu(10.0).with_io(5.0);
759 let b = Cost::cpu(20.0).with_memory(100.0);
760 let c = a + b;
761
762 assert!((c.cpu - 30.0).abs() < 0.001);
763 assert!((c.io - 5.0).abs() < 0.001);
764 assert!((c.memory - 100.0).abs() < 0.001);
765 }
766
767 #[test]
768 fn test_cost_total() {
769 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
770 assert!((cost.total() - 30.0).abs() < 0.001);
772 }
773
774 #[test]
775 fn test_cost_model_node_scan() {
776 let model = CostModel::new();
777 let scan = NodeScanOp {
778 variable: "n".to_string(),
779 label: Some("Person".to_string()),
780 input: None,
781 };
782 let cost = model.node_scan_cost(&scan, 1000.0);
783
784 assert!(cost.cpu > 0.0);
785 assert!(cost.io > 0.0);
786 }
787
788 #[test]
789 fn test_cost_model_sort() {
790 let model = CostModel::new();
791 let sort = SortOp {
792 keys: vec![],
793 input: Box::new(LogicalOperator::Empty),
794 };
795
796 let cost_100 = model.sort_cost(&sort, 100.0);
797 let cost_1000 = model.sort_cost(&sort, 1000.0);
798
799 assert!(cost_1000.total() > cost_100.total());
801 }
802
803 #[test]
804 fn test_cost_zero() {
805 let cost = Cost::zero();
806 assert!((cost.cpu).abs() < 0.001);
807 assert!((cost.io).abs() < 0.001);
808 assert!((cost.memory).abs() < 0.001);
809 assert!((cost.network).abs() < 0.001);
810 assert!((cost.total()).abs() < 0.001);
811 }
812
813 #[test]
814 fn test_cost_add_assign() {
815 let mut cost = Cost::cpu(10.0);
816 cost += Cost::cpu(5.0).with_io(2.0);
817 assert!((cost.cpu - 15.0).abs() < 0.001);
818 assert!((cost.io - 2.0).abs() < 0.001);
819 }
820
821 #[test]
822 fn test_cost_total_weighted() {
823 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
824 let total = cost.total_weighted(2.0, 5.0, 0.5);
826 assert!((total - 80.0).abs() < 0.001);
827 }
828
829 #[test]
830 fn test_cost_model_filter() {
831 let model = CostModel::new();
832 let filter = FilterOp {
833 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
834 input: Box::new(LogicalOperator::Empty),
835 pushdown_hint: None,
836 };
837 let cost = model.filter_cost(&filter, 1000.0);
838
839 assert!(cost.cpu > 0.0);
841 assert!((cost.io).abs() < 0.001);
842 }
843
844 #[test]
845 fn test_cost_model_project() {
846 let model = CostModel::new();
847 let project = ProjectOp {
848 projections: vec![
849 Projection {
850 expression: LogicalExpression::Variable("a".to_string()),
851 alias: None,
852 },
853 Projection {
854 expression: LogicalExpression::Variable("b".to_string()),
855 alias: None,
856 },
857 ],
858 input: Box::new(LogicalOperator::Empty),
859 pass_through_input: false,
860 };
861 let cost = model.project_cost(&project, 1000.0);
862
863 assert!(cost.cpu > 0.0);
865 }
866
867 #[test]
868 fn test_cost_model_expand() {
869 let model = CostModel::new();
870 let expand = ExpandOp {
871 from_variable: "a".to_string(),
872 to_variable: "b".to_string(),
873 edge_variable: None,
874 direction: ExpandDirection::Outgoing,
875 edge_types: vec![],
876 min_hops: 1,
877 max_hops: Some(1),
878 input: Box::new(LogicalOperator::Empty),
879 path_alias: None,
880 path_mode: PathMode::Walk,
881 };
882 let cost = model.expand_cost(&expand, 1000.0);
883
884 assert!(cost.cpu > 0.0);
886 }
887
888 #[test]
889 fn test_cost_model_expand_with_edge_type_stats() {
890 let mut degrees = std::collections::HashMap::new();
891 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
895
896 let knows_out = ExpandOp {
898 from_variable: "a".to_string(),
899 to_variable: "b".to_string(),
900 edge_variable: None,
901 direction: ExpandDirection::Outgoing,
902 edge_types: vec!["KNOWS".to_string()],
903 min_hops: 1,
904 max_hops: Some(1),
905 input: Box::new(LogicalOperator::Empty),
906 path_alias: None,
907 path_mode: PathMode::Walk,
908 };
909 let cost_knows = model.expand_cost(&knows_out, 1000.0);
910
911 let works_out = ExpandOp {
913 from_variable: "a".to_string(),
914 to_variable: "b".to_string(),
915 edge_variable: None,
916 direction: ExpandDirection::Outgoing,
917 edge_types: vec!["WORKS_AT".to_string()],
918 min_hops: 1,
919 max_hops: Some(1),
920 input: Box::new(LogicalOperator::Empty),
921 path_alias: None,
922 path_mode: PathMode::Walk,
923 };
924 let cost_works = model.expand_cost(&works_out, 1000.0);
925
926 assert!(
928 cost_knows.cpu > cost_works.cpu,
929 "KNOWS(5) should cost more than WORKS_AT(1)"
930 );
931
932 let works_in = ExpandOp {
934 from_variable: "c".to_string(),
935 to_variable: "p".to_string(),
936 edge_variable: None,
937 direction: ExpandDirection::Incoming,
938 edge_types: vec!["WORKS_AT".to_string()],
939 min_hops: 1,
940 max_hops: Some(1),
941 input: Box::new(LogicalOperator::Empty),
942 path_alias: None,
943 path_mode: PathMode::Walk,
944 };
945 let cost_works_in = model.expand_cost(&works_in, 1000.0);
946
947 assert!(
949 cost_works_in.cpu > cost_knows.cpu,
950 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
951 );
952 }
953
954 #[test]
955 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
956 let model = CostModel::new().with_avg_fanout(7.0);
957 let expand = ExpandOp {
958 from_variable: "a".to_string(),
959 to_variable: "b".to_string(),
960 edge_variable: None,
961 direction: ExpandDirection::Outgoing,
962 edge_types: vec!["UNKNOWN_TYPE".to_string()],
963 min_hops: 1,
964 max_hops: Some(1),
965 input: Box::new(LogicalOperator::Empty),
966 path_alias: None,
967 path_mode: PathMode::Walk,
968 };
969 let cost_unknown = model.expand_cost(&expand, 1000.0);
970
971 let expand_no_type = ExpandOp {
973 from_variable: "a".to_string(),
974 to_variable: "b".to_string(),
975 edge_variable: None,
976 direction: ExpandDirection::Outgoing,
977 edge_types: vec![],
978 min_hops: 1,
979 max_hops: Some(1),
980 input: Box::new(LogicalOperator::Empty),
981 path_alias: None,
982 path_mode: PathMode::Walk,
983 };
984 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
985
986 assert!(
988 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
989 "Unknown edge type should use global fanout"
990 );
991 }
992
993 #[test]
994 fn test_cost_model_hash_join() {
995 let model = CostModel::new();
996 let join = JoinOp {
997 left: Box::new(LogicalOperator::Empty),
998 right: Box::new(LogicalOperator::Empty),
999 join_type: JoinType::Inner,
1000 conditions: vec![JoinCondition {
1001 left: LogicalExpression::Variable("a".to_string()),
1002 right: LogicalExpression::Variable("b".to_string()),
1003 }],
1004 };
1005 let cost = model.join_cost(&join, 10000.0);
1006
1007 assert!(cost.cpu > 0.0);
1009 assert!(cost.memory > 0.0);
1010 }
1011
1012 #[test]
1013 fn test_cost_model_cross_join() {
1014 let model = CostModel::new();
1015 let join = JoinOp {
1016 left: Box::new(LogicalOperator::Empty),
1017 right: Box::new(LogicalOperator::Empty),
1018 join_type: JoinType::Cross,
1019 conditions: vec![],
1020 };
1021 let cost = model.join_cost(&join, 1000000.0);
1022
1023 assert!(cost.cpu > 0.0);
1025 }
1026
1027 #[test]
1028 fn test_cost_model_semi_join() {
1029 let model = CostModel::new();
1030 let join = JoinOp {
1031 left: Box::new(LogicalOperator::Empty),
1032 right: Box::new(LogicalOperator::Empty),
1033 join_type: JoinType::Semi,
1034 conditions: vec![],
1035 };
1036 let cost_semi = model.join_cost(&join, 1000.0);
1037
1038 let inner_join = JoinOp {
1039 left: Box::new(LogicalOperator::Empty),
1040 right: Box::new(LogicalOperator::Empty),
1041 join_type: JoinType::Inner,
1042 conditions: vec![],
1043 };
1044 let cost_inner = model.join_cost(&inner_join, 1000.0);
1045
1046 assert!(cost_semi.cpu > 0.0);
1048 assert!(cost_inner.cpu > 0.0);
1049 }
1050
1051 #[test]
1052 fn test_cost_model_aggregate() {
1053 let model = CostModel::new();
1054 let agg = AggregateOp {
1055 group_by: vec![],
1056 aggregates: vec![
1057 AggregateExpr {
1058 function: AggregateFunction::Count,
1059 expression: None,
1060 expression2: None,
1061 distinct: false,
1062 alias: Some("cnt".to_string()),
1063 percentile: None,
1064 separator: None,
1065 },
1066 AggregateExpr {
1067 function: AggregateFunction::Sum,
1068 expression: Some(LogicalExpression::Variable("x".to_string())),
1069 expression2: None,
1070 distinct: false,
1071 alias: Some("total".to_string()),
1072 percentile: None,
1073 separator: None,
1074 },
1075 ],
1076 input: Box::new(LogicalOperator::Empty),
1077 having: None,
1078 };
1079 let cost = model.aggregate_cost(&agg, 1000.0);
1080
1081 assert!(cost.cpu > 0.0);
1083 assert!(cost.memory > 0.0);
1084 }
1085
1086 #[test]
1087 fn test_cost_model_distinct() {
1088 let model = CostModel::new();
1089 let distinct = DistinctOp {
1090 input: Box::new(LogicalOperator::Empty),
1091 columns: None,
1092 };
1093 let cost = model.distinct_cost(&distinct, 1000.0);
1094
1095 assert!(cost.cpu > 0.0);
1097 assert!(cost.memory > 0.0);
1098 }
1099
1100 #[test]
1101 fn test_cost_model_limit() {
1102 let model = CostModel::new();
1103 let limit = LimitOp {
1104 count: 10.into(),
1105 input: Box::new(LogicalOperator::Empty),
1106 };
1107 let cost = model.limit_cost(&limit, 1000.0);
1108
1109 assert!(cost.cpu > 0.0);
1111 assert!(cost.cpu < 1.0); }
1113
1114 #[test]
1115 fn test_cost_model_skip() {
1116 let model = CostModel::new();
1117 let skip = SkipOp {
1118 count: 100.into(),
1119 input: Box::new(LogicalOperator::Empty),
1120 };
1121 let cost = model.skip_cost(&skip, 1000.0);
1122
1123 assert!(cost.cpu > 0.0);
1125 }
1126
1127 #[test]
1128 fn test_cost_model_return() {
1129 let model = CostModel::new();
1130 let ret = ReturnOp {
1131 items: vec![
1132 ReturnItem {
1133 expression: LogicalExpression::Variable("a".to_string()),
1134 alias: None,
1135 },
1136 ReturnItem {
1137 expression: LogicalExpression::Variable("b".to_string()),
1138 alias: None,
1139 },
1140 ],
1141 distinct: false,
1142 input: Box::new(LogicalOperator::Empty),
1143 };
1144 let cost = model.return_cost(&ret, 1000.0);
1145
1146 assert!(cost.cpu > 0.0);
1148 }
1149
1150 #[test]
1151 fn test_cost_cheaper() {
1152 let model = CostModel::new();
1153 let cheap = Cost::cpu(10.0);
1154 let expensive = Cost::cpu(100.0);
1155
1156 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1157 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1158 }
1159
1160 #[test]
1161 fn test_cost_comparison_prefers_lower_total() {
1162 let model = CostModel::new();
1163 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1165 let io_heavy = Cost::cpu(10.0).with_io(20.0);
1167
1168 assert!(cpu_heavy.total() < io_heavy.total());
1170 assert_eq!(
1171 model.cheaper(&cpu_heavy, &io_heavy).total(),
1172 cpu_heavy.total()
1173 );
1174 }
1175
1176 #[test]
1177 fn test_cost_model_sort_with_keys() {
1178 let model = CostModel::new();
1179 let sort_single = SortOp {
1180 keys: vec![crate::query::plan::SortKey {
1181 expression: LogicalExpression::Variable("a".to_string()),
1182 order: SortOrder::Ascending,
1183 nulls: None,
1184 }],
1185 input: Box::new(LogicalOperator::Empty),
1186 };
1187 let sort_multi = SortOp {
1188 keys: vec![
1189 crate::query::plan::SortKey {
1190 expression: LogicalExpression::Variable("a".to_string()),
1191 order: SortOrder::Ascending,
1192 nulls: None,
1193 },
1194 crate::query::plan::SortKey {
1195 expression: LogicalExpression::Variable("b".to_string()),
1196 order: SortOrder::Descending,
1197 nulls: None,
1198 },
1199 ],
1200 input: Box::new(LogicalOperator::Empty),
1201 };
1202
1203 let cost_single = model.sort_cost(&sort_single, 1000.0);
1204 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1205
1206 assert!(cost_multi.cpu > cost_single.cpu);
1208 }
1209
1210 #[test]
1211 fn test_cost_model_empty_operator() {
1212 let model = CostModel::new();
1213 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1214 assert!((cost.total()).abs() < 0.001);
1215 }
1216
1217 #[test]
1218 fn test_cost_model_default() {
1219 let model = CostModel::default();
1220 let scan = NodeScanOp {
1221 variable: "n".to_string(),
1222 label: None,
1223 input: None,
1224 };
1225 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1226 assert!(cost.total() > 0.0);
1227 }
1228
1229 #[test]
1230 fn test_leapfrog_join_cost() {
1231 let model = CostModel::new();
1232
1233 let cardinalities = vec![1000.0, 1000.0, 1000.0];
1235 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1236
1237 assert!(cost.cpu > 0.0);
1239 assert!(cost.memory > 0.0);
1241 }
1242
1243 #[test]
1244 fn test_leapfrog_join_cost_empty() {
1245 let model = CostModel::new();
1246 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1247 assert!((cost.total()).abs() < 0.001);
1248 }
1249
1250 #[test]
1251 fn test_prefer_leapfrog_join_for_triangles() {
1252 let model = CostModel::new();
1253
1254 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1256 let output = 1000.0;
1257
1258 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1259
1260 assert!(leapfrog_cost.cpu > 0.0);
1262 assert!(leapfrog_cost.memory > 0.0);
1263
1264 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1267 }
1269
1270 #[test]
1271 fn test_prefer_leapfrog_join_binary_case() {
1272 let model = CostModel::new();
1273
1274 let cardinalities = vec![1000.0, 1000.0];
1276 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1277 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1278 }
1279
1280 #[test]
1281 fn test_factorized_benefit_single_hop() {
1282 let model = CostModel::new();
1283
1284 let benefit = model.factorized_benefit(10.0, 1);
1286 assert!(
1287 (benefit - 1.0).abs() < 0.001,
1288 "Single hop should have no benefit"
1289 );
1290 }
1291
1292 #[test]
1293 fn test_factorized_benefit_multi_hop() {
1294 let model = CostModel::new();
1295
1296 let benefit = model.factorized_benefit(10.0, 3);
1298
1299 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1303 assert!(benefit > 0.0, "Benefit should be positive");
1304 }
1305
1306 #[test]
1307 fn test_factorized_benefit_low_fanout() {
1308 let model = CostModel::new();
1309
1310 let benefit = model.factorized_benefit(1.5, 2);
1312 assert!(
1313 benefit <= 1.0,
1314 "Low fanout still benefits from factorization"
1315 );
1316 }
1317
1318 #[test]
1319 fn test_node_scan_uses_label_cardinality_for_io() {
1320 let mut label_cards = std::collections::HashMap::new();
1321 label_cards.insert("Person".to_string(), 500_u64);
1322 label_cards.insert("Company".to_string(), 50_u64);
1323
1324 let model = CostModel::new()
1325 .with_label_cardinalities(label_cards)
1326 .with_graph_totals(550, 1000);
1327
1328 let person_scan = NodeScanOp {
1329 variable: "n".to_string(),
1330 label: Some("Person".to_string()),
1331 input: None,
1332 };
1333 let company_scan = NodeScanOp {
1334 variable: "n".to_string(),
1335 label: Some("Company".to_string()),
1336 input: None,
1337 };
1338
1339 let person_cost = model.node_scan_cost(&person_scan, 500.0);
1340 let company_cost = model.node_scan_cost(&company_scan, 50.0);
1341
1342 assert!(
1344 person_cost.io > company_cost.io * 5.0,
1345 "Person ({}) should have much higher IO than Company ({})",
1346 person_cost.io,
1347 company_cost.io
1348 );
1349 }
1350
1351 #[test]
1352 fn test_node_scan_unlabeled_uses_total_nodes() {
1353 let model = CostModel::new().with_graph_totals(10_000, 50_000);
1354
1355 let scan = NodeScanOp {
1356 variable: "n".to_string(),
1357 label: None,
1358 input: None,
1359 };
1360
1361 let cost = model.node_scan_cost(&scan, 10_000.0);
1362 let expected_pages = (10_000.0 * 100.0) / 8192.0;
1363 assert!(
1364 (cost.io - expected_pages).abs() < 0.1,
1365 "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1366 cost.io,
1367 expected_pages
1368 );
1369 }
1370
1371 #[test]
1372 fn test_join_cost_with_actual_child_cardinalities() {
1373 let model = CostModel::new();
1374 let join = JoinOp {
1375 left: Box::new(LogicalOperator::Empty),
1376 right: Box::new(LogicalOperator::Empty),
1377 join_type: JoinType::Inner,
1378 conditions: vec![JoinCondition {
1379 left: LogicalExpression::Variable("a".to_string()),
1380 right: LogicalExpression::Variable("b".to_string()),
1381 }],
1382 };
1383
1384 let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1386
1387 let cost_sqrt = model.join_cost(&join, 500.0);
1389
1390 assert!(
1394 cost_actual.cpu > cost_sqrt.cpu,
1395 "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1396 cost_actual.cpu,
1397 cost_sqrt.cpu
1398 );
1399 }
1400
1401 #[test]
1402 fn test_expand_multi_edge_types() {
1403 let mut degrees = std::collections::HashMap::new();
1404 degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1405 degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1406
1407 let model = CostModel::new().with_edge_type_degrees(degrees);
1408
1409 let multi_expand = ExpandOp {
1411 from_variable: "a".to_string(),
1412 to_variable: "b".to_string(),
1413 edge_variable: None,
1414 direction: ExpandDirection::Outgoing,
1415 edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1416 min_hops: 1,
1417 max_hops: Some(1),
1418 input: Box::new(LogicalOperator::Empty),
1419 path_alias: None,
1420 path_mode: PathMode::Walk,
1421 };
1422 let multi_cost = model.expand_cost(&multi_expand, 100.0);
1423
1424 let single_expand = ExpandOp {
1426 from_variable: "a".to_string(),
1427 to_variable: "b".to_string(),
1428 edge_variable: None,
1429 direction: ExpandDirection::Outgoing,
1430 edge_types: vec!["KNOWS".to_string()],
1431 min_hops: 1,
1432 max_hops: Some(1),
1433 input: Box::new(LogicalOperator::Empty),
1434 path_alias: None,
1435 path_mode: PathMode::Walk,
1436 };
1437 let single_cost = model.expand_cost(&single_expand, 100.0);
1438
1439 assert!(
1441 multi_cost.cpu > single_cost.cpu * 3.0,
1442 "Multi-type fanout ({}) should be much higher than single-type ({})",
1443 multi_cost.cpu,
1444 single_cost.cpu
1445 );
1446 }
1447
1448 #[test]
1449 fn test_recursive_tree_cost() {
1450 use crate::query::optimizer::CardinalityEstimator;
1451
1452 let mut label_cards = std::collections::HashMap::new();
1453 label_cards.insert("Person".to_string(), 1000_u64);
1454
1455 let model = CostModel::new()
1456 .with_label_cardinalities(label_cards)
1457 .with_graph_totals(1000, 5000)
1458 .with_avg_fanout(5.0);
1459
1460 let mut card_est = CardinalityEstimator::new();
1461 card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1462
1463 let plan = LogicalOperator::Return(ReturnOp {
1465 items: vec![ReturnItem {
1466 expression: LogicalExpression::Variable("n".to_string()),
1467 alias: None,
1468 }],
1469 distinct: false,
1470 input: Box::new(LogicalOperator::Filter(FilterOp {
1471 predicate: LogicalExpression::Binary {
1472 left: Box::new(LogicalExpression::Property {
1473 variable: "n".to_string(),
1474 property: "age".to_string(),
1475 }),
1476 op: crate::query::plan::BinaryOp::Gt,
1477 right: Box::new(LogicalExpression::Literal(
1478 grafeo_common::types::Value::Int64(30),
1479 )),
1480 },
1481 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1482 variable: "n".to_string(),
1483 label: Some("Person".to_string()),
1484 input: None,
1485 })),
1486 pushdown_hint: None,
1487 })),
1488 });
1489
1490 let tree_cost = model.estimate_tree(&plan, &card_est);
1491
1492 assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1494 assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1495
1496 let root_only_card = card_est.estimate(&plan);
1498 let root_only_cost = model.estimate(&plan, root_only_card);
1499
1500 assert!(
1502 tree_cost.total() > root_only_cost.total(),
1503 "Recursive tree cost ({}) should exceed root-only cost ({})",
1504 tree_cost.total(),
1505 root_only_cost.total()
1506 );
1507 }
1508
1509 #[test]
1510 fn test_statistics_driven_vs_default_cost() {
1511 let default_model = CostModel::new();
1512
1513 let mut label_cards = std::collections::HashMap::new();
1514 label_cards.insert("Person".to_string(), 100_u64);
1515 let stats_model = CostModel::new()
1516 .with_label_cardinalities(label_cards)
1517 .with_graph_totals(100, 500);
1518
1519 let scan = NodeScanOp {
1521 variable: "n".to_string(),
1522 label: Some("Person".to_string()),
1523 input: None,
1524 };
1525
1526 let default_cost = default_model.node_scan_cost(&scan, 100.0);
1527 let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1528
1529 assert!(
1533 (default_cost.io - stats_cost.io).abs() < 0.1,
1534 "When cardinality matches label size, costs should be similar"
1535 );
1536 }
1537
1538 #[test]
1539 fn test_leapfrog_join_cost_unit_min_cardinality() {
1540 let model = CostModel::new();
1541 let cost = model.leapfrog_join_cost(3, &[1.0, 100.0, 200.0], 50.0);
1543 assert!(cost.cpu > 0.0);
1544 assert!(cost.memory > 0.0);
1545 }
1546
1547 #[test]
1548 fn test_prefer_leapfrog_join_cardinalities_below_three() {
1549 let model = CostModel::new();
1550 assert!(!model.prefer_leapfrog_join(3, &[100.0, 200.0], 50.0));
1552 assert!(!model.prefer_leapfrog_join(5, &[], 10.0));
1553 }
1554
1555 #[test]
1556 fn test_factorized_benefit_zero_hops() {
1557 let model = CostModel::new();
1558 assert_eq!(model.factorized_benefit(10.0, 0), 1.0);
1560 }
1561
1562 #[test]
1563 fn test_factorized_benefit_unit_fanout_guard() {
1564 let model = CostModel::new();
1565 assert_eq!(model.factorized_benefit(1.0, 5), 1.0);
1567 }
1568}