1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LeftJoinOp,
7 LimitOp, LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8 VectorJoinOp, VectorScanOp,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17 pub cpu: f64,
19 pub io: f64,
21 pub memory: f64,
23 pub network: f64,
25}
26
27impl Cost {
28 #[must_use]
30 pub fn zero() -> Self {
31 Self {
32 cpu: 0.0,
33 io: 0.0,
34 memory: 0.0,
35 network: 0.0,
36 }
37 }
38
39 #[must_use]
41 pub fn cpu(cpu: f64) -> Self {
42 Self {
43 cpu,
44 io: 0.0,
45 memory: 0.0,
46 network: 0.0,
47 }
48 }
49
50 #[must_use]
52 pub fn with_io(mut self, io: f64) -> Self {
53 self.io = io;
54 self
55 }
56
57 #[must_use]
59 pub fn with_memory(mut self, memory: f64) -> Self {
60 self.memory = memory;
61 self
62 }
63
64 #[must_use]
68 pub fn total(&self) -> f64 {
69 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70 }
71
72 #[must_use]
74 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76 }
77}
78
79impl std::ops::Add for Cost {
80 type Output = Self;
81
82 fn add(self, other: Self) -> Self {
83 Self {
84 cpu: self.cpu + other.cpu,
85 io: self.io + other.io,
86 memory: self.memory + other.memory,
87 network: self.network + other.network,
88 }
89 }
90}
91
92impl std::ops::AddAssign for Cost {
93 fn add_assign(&mut self, other: Self) {
94 self.cpu += other.cpu;
95 self.io += other.io;
96 self.memory += other.memory;
97 self.network += other.network;
98 }
99}
100
101pub struct CostModel {
109 cpu_tuple_cost: f64,
111 hash_lookup_cost: f64,
113 sort_comparison_cost: f64,
115 avg_tuple_size: f64,
117 page_size: f64,
119 avg_fanout: f64,
121 edge_type_degrees: HashMap<String, (f64, f64)>,
123 label_cardinalities: HashMap<String, u64>,
125 total_nodes: u64,
127 total_edges: u64,
129}
130
131impl CostModel {
132 #[must_use]
134 pub fn new() -> Self {
135 Self {
136 cpu_tuple_cost: 0.01,
137 hash_lookup_cost: 0.03,
138 sort_comparison_cost: 0.02,
139 avg_tuple_size: 100.0,
140 page_size: 8192.0,
141 avg_fanout: 10.0,
142 edge_type_degrees: HashMap::new(),
143 label_cardinalities: HashMap::new(),
144 total_nodes: 0,
145 total_edges: 0,
146 }
147 }
148
149 #[must_use]
151 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153 self
154 }
155
156 #[must_use]
160 pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161 self.edge_type_degrees = degrees;
162 self
163 }
164
165 #[must_use]
167 pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168 self.label_cardinalities = cardinalities;
169 self
170 }
171
172 #[must_use]
174 pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175 self.total_nodes = total_nodes;
176 self.total_edges = total_edges;
177 self
178 }
179
180 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185 if expand.edge_types.is_empty() {
186 return self.avg_fanout;
187 }
188
189 let mut total_fanout = 0.0;
190 let mut all_found = true;
191
192 for edge_type in &expand.edge_types {
193 if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194 total_fanout += match expand.direction {
195 ExpandDirection::Outgoing => out_deg,
196 ExpandDirection::Incoming => in_deg,
197 ExpandDirection::Both => out_deg + in_deg,
198 };
199 } else {
200 all_found = false;
201 break;
202 }
203 }
204
205 if all_found {
206 total_fanout
207 } else {
208 self.avg_fanout
209 }
210 }
211
212 #[must_use]
214 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215 match op {
216 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227 LogicalOperator::Empty => Cost::zero(),
228 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230 LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231 LogicalOperator::LeftJoin(lj) => {
232 self.left_join_cost(lj, cardinality, cardinality.sqrt(), cardinality.sqrt())
233 }
234 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
235 }
236 }
237
238 fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
244 let scan_size = if let Some(label) = &scan.label {
246 self.label_cardinalities
247 .get(label)
248 .map_or(cardinality, |&count| count as f64)
249 } else if self.total_nodes > 0 {
250 self.total_nodes as f64
251 } else {
252 cardinality
253 };
254 let pages = (scan_size * self.avg_tuple_size) / self.page_size;
255 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
257 }
258
259 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
261 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
263 }
264
265 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
267 let expr_count = project.projections.len() as f64;
269 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
270 }
271
272 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
277 let fanout = self.fanout_for_expand(expand);
278 let lookup_cost = cardinality * self.hash_lookup_cost;
280 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
282 Cost::cpu(lookup_cost + output_cost)
283 }
284
285 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
290 self.join_cost_with_children(join, cardinality, None, None)
291 }
292
293 fn join_cost_with_children(
295 &self,
296 join: &JoinOp,
297 cardinality: f64,
298 left_cardinality: Option<f64>,
299 right_cardinality: Option<f64>,
300 ) -> Cost {
301 match join.join_type {
302 JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
303 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
304 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
306 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
307
308 let build_cost = build_cardinality * self.hash_lookup_cost;
309 let memory_cost = build_cardinality * self.avg_tuple_size;
310 let probe_cost = probe_cardinality * self.hash_lookup_cost;
311 let output_cost = cardinality * self.cpu_tuple_cost;
312
313 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
314 }
315 JoinType::Semi | JoinType::Anti => {
316 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
317 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
318
319 let build_cost = build_cardinality * self.hash_lookup_cost;
320 let probe_cost = probe_cardinality * self.hash_lookup_cost;
321
322 Cost::cpu(build_cost + probe_cost)
323 .with_memory(build_cardinality * self.avg_tuple_size)
324 }
325 }
326 }
327
328 fn left_join_cost(
334 &self,
335 _lj: &LeftJoinOp,
336 cardinality: f64,
337 left_card: f64,
338 right_card: f64,
339 ) -> Cost {
340 let build_cost = right_card * self.hash_lookup_cost;
341 let memory_cost = right_card * self.avg_tuple_size;
342 let probe_cost = left_card * self.hash_lookup_cost;
343 let output_cost = cardinality * self.cpu_tuple_cost;
344
345 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
346 }
347
348 fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
353 let n = mwj.inputs.len();
354 if n == 0 {
355 return Cost::zero();
356 }
357 let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
360 let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
361 self.leapfrog_join_cost(n, &cardinalities, cardinality)
362 }
363
364 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
366 let hash_cost = cardinality * self.hash_lookup_cost;
368
369 let agg_count = agg.aggregates.len() as f64;
371 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
372
373 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
376
377 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
378 }
379
380 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
382 if cardinality <= 1.0 {
383 return Cost::zero();
384 }
385
386 let comparisons = cardinality * cardinality.log2();
388 let key_count = sort.keys.len() as f64;
389
390 let memory_cost = cardinality * self.avg_tuple_size;
392
393 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
394 }
395
396 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
398 let hash_cost = cardinality * self.hash_lookup_cost;
400 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
403 }
404
405 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
407 Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
409 }
410
411 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
413 Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
415 }
416
417 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
419 let expr_count = ret.items.len() as f64;
421 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
422 }
423
424 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
429 let k = scan.k as f64;
431
432 let ef = 64.0;
435 let n = cardinality.max(1.0);
436 let search_cost = if scan.index_name.is_some() {
437 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
440 n * self.cpu_tuple_cost * 10.0
442 };
443
444 let memory = k * self.avg_tuple_size * 2.0;
446
447 Cost::cpu(search_cost).with_memory(memory)
448 }
449
450 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
454 let k = join.k as f64;
455
456 let per_row_search_cost = if join.index_name.is_some() {
459 let ef = 64.0;
461 let n = cardinality.max(1.0);
462 ef * n.ln() * self.cpu_tuple_cost * 10.0
463 } else {
464 cardinality * self.cpu_tuple_cost * 10.0
466 };
467
468 let input_cardinality = (cardinality / k).max(1.0);
471 let total_search_cost = input_cardinality * per_row_search_cost;
472
473 let memory = cardinality * self.avg_tuple_size;
475
476 Cost::cpu(total_search_cost).with_memory(memory)
477 }
478
479 #[must_use]
485 pub fn estimate_tree(
486 &self,
487 op: &LogicalOperator,
488 card_estimator: &super::CardinalityEstimator,
489 ) -> Cost {
490 self.estimate_tree_inner(op, card_estimator)
491 }
492
493 fn estimate_tree_inner(
494 &self,
495 op: &LogicalOperator,
496 card_est: &super::CardinalityEstimator,
497 ) -> Cost {
498 let cardinality = card_est.estimate(op);
499
500 match op {
501 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
502 LogicalOperator::Filter(filter) => {
503 let child_cost = self.estimate_tree_inner(&filter.input, card_est);
504 child_cost + self.filter_cost(filter, cardinality)
505 }
506 LogicalOperator::Project(project) => {
507 let child_cost = self.estimate_tree_inner(&project.input, card_est);
508 child_cost + self.project_cost(project, cardinality)
509 }
510 LogicalOperator::Expand(expand) => {
511 let child_cost = self.estimate_tree_inner(&expand.input, card_est);
512 child_cost + self.expand_cost(expand, cardinality)
513 }
514 LogicalOperator::Join(join) => {
515 let left_cost = self.estimate_tree_inner(&join.left, card_est);
516 let right_cost = self.estimate_tree_inner(&join.right, card_est);
517 let left_card = card_est.estimate(&join.left);
518 let right_card = card_est.estimate(&join.right);
519 let join_cost = self.join_cost_with_children(
520 join,
521 cardinality,
522 Some(left_card),
523 Some(right_card),
524 );
525 left_cost + right_cost + join_cost
526 }
527 LogicalOperator::LeftJoin(lj) => {
528 let left_cost = self.estimate_tree_inner(&lj.left, card_est);
529 let right_cost = self.estimate_tree_inner(&lj.right, card_est);
530 let left_card = card_est.estimate(&lj.left);
531 let right_card = card_est.estimate(&lj.right);
532 let join_cost = self.left_join_cost(lj, cardinality, left_card, right_card);
533 left_cost + right_cost + join_cost
534 }
535 LogicalOperator::Aggregate(agg) => {
536 let child_cost = self.estimate_tree_inner(&agg.input, card_est);
537 child_cost + self.aggregate_cost(agg, cardinality)
538 }
539 LogicalOperator::Sort(sort) => {
540 let child_cost = self.estimate_tree_inner(&sort.input, card_est);
541 child_cost + self.sort_cost(sort, cardinality)
542 }
543 LogicalOperator::Distinct(distinct) => {
544 let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
545 child_cost + self.distinct_cost(distinct, cardinality)
546 }
547 LogicalOperator::Limit(limit) => {
548 let child_cost = self.estimate_tree_inner(&limit.input, card_est);
549 child_cost + self.limit_cost(limit, cardinality)
550 }
551 LogicalOperator::Skip(skip) => {
552 let child_cost = self.estimate_tree_inner(&skip.input, card_est);
553 child_cost + self.skip_cost(skip, cardinality)
554 }
555 LogicalOperator::Return(ret) => {
556 let child_cost = self.estimate_tree_inner(&ret.input, card_est);
557 child_cost + self.return_cost(ret, cardinality)
558 }
559 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
560 LogicalOperator::VectorJoin(join) => {
561 let child_cost = self.estimate_tree_inner(&join.input, card_est);
562 child_cost + self.vector_join_cost(join, cardinality)
563 }
564 LogicalOperator::MultiWayJoin(mwj) => {
565 let mut children_cost = Cost::zero();
566 for input in &mwj.inputs {
567 children_cost += self.estimate_tree_inner(input, card_est);
568 }
569 children_cost + self.multi_way_join_cost(mwj, cardinality)
570 }
571 LogicalOperator::Empty => Cost::zero(),
572 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
573 }
574 }
575
576 #[must_use]
578 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
579 if a.total() <= b.total() { a } else { b }
580 }
581
582 #[must_use]
598 pub fn leapfrog_join_cost(
599 &self,
600 num_relations: usize,
601 cardinalities: &[f64],
602 output_cardinality: f64,
603 ) -> Cost {
604 if cardinalities.is_empty() {
605 return Cost::zero();
606 }
607
608 let total_input: f64 = cardinalities.iter().sum();
609 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
610
611 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
616 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
617 } else {
618 output_cardinality * self.cpu_tuple_cost
619 };
620
621 let output_cost = output_cardinality * self.cpu_tuple_cost;
623
624 let memory = total_input * self.avg_tuple_size * 2.0;
626
627 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
628 }
629
630 #[must_use]
634 pub fn prefer_leapfrog_join(
635 &self,
636 num_relations: usize,
637 cardinalities: &[f64],
638 output_cardinality: f64,
639 ) -> bool {
640 if num_relations < 3 || cardinalities.len() < 3 {
641 return false;
643 }
644
645 let leapfrog_cost =
646 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
647
648 let mut hash_cascade_cost = Cost::zero();
652 let mut intermediate_cardinality = cardinalities[0];
653
654 for card in &cardinalities[1..] {
655 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
658 left: Box::new(LogicalOperator::Empty),
659 right: Box::new(LogicalOperator::Empty),
660 join_type: JoinType::Inner,
661 conditions: vec![],
662 };
663 hash_cascade_cost += self.join_cost(&join, join_output);
664 intermediate_cardinality = join_output;
665 }
666
667 leapfrog_cost.total() < hash_cascade_cost.total()
668 }
669
670 #[must_use]
678 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
679 if num_hops <= 1 || avg_fanout <= 1.0 {
680 return 1.0; }
682
683 let full_size = avg_fanout.powi(num_hops as i32);
689 let factorized_size = if avg_fanout > 1.0 {
690 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
691 } else {
692 num_hops as f64
693 };
694
695 (factorized_size / full_size).min(1.0)
696 }
697}
698
699impl Default for CostModel {
700 fn default() -> Self {
701 Self::new()
702 }
703}
704
705#[cfg(test)]
706mod tests {
707 use super::*;
708 use crate::query::plan::{
709 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
710 PathMode, Projection, ReturnItem, SortOrder,
711 };
712
713 #[test]
714 fn test_cost_addition() {
715 let a = Cost::cpu(10.0).with_io(5.0);
716 let b = Cost::cpu(20.0).with_memory(100.0);
717 let c = a + b;
718
719 assert!((c.cpu - 30.0).abs() < 0.001);
720 assert!((c.io - 5.0).abs() < 0.001);
721 assert!((c.memory - 100.0).abs() < 0.001);
722 }
723
724 #[test]
725 fn test_cost_total() {
726 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
727 assert!((cost.total() - 30.0).abs() < 0.001);
729 }
730
731 #[test]
732 fn test_cost_model_node_scan() {
733 let model = CostModel::new();
734 let scan = NodeScanOp {
735 variable: "n".to_string(),
736 label: Some("Person".to_string()),
737 input: None,
738 };
739 let cost = model.node_scan_cost(&scan, 1000.0);
740
741 assert!(cost.cpu > 0.0);
742 assert!(cost.io > 0.0);
743 }
744
745 #[test]
746 fn test_cost_model_sort() {
747 let model = CostModel::new();
748 let sort = SortOp {
749 keys: vec![],
750 input: Box::new(LogicalOperator::Empty),
751 };
752
753 let cost_100 = model.sort_cost(&sort, 100.0);
754 let cost_1000 = model.sort_cost(&sort, 1000.0);
755
756 assert!(cost_1000.total() > cost_100.total());
758 }
759
760 #[test]
761 fn test_cost_zero() {
762 let cost = Cost::zero();
763 assert!((cost.cpu).abs() < 0.001);
764 assert!((cost.io).abs() < 0.001);
765 assert!((cost.memory).abs() < 0.001);
766 assert!((cost.network).abs() < 0.001);
767 assert!((cost.total()).abs() < 0.001);
768 }
769
770 #[test]
771 fn test_cost_add_assign() {
772 let mut cost = Cost::cpu(10.0);
773 cost += Cost::cpu(5.0).with_io(2.0);
774 assert!((cost.cpu - 15.0).abs() < 0.001);
775 assert!((cost.io - 2.0).abs() < 0.001);
776 }
777
778 #[test]
779 fn test_cost_total_weighted() {
780 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
781 let total = cost.total_weighted(2.0, 5.0, 0.5);
783 assert!((total - 80.0).abs() < 0.001);
784 }
785
786 #[test]
787 fn test_cost_model_filter() {
788 let model = CostModel::new();
789 let filter = FilterOp {
790 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
791 input: Box::new(LogicalOperator::Empty),
792 pushdown_hint: None,
793 };
794 let cost = model.filter_cost(&filter, 1000.0);
795
796 assert!(cost.cpu > 0.0);
798 assert!((cost.io).abs() < 0.001);
799 }
800
801 #[test]
802 fn test_cost_model_project() {
803 let model = CostModel::new();
804 let project = ProjectOp {
805 projections: vec![
806 Projection {
807 expression: LogicalExpression::Variable("a".to_string()),
808 alias: None,
809 },
810 Projection {
811 expression: LogicalExpression::Variable("b".to_string()),
812 alias: None,
813 },
814 ],
815 input: Box::new(LogicalOperator::Empty),
816 pass_through_input: false,
817 };
818 let cost = model.project_cost(&project, 1000.0);
819
820 assert!(cost.cpu > 0.0);
822 }
823
824 #[test]
825 fn test_cost_model_expand() {
826 let model = CostModel::new();
827 let expand = ExpandOp {
828 from_variable: "a".to_string(),
829 to_variable: "b".to_string(),
830 edge_variable: None,
831 direction: ExpandDirection::Outgoing,
832 edge_types: vec![],
833 min_hops: 1,
834 max_hops: Some(1),
835 input: Box::new(LogicalOperator::Empty),
836 path_alias: None,
837 path_mode: PathMode::Walk,
838 };
839 let cost = model.expand_cost(&expand, 1000.0);
840
841 assert!(cost.cpu > 0.0);
843 }
844
845 #[test]
846 fn test_cost_model_expand_with_edge_type_stats() {
847 let mut degrees = std::collections::HashMap::new();
848 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
852
853 let knows_out = ExpandOp {
855 from_variable: "a".to_string(),
856 to_variable: "b".to_string(),
857 edge_variable: None,
858 direction: ExpandDirection::Outgoing,
859 edge_types: vec!["KNOWS".to_string()],
860 min_hops: 1,
861 max_hops: Some(1),
862 input: Box::new(LogicalOperator::Empty),
863 path_alias: None,
864 path_mode: PathMode::Walk,
865 };
866 let cost_knows = model.expand_cost(&knows_out, 1000.0);
867
868 let works_out = ExpandOp {
870 from_variable: "a".to_string(),
871 to_variable: "b".to_string(),
872 edge_variable: None,
873 direction: ExpandDirection::Outgoing,
874 edge_types: vec!["WORKS_AT".to_string()],
875 min_hops: 1,
876 max_hops: Some(1),
877 input: Box::new(LogicalOperator::Empty),
878 path_alias: None,
879 path_mode: PathMode::Walk,
880 };
881 let cost_works = model.expand_cost(&works_out, 1000.0);
882
883 assert!(
885 cost_knows.cpu > cost_works.cpu,
886 "KNOWS(5) should cost more than WORKS_AT(1)"
887 );
888
889 let works_in = ExpandOp {
891 from_variable: "c".to_string(),
892 to_variable: "p".to_string(),
893 edge_variable: None,
894 direction: ExpandDirection::Incoming,
895 edge_types: vec!["WORKS_AT".to_string()],
896 min_hops: 1,
897 max_hops: Some(1),
898 input: Box::new(LogicalOperator::Empty),
899 path_alias: None,
900 path_mode: PathMode::Walk,
901 };
902 let cost_works_in = model.expand_cost(&works_in, 1000.0);
903
904 assert!(
906 cost_works_in.cpu > cost_knows.cpu,
907 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
908 );
909 }
910
911 #[test]
912 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
913 let model = CostModel::new().with_avg_fanout(7.0);
914 let expand = ExpandOp {
915 from_variable: "a".to_string(),
916 to_variable: "b".to_string(),
917 edge_variable: None,
918 direction: ExpandDirection::Outgoing,
919 edge_types: vec!["UNKNOWN_TYPE".to_string()],
920 min_hops: 1,
921 max_hops: Some(1),
922 input: Box::new(LogicalOperator::Empty),
923 path_alias: None,
924 path_mode: PathMode::Walk,
925 };
926 let cost_unknown = model.expand_cost(&expand, 1000.0);
927
928 let expand_no_type = ExpandOp {
930 from_variable: "a".to_string(),
931 to_variable: "b".to_string(),
932 edge_variable: None,
933 direction: ExpandDirection::Outgoing,
934 edge_types: vec![],
935 min_hops: 1,
936 max_hops: Some(1),
937 input: Box::new(LogicalOperator::Empty),
938 path_alias: None,
939 path_mode: PathMode::Walk,
940 };
941 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
942
943 assert!(
945 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
946 "Unknown edge type should use global fanout"
947 );
948 }
949
950 #[test]
951 fn test_cost_model_hash_join() {
952 let model = CostModel::new();
953 let join = JoinOp {
954 left: Box::new(LogicalOperator::Empty),
955 right: Box::new(LogicalOperator::Empty),
956 join_type: JoinType::Inner,
957 conditions: vec![JoinCondition {
958 left: LogicalExpression::Variable("a".to_string()),
959 right: LogicalExpression::Variable("b".to_string()),
960 }],
961 };
962 let cost = model.join_cost(&join, 10000.0);
963
964 assert!(cost.cpu > 0.0);
966 assert!(cost.memory > 0.0);
967 }
968
969 #[test]
970 fn test_cost_model_cross_join() {
971 let model = CostModel::new();
972 let join = JoinOp {
973 left: Box::new(LogicalOperator::Empty),
974 right: Box::new(LogicalOperator::Empty),
975 join_type: JoinType::Cross,
976 conditions: vec![],
977 };
978 let cost = model.join_cost(&join, 1000000.0);
979
980 assert!(cost.cpu > 0.0);
982 }
983
984 #[test]
985 fn test_cost_model_semi_join() {
986 let model = CostModel::new();
987 let join = JoinOp {
988 left: Box::new(LogicalOperator::Empty),
989 right: Box::new(LogicalOperator::Empty),
990 join_type: JoinType::Semi,
991 conditions: vec![],
992 };
993 let cost_semi = model.join_cost(&join, 1000.0);
994
995 let inner_join = JoinOp {
996 left: Box::new(LogicalOperator::Empty),
997 right: Box::new(LogicalOperator::Empty),
998 join_type: JoinType::Inner,
999 conditions: vec![],
1000 };
1001 let cost_inner = model.join_cost(&inner_join, 1000.0);
1002
1003 assert!(cost_semi.cpu > 0.0);
1005 assert!(cost_inner.cpu > 0.0);
1006 }
1007
1008 #[test]
1009 fn test_cost_model_aggregate() {
1010 let model = CostModel::new();
1011 let agg = AggregateOp {
1012 group_by: vec![],
1013 aggregates: vec![
1014 AggregateExpr {
1015 function: AggregateFunction::Count,
1016 expression: None,
1017 expression2: None,
1018 distinct: false,
1019 alias: Some("cnt".to_string()),
1020 percentile: None,
1021 separator: None,
1022 },
1023 AggregateExpr {
1024 function: AggregateFunction::Sum,
1025 expression: Some(LogicalExpression::Variable("x".to_string())),
1026 expression2: None,
1027 distinct: false,
1028 alias: Some("total".to_string()),
1029 percentile: None,
1030 separator: None,
1031 },
1032 ],
1033 input: Box::new(LogicalOperator::Empty),
1034 having: None,
1035 };
1036 let cost = model.aggregate_cost(&agg, 1000.0);
1037
1038 assert!(cost.cpu > 0.0);
1040 assert!(cost.memory > 0.0);
1041 }
1042
1043 #[test]
1044 fn test_cost_model_distinct() {
1045 let model = CostModel::new();
1046 let distinct = DistinctOp {
1047 input: Box::new(LogicalOperator::Empty),
1048 columns: None,
1049 };
1050 let cost = model.distinct_cost(&distinct, 1000.0);
1051
1052 assert!(cost.cpu > 0.0);
1054 assert!(cost.memory > 0.0);
1055 }
1056
1057 #[test]
1058 fn test_cost_model_limit() {
1059 let model = CostModel::new();
1060 let limit = LimitOp {
1061 count: 10.into(),
1062 input: Box::new(LogicalOperator::Empty),
1063 };
1064 let cost = model.limit_cost(&limit, 1000.0);
1065
1066 assert!(cost.cpu > 0.0);
1068 assert!(cost.cpu < 1.0); }
1070
1071 #[test]
1072 fn test_cost_model_skip() {
1073 let model = CostModel::new();
1074 let skip = SkipOp {
1075 count: 100.into(),
1076 input: Box::new(LogicalOperator::Empty),
1077 };
1078 let cost = model.skip_cost(&skip, 1000.0);
1079
1080 assert!(cost.cpu > 0.0);
1082 }
1083
1084 #[test]
1085 fn test_cost_model_return() {
1086 let model = CostModel::new();
1087 let ret = ReturnOp {
1088 items: vec![
1089 ReturnItem {
1090 expression: LogicalExpression::Variable("a".to_string()),
1091 alias: None,
1092 },
1093 ReturnItem {
1094 expression: LogicalExpression::Variable("b".to_string()),
1095 alias: None,
1096 },
1097 ],
1098 distinct: false,
1099 input: Box::new(LogicalOperator::Empty),
1100 };
1101 let cost = model.return_cost(&ret, 1000.0);
1102
1103 assert!(cost.cpu > 0.0);
1105 }
1106
1107 #[test]
1108 fn test_cost_cheaper() {
1109 let model = CostModel::new();
1110 let cheap = Cost::cpu(10.0);
1111 let expensive = Cost::cpu(100.0);
1112
1113 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1114 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1115 }
1116
1117 #[test]
1118 fn test_cost_comparison_prefers_lower_total() {
1119 let model = CostModel::new();
1120 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1122 let io_heavy = Cost::cpu(10.0).with_io(20.0);
1124
1125 assert!(cpu_heavy.total() < io_heavy.total());
1127 assert_eq!(
1128 model.cheaper(&cpu_heavy, &io_heavy).total(),
1129 cpu_heavy.total()
1130 );
1131 }
1132
1133 #[test]
1134 fn test_cost_model_sort_with_keys() {
1135 let model = CostModel::new();
1136 let sort_single = SortOp {
1137 keys: vec![crate::query::plan::SortKey {
1138 expression: LogicalExpression::Variable("a".to_string()),
1139 order: SortOrder::Ascending,
1140 nulls: None,
1141 }],
1142 input: Box::new(LogicalOperator::Empty),
1143 };
1144 let sort_multi = SortOp {
1145 keys: vec![
1146 crate::query::plan::SortKey {
1147 expression: LogicalExpression::Variable("a".to_string()),
1148 order: SortOrder::Ascending,
1149 nulls: None,
1150 },
1151 crate::query::plan::SortKey {
1152 expression: LogicalExpression::Variable("b".to_string()),
1153 order: SortOrder::Descending,
1154 nulls: None,
1155 },
1156 ],
1157 input: Box::new(LogicalOperator::Empty),
1158 };
1159
1160 let cost_single = model.sort_cost(&sort_single, 1000.0);
1161 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1162
1163 assert!(cost_multi.cpu > cost_single.cpu);
1165 }
1166
1167 #[test]
1168 fn test_cost_model_empty_operator() {
1169 let model = CostModel::new();
1170 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1171 assert!((cost.total()).abs() < 0.001);
1172 }
1173
1174 #[test]
1175 fn test_cost_model_default() {
1176 let model = CostModel::default();
1177 let scan = NodeScanOp {
1178 variable: "n".to_string(),
1179 label: None,
1180 input: None,
1181 };
1182 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1183 assert!(cost.total() > 0.0);
1184 }
1185
1186 #[test]
1187 fn test_leapfrog_join_cost() {
1188 let model = CostModel::new();
1189
1190 let cardinalities = vec![1000.0, 1000.0, 1000.0];
1192 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1193
1194 assert!(cost.cpu > 0.0);
1196 assert!(cost.memory > 0.0);
1198 }
1199
1200 #[test]
1201 fn test_leapfrog_join_cost_empty() {
1202 let model = CostModel::new();
1203 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1204 assert!((cost.total()).abs() < 0.001);
1205 }
1206
1207 #[test]
1208 fn test_prefer_leapfrog_join_for_triangles() {
1209 let model = CostModel::new();
1210
1211 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1213 let output = 1000.0;
1214
1215 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1216
1217 assert!(leapfrog_cost.cpu > 0.0);
1219 assert!(leapfrog_cost.memory > 0.0);
1220
1221 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1224 }
1226
1227 #[test]
1228 fn test_prefer_leapfrog_join_binary_case() {
1229 let model = CostModel::new();
1230
1231 let cardinalities = vec![1000.0, 1000.0];
1233 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1234 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1235 }
1236
1237 #[test]
1238 fn test_factorized_benefit_single_hop() {
1239 let model = CostModel::new();
1240
1241 let benefit = model.factorized_benefit(10.0, 1);
1243 assert!(
1244 (benefit - 1.0).abs() < 0.001,
1245 "Single hop should have no benefit"
1246 );
1247 }
1248
1249 #[test]
1250 fn test_factorized_benefit_multi_hop() {
1251 let model = CostModel::new();
1252
1253 let benefit = model.factorized_benefit(10.0, 3);
1255
1256 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1260 assert!(benefit > 0.0, "Benefit should be positive");
1261 }
1262
1263 #[test]
1264 fn test_factorized_benefit_low_fanout() {
1265 let model = CostModel::new();
1266
1267 let benefit = model.factorized_benefit(1.5, 2);
1269 assert!(
1270 benefit <= 1.0,
1271 "Low fanout still benefits from factorization"
1272 );
1273 }
1274
1275 #[test]
1276 fn test_node_scan_uses_label_cardinality_for_io() {
1277 let mut label_cards = std::collections::HashMap::new();
1278 label_cards.insert("Person".to_string(), 500_u64);
1279 label_cards.insert("Company".to_string(), 50_u64);
1280
1281 let model = CostModel::new()
1282 .with_label_cardinalities(label_cards)
1283 .with_graph_totals(550, 1000);
1284
1285 let person_scan = NodeScanOp {
1286 variable: "n".to_string(),
1287 label: Some("Person".to_string()),
1288 input: None,
1289 };
1290 let company_scan = NodeScanOp {
1291 variable: "n".to_string(),
1292 label: Some("Company".to_string()),
1293 input: None,
1294 };
1295
1296 let person_cost = model.node_scan_cost(&person_scan, 500.0);
1297 let company_cost = model.node_scan_cost(&company_scan, 50.0);
1298
1299 assert!(
1301 person_cost.io > company_cost.io * 5.0,
1302 "Person ({}) should have much higher IO than Company ({})",
1303 person_cost.io,
1304 company_cost.io
1305 );
1306 }
1307
1308 #[test]
1309 fn test_node_scan_unlabeled_uses_total_nodes() {
1310 let model = CostModel::new().with_graph_totals(10_000, 50_000);
1311
1312 let scan = NodeScanOp {
1313 variable: "n".to_string(),
1314 label: None,
1315 input: None,
1316 };
1317
1318 let cost = model.node_scan_cost(&scan, 10_000.0);
1319 let expected_pages = (10_000.0 * 100.0) / 8192.0;
1320 assert!(
1321 (cost.io - expected_pages).abs() < 0.1,
1322 "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1323 cost.io,
1324 expected_pages
1325 );
1326 }
1327
1328 #[test]
1329 fn test_join_cost_with_actual_child_cardinalities() {
1330 let model = CostModel::new();
1331 let join = JoinOp {
1332 left: Box::new(LogicalOperator::Empty),
1333 right: Box::new(LogicalOperator::Empty),
1334 join_type: JoinType::Inner,
1335 conditions: vec![JoinCondition {
1336 left: LogicalExpression::Variable("a".to_string()),
1337 right: LogicalExpression::Variable("b".to_string()),
1338 }],
1339 };
1340
1341 let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1343
1344 let cost_sqrt = model.join_cost(&join, 500.0);
1346
1347 assert!(
1351 cost_actual.cpu > cost_sqrt.cpu,
1352 "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1353 cost_actual.cpu,
1354 cost_sqrt.cpu
1355 );
1356 }
1357
1358 #[test]
1359 fn test_expand_multi_edge_types() {
1360 let mut degrees = std::collections::HashMap::new();
1361 degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1362 degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1363
1364 let model = CostModel::new().with_edge_type_degrees(degrees);
1365
1366 let multi_expand = ExpandOp {
1368 from_variable: "a".to_string(),
1369 to_variable: "b".to_string(),
1370 edge_variable: None,
1371 direction: ExpandDirection::Outgoing,
1372 edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1373 min_hops: 1,
1374 max_hops: Some(1),
1375 input: Box::new(LogicalOperator::Empty),
1376 path_alias: None,
1377 path_mode: PathMode::Walk,
1378 };
1379 let multi_cost = model.expand_cost(&multi_expand, 100.0);
1380
1381 let single_expand = ExpandOp {
1383 from_variable: "a".to_string(),
1384 to_variable: "b".to_string(),
1385 edge_variable: None,
1386 direction: ExpandDirection::Outgoing,
1387 edge_types: vec!["KNOWS".to_string()],
1388 min_hops: 1,
1389 max_hops: Some(1),
1390 input: Box::new(LogicalOperator::Empty),
1391 path_alias: None,
1392 path_mode: PathMode::Walk,
1393 };
1394 let single_cost = model.expand_cost(&single_expand, 100.0);
1395
1396 assert!(
1398 multi_cost.cpu > single_cost.cpu * 3.0,
1399 "Multi-type fanout ({}) should be much higher than single-type ({})",
1400 multi_cost.cpu,
1401 single_cost.cpu
1402 );
1403 }
1404
1405 #[test]
1406 fn test_recursive_tree_cost() {
1407 use crate::query::optimizer::CardinalityEstimator;
1408
1409 let mut label_cards = std::collections::HashMap::new();
1410 label_cards.insert("Person".to_string(), 1000_u64);
1411
1412 let model = CostModel::new()
1413 .with_label_cardinalities(label_cards)
1414 .with_graph_totals(1000, 5000)
1415 .with_avg_fanout(5.0);
1416
1417 let mut card_est = CardinalityEstimator::new();
1418 card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1419
1420 let plan = LogicalOperator::Return(ReturnOp {
1422 items: vec![ReturnItem {
1423 expression: LogicalExpression::Variable("n".to_string()),
1424 alias: None,
1425 }],
1426 distinct: false,
1427 input: Box::new(LogicalOperator::Filter(FilterOp {
1428 predicate: LogicalExpression::Binary {
1429 left: Box::new(LogicalExpression::Property {
1430 variable: "n".to_string(),
1431 property: "age".to_string(),
1432 }),
1433 op: crate::query::plan::BinaryOp::Gt,
1434 right: Box::new(LogicalExpression::Literal(
1435 grafeo_common::types::Value::Int64(30),
1436 )),
1437 },
1438 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1439 variable: "n".to_string(),
1440 label: Some("Person".to_string()),
1441 input: None,
1442 })),
1443 pushdown_hint: None,
1444 })),
1445 });
1446
1447 let tree_cost = model.estimate_tree(&plan, &card_est);
1448
1449 assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1451 assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1452
1453 let root_only_card = card_est.estimate(&plan);
1455 let root_only_cost = model.estimate(&plan, root_only_card);
1456
1457 assert!(
1459 tree_cost.total() > root_only_cost.total(),
1460 "Recursive tree cost ({}) should exceed root-only cost ({})",
1461 tree_cost.total(),
1462 root_only_cost.total()
1463 );
1464 }
1465
1466 #[test]
1467 fn test_statistics_driven_vs_default_cost() {
1468 let default_model = CostModel::new();
1469
1470 let mut label_cards = std::collections::HashMap::new();
1471 label_cards.insert("Person".to_string(), 100_u64);
1472 let stats_model = CostModel::new()
1473 .with_label_cardinalities(label_cards)
1474 .with_graph_totals(100, 500);
1475
1476 let scan = NodeScanOp {
1478 variable: "n".to_string(),
1479 label: Some("Person".to_string()),
1480 input: None,
1481 };
1482
1483 let default_cost = default_model.node_scan_cost(&scan, 100.0);
1484 let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1485
1486 assert!(
1490 (default_cost.io - stats_cost.io).abs() < 0.1,
1491 "When cardinality matches label size, costs should be similar"
1492 );
1493 }
1494}