1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7 LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp,
8 VectorScanOp,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
16pub struct Cost {
17 pub cpu: f64,
19 pub io: f64,
21 pub memory: f64,
23 pub network: f64,
25}
26
27impl Cost {
28 #[must_use]
30 pub fn zero() -> Self {
31 Self {
32 cpu: 0.0,
33 io: 0.0,
34 memory: 0.0,
35 network: 0.0,
36 }
37 }
38
39 #[must_use]
41 pub fn cpu(cpu: f64) -> Self {
42 Self {
43 cpu,
44 io: 0.0,
45 memory: 0.0,
46 network: 0.0,
47 }
48 }
49
50 #[must_use]
52 pub fn with_io(mut self, io: f64) -> Self {
53 self.io = io;
54 self
55 }
56
57 #[must_use]
59 pub fn with_memory(mut self, memory: f64) -> Self {
60 self.memory = memory;
61 self
62 }
63
64 #[must_use]
68 pub fn total(&self) -> f64 {
69 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
70 }
71
72 #[must_use]
74 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
75 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
76 }
77}
78
79impl std::ops::Add for Cost {
80 type Output = Self;
81
82 fn add(self, other: Self) -> Self {
83 Self {
84 cpu: self.cpu + other.cpu,
85 io: self.io + other.io,
86 memory: self.memory + other.memory,
87 network: self.network + other.network,
88 }
89 }
90}
91
92impl std::ops::AddAssign for Cost {
93 fn add_assign(&mut self, other: Self) {
94 self.cpu += other.cpu;
95 self.io += other.io;
96 self.memory += other.memory;
97 self.network += other.network;
98 }
99}
100
101pub struct CostModel {
109 cpu_tuple_cost: f64,
111 hash_lookup_cost: f64,
113 sort_comparison_cost: f64,
115 avg_tuple_size: f64,
117 page_size: f64,
119 avg_fanout: f64,
121 edge_type_degrees: HashMap<String, (f64, f64)>,
123 label_cardinalities: HashMap<String, u64>,
125 total_nodes: u64,
127 total_edges: u64,
129}
130
131impl CostModel {
132 #[must_use]
134 pub fn new() -> Self {
135 Self {
136 cpu_tuple_cost: 0.01,
137 hash_lookup_cost: 0.03,
138 sort_comparison_cost: 0.02,
139 avg_tuple_size: 100.0,
140 page_size: 8192.0,
141 avg_fanout: 10.0,
142 edge_type_degrees: HashMap::new(),
143 label_cardinalities: HashMap::new(),
144 total_nodes: 0,
145 total_edges: 0,
146 }
147 }
148
149 #[must_use]
151 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
152 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
153 self
154 }
155
156 #[must_use]
160 pub fn with_edge_type_degrees(mut self, degrees: HashMap<String, (f64, f64)>) -> Self {
161 self.edge_type_degrees = degrees;
162 self
163 }
164
165 #[must_use]
167 pub fn with_label_cardinalities(mut self, cardinalities: HashMap<String, u64>) -> Self {
168 self.label_cardinalities = cardinalities;
169 self
170 }
171
172 #[must_use]
174 pub fn with_graph_totals(mut self, total_nodes: u64, total_edges: u64) -> Self {
175 self.total_nodes = total_nodes;
176 self.total_edges = total_edges;
177 self
178 }
179
180 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
185 if expand.edge_types.is_empty() {
186 return self.avg_fanout;
187 }
188
189 let mut total_fanout = 0.0;
190 let mut all_found = true;
191
192 for edge_type in &expand.edge_types {
193 if let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(edge_type) {
194 total_fanout += match expand.direction {
195 ExpandDirection::Outgoing => out_deg,
196 ExpandDirection::Incoming => in_deg,
197 ExpandDirection::Both => out_deg + in_deg,
198 };
199 } else {
200 all_found = false;
201 break;
202 }
203 }
204
205 if all_found {
206 total_fanout
207 } else {
208 self.avg_fanout
209 }
210 }
211
212 #[must_use]
214 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
215 match op {
216 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
217 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
218 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
219 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
220 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
221 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
222 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
223 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
224 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
225 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
226 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
227 LogicalOperator::Empty => Cost::zero(),
228 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
229 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
230 LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
231 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
232 }
233 }
234
235 fn node_scan_cost(&self, scan: &NodeScanOp, cardinality: f64) -> Cost {
241 let scan_size = if let Some(label) = &scan.label {
243 self.label_cardinalities
244 .get(label)
245 .map_or(cardinality, |&count| count as f64)
246 } else if self.total_nodes > 0 {
247 self.total_nodes as f64
248 } else {
249 cardinality
250 };
251 let pages = (scan_size * self.avg_tuple_size) / self.page_size;
252 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
254 }
255
256 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
258 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
260 }
261
262 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
264 let expr_count = project.projections.len() as f64;
266 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
267 }
268
269 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
274 let fanout = self.fanout_for_expand(expand);
275 let lookup_cost = cardinality * self.hash_lookup_cost;
277 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
279 Cost::cpu(lookup_cost + output_cost)
280 }
281
282 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
287 self.join_cost_with_children(join, cardinality, None, None)
288 }
289
290 fn join_cost_with_children(
292 &self,
293 join: &JoinOp,
294 cardinality: f64,
295 left_cardinality: Option<f64>,
296 right_cardinality: Option<f64>,
297 ) -> Cost {
298 match join.join_type {
299 JoinType::Cross => Cost::cpu(cardinality * self.cpu_tuple_cost),
300 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
301 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
303 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
304
305 let build_cost = build_cardinality * self.hash_lookup_cost;
306 let memory_cost = build_cardinality * self.avg_tuple_size;
307 let probe_cost = probe_cardinality * self.hash_lookup_cost;
308 let output_cost = cardinality * self.cpu_tuple_cost;
309
310 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
311 }
312 JoinType::Semi | JoinType::Anti => {
313 let build_cardinality = left_cardinality.unwrap_or_else(|| cardinality.sqrt());
314 let probe_cardinality = right_cardinality.unwrap_or_else(|| cardinality.sqrt());
315
316 let build_cost = build_cardinality * self.hash_lookup_cost;
317 let probe_cost = probe_cardinality * self.hash_lookup_cost;
318
319 Cost::cpu(build_cost + probe_cost)
320 .with_memory(build_cardinality * self.avg_tuple_size)
321 }
322 }
323 }
324
325 fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
330 let n = mwj.inputs.len();
331 if n == 0 {
332 return Cost::zero();
333 }
334 let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
337 let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
338 self.leapfrog_join_cost(n, &cardinalities, cardinality)
339 }
340
341 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
343 let hash_cost = cardinality * self.hash_lookup_cost;
345
346 let agg_count = agg.aggregates.len() as f64;
348 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
349
350 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
353
354 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
355 }
356
357 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
359 if cardinality <= 1.0 {
360 return Cost::zero();
361 }
362
363 let comparisons = cardinality * cardinality.log2();
365 let key_count = sort.keys.len() as f64;
366
367 let memory_cost = cardinality * self.avg_tuple_size;
369
370 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
371 }
372
373 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
375 let hash_cost = cardinality * self.hash_lookup_cost;
377 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
380 }
381
382 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
384 Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
386 }
387
388 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
390 Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
392 }
393
394 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
396 let expr_count = ret.items.len() as f64;
398 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
399 }
400
401 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
406 let k = scan.k as f64;
408
409 let ef = 64.0;
412 let n = cardinality.max(1.0);
413 let search_cost = if scan.index_name.is_some() {
414 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
417 n * self.cpu_tuple_cost * 10.0
419 };
420
421 let memory = k * self.avg_tuple_size * 2.0;
423
424 Cost::cpu(search_cost).with_memory(memory)
425 }
426
427 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
431 let k = join.k as f64;
432
433 let per_row_search_cost = if join.index_name.is_some() {
436 let ef = 64.0;
438 let n = cardinality.max(1.0);
439 ef * n.ln() * self.cpu_tuple_cost * 10.0
440 } else {
441 cardinality * self.cpu_tuple_cost * 10.0
443 };
444
445 let input_cardinality = (cardinality / k).max(1.0);
448 let total_search_cost = input_cardinality * per_row_search_cost;
449
450 let memory = cardinality * self.avg_tuple_size;
452
453 Cost::cpu(total_search_cost).with_memory(memory)
454 }
455
456 #[must_use]
462 pub fn estimate_tree(
463 &self,
464 op: &LogicalOperator,
465 card_estimator: &super::CardinalityEstimator,
466 ) -> Cost {
467 self.estimate_tree_inner(op, card_estimator)
468 }
469
470 fn estimate_tree_inner(
471 &self,
472 op: &LogicalOperator,
473 card_est: &super::CardinalityEstimator,
474 ) -> Cost {
475 let cardinality = card_est.estimate(op);
476
477 match op {
478 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
479 LogicalOperator::Filter(filter) => {
480 let child_cost = self.estimate_tree_inner(&filter.input, card_est);
481 child_cost + self.filter_cost(filter, cardinality)
482 }
483 LogicalOperator::Project(project) => {
484 let child_cost = self.estimate_tree_inner(&project.input, card_est);
485 child_cost + self.project_cost(project, cardinality)
486 }
487 LogicalOperator::Expand(expand) => {
488 let child_cost = self.estimate_tree_inner(&expand.input, card_est);
489 child_cost + self.expand_cost(expand, cardinality)
490 }
491 LogicalOperator::Join(join) => {
492 let left_cost = self.estimate_tree_inner(&join.left, card_est);
493 let right_cost = self.estimate_tree_inner(&join.right, card_est);
494 let left_card = card_est.estimate(&join.left);
495 let right_card = card_est.estimate(&join.right);
496 let join_cost = self.join_cost_with_children(
497 join,
498 cardinality,
499 Some(left_card),
500 Some(right_card),
501 );
502 left_cost + right_cost + join_cost
503 }
504 LogicalOperator::Aggregate(agg) => {
505 let child_cost = self.estimate_tree_inner(&agg.input, card_est);
506 child_cost + self.aggregate_cost(agg, cardinality)
507 }
508 LogicalOperator::Sort(sort) => {
509 let child_cost = self.estimate_tree_inner(&sort.input, card_est);
510 child_cost + self.sort_cost(sort, cardinality)
511 }
512 LogicalOperator::Distinct(distinct) => {
513 let child_cost = self.estimate_tree_inner(&distinct.input, card_est);
514 child_cost + self.distinct_cost(distinct, cardinality)
515 }
516 LogicalOperator::Limit(limit) => {
517 let child_cost = self.estimate_tree_inner(&limit.input, card_est);
518 child_cost + self.limit_cost(limit, cardinality)
519 }
520 LogicalOperator::Skip(skip) => {
521 let child_cost = self.estimate_tree_inner(&skip.input, card_est);
522 child_cost + self.skip_cost(skip, cardinality)
523 }
524 LogicalOperator::Return(ret) => {
525 let child_cost = self.estimate_tree_inner(&ret.input, card_est);
526 child_cost + self.return_cost(ret, cardinality)
527 }
528 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
529 LogicalOperator::VectorJoin(join) => {
530 let child_cost = self.estimate_tree_inner(&join.input, card_est);
531 child_cost + self.vector_join_cost(join, cardinality)
532 }
533 LogicalOperator::MultiWayJoin(mwj) => {
534 let mut children_cost = Cost::zero();
535 for input in &mwj.inputs {
536 children_cost += self.estimate_tree_inner(input, card_est);
537 }
538 children_cost + self.multi_way_join_cost(mwj, cardinality)
539 }
540 LogicalOperator::Empty => Cost::zero(),
541 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
542 }
543 }
544
545 #[must_use]
547 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
548 if a.total() <= b.total() { a } else { b }
549 }
550
551 #[must_use]
567 pub fn leapfrog_join_cost(
568 &self,
569 num_relations: usize,
570 cardinalities: &[f64],
571 output_cardinality: f64,
572 ) -> Cost {
573 if cardinalities.is_empty() {
574 return Cost::zero();
575 }
576
577 let total_input: f64 = cardinalities.iter().sum();
578 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
579
580 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
585 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
586 } else {
587 output_cardinality * self.cpu_tuple_cost
588 };
589
590 let output_cost = output_cardinality * self.cpu_tuple_cost;
592
593 let memory = total_input * self.avg_tuple_size * 2.0;
595
596 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
597 }
598
599 #[must_use]
603 pub fn prefer_leapfrog_join(
604 &self,
605 num_relations: usize,
606 cardinalities: &[f64],
607 output_cardinality: f64,
608 ) -> bool {
609 if num_relations < 3 || cardinalities.len() < 3 {
610 return false;
612 }
613
614 let leapfrog_cost =
615 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
616
617 let mut hash_cascade_cost = Cost::zero();
621 let mut intermediate_cardinality = cardinalities[0];
622
623 for card in &cardinalities[1..] {
624 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
627 left: Box::new(LogicalOperator::Empty),
628 right: Box::new(LogicalOperator::Empty),
629 join_type: JoinType::Inner,
630 conditions: vec![],
631 };
632 hash_cascade_cost += self.join_cost(&join, join_output);
633 intermediate_cardinality = join_output;
634 }
635
636 leapfrog_cost.total() < hash_cascade_cost.total()
637 }
638
639 #[must_use]
647 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
648 if num_hops <= 1 || avg_fanout <= 1.0 {
649 return 1.0; }
651
652 let full_size = avg_fanout.powi(num_hops as i32);
658 let factorized_size = if avg_fanout > 1.0 {
659 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
660 } else {
661 num_hops as f64
662 };
663
664 (factorized_size / full_size).min(1.0)
665 }
666}
667
668impl Default for CostModel {
669 fn default() -> Self {
670 Self::new()
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use crate::query::plan::{
678 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
679 PathMode, Projection, ReturnItem, SortOrder,
680 };
681
682 #[test]
683 fn test_cost_addition() {
684 let a = Cost::cpu(10.0).with_io(5.0);
685 let b = Cost::cpu(20.0).with_memory(100.0);
686 let c = a + b;
687
688 assert!((c.cpu - 30.0).abs() < 0.001);
689 assert!((c.io - 5.0).abs() < 0.001);
690 assert!((c.memory - 100.0).abs() < 0.001);
691 }
692
693 #[test]
694 fn test_cost_total() {
695 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
696 assert!((cost.total() - 30.0).abs() < 0.001);
698 }
699
700 #[test]
701 fn test_cost_model_node_scan() {
702 let model = CostModel::new();
703 let scan = NodeScanOp {
704 variable: "n".to_string(),
705 label: Some("Person".to_string()),
706 input: None,
707 };
708 let cost = model.node_scan_cost(&scan, 1000.0);
709
710 assert!(cost.cpu > 0.0);
711 assert!(cost.io > 0.0);
712 }
713
714 #[test]
715 fn test_cost_model_sort() {
716 let model = CostModel::new();
717 let sort = SortOp {
718 keys: vec![],
719 input: Box::new(LogicalOperator::Empty),
720 };
721
722 let cost_100 = model.sort_cost(&sort, 100.0);
723 let cost_1000 = model.sort_cost(&sort, 1000.0);
724
725 assert!(cost_1000.total() > cost_100.total());
727 }
728
729 #[test]
730 fn test_cost_zero() {
731 let cost = Cost::zero();
732 assert!((cost.cpu).abs() < 0.001);
733 assert!((cost.io).abs() < 0.001);
734 assert!((cost.memory).abs() < 0.001);
735 assert!((cost.network).abs() < 0.001);
736 assert!((cost.total()).abs() < 0.001);
737 }
738
739 #[test]
740 fn test_cost_add_assign() {
741 let mut cost = Cost::cpu(10.0);
742 cost += Cost::cpu(5.0).with_io(2.0);
743 assert!((cost.cpu - 15.0).abs() < 0.001);
744 assert!((cost.io - 2.0).abs() < 0.001);
745 }
746
747 #[test]
748 fn test_cost_total_weighted() {
749 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
750 let total = cost.total_weighted(2.0, 5.0, 0.5);
752 assert!((total - 80.0).abs() < 0.001);
753 }
754
755 #[test]
756 fn test_cost_model_filter() {
757 let model = CostModel::new();
758 let filter = FilterOp {
759 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
760 input: Box::new(LogicalOperator::Empty),
761 pushdown_hint: None,
762 };
763 let cost = model.filter_cost(&filter, 1000.0);
764
765 assert!(cost.cpu > 0.0);
767 assert!((cost.io).abs() < 0.001);
768 }
769
770 #[test]
771 fn test_cost_model_project() {
772 let model = CostModel::new();
773 let project = ProjectOp {
774 projections: vec![
775 Projection {
776 expression: LogicalExpression::Variable("a".to_string()),
777 alias: None,
778 },
779 Projection {
780 expression: LogicalExpression::Variable("b".to_string()),
781 alias: None,
782 },
783 ],
784 input: Box::new(LogicalOperator::Empty),
785 pass_through_input: false,
786 };
787 let cost = model.project_cost(&project, 1000.0);
788
789 assert!(cost.cpu > 0.0);
791 }
792
793 #[test]
794 fn test_cost_model_expand() {
795 let model = CostModel::new();
796 let expand = ExpandOp {
797 from_variable: "a".to_string(),
798 to_variable: "b".to_string(),
799 edge_variable: None,
800 direction: ExpandDirection::Outgoing,
801 edge_types: vec![],
802 min_hops: 1,
803 max_hops: Some(1),
804 input: Box::new(LogicalOperator::Empty),
805 path_alias: None,
806 path_mode: PathMode::Walk,
807 };
808 let cost = model.expand_cost(&expand, 1000.0);
809
810 assert!(cost.cpu > 0.0);
812 }
813
814 #[test]
815 fn test_cost_model_expand_with_edge_type_stats() {
816 let mut degrees = std::collections::HashMap::new();
817 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
821
822 let knows_out = ExpandOp {
824 from_variable: "a".to_string(),
825 to_variable: "b".to_string(),
826 edge_variable: None,
827 direction: ExpandDirection::Outgoing,
828 edge_types: vec!["KNOWS".to_string()],
829 min_hops: 1,
830 max_hops: Some(1),
831 input: Box::new(LogicalOperator::Empty),
832 path_alias: None,
833 path_mode: PathMode::Walk,
834 };
835 let cost_knows = model.expand_cost(&knows_out, 1000.0);
836
837 let works_out = ExpandOp {
839 from_variable: "a".to_string(),
840 to_variable: "b".to_string(),
841 edge_variable: None,
842 direction: ExpandDirection::Outgoing,
843 edge_types: vec!["WORKS_AT".to_string()],
844 min_hops: 1,
845 max_hops: Some(1),
846 input: Box::new(LogicalOperator::Empty),
847 path_alias: None,
848 path_mode: PathMode::Walk,
849 };
850 let cost_works = model.expand_cost(&works_out, 1000.0);
851
852 assert!(
854 cost_knows.cpu > cost_works.cpu,
855 "KNOWS(5) should cost more than WORKS_AT(1)"
856 );
857
858 let works_in = ExpandOp {
860 from_variable: "c".to_string(),
861 to_variable: "p".to_string(),
862 edge_variable: None,
863 direction: ExpandDirection::Incoming,
864 edge_types: vec!["WORKS_AT".to_string()],
865 min_hops: 1,
866 max_hops: Some(1),
867 input: Box::new(LogicalOperator::Empty),
868 path_alias: None,
869 path_mode: PathMode::Walk,
870 };
871 let cost_works_in = model.expand_cost(&works_in, 1000.0);
872
873 assert!(
875 cost_works_in.cpu > cost_knows.cpu,
876 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
877 );
878 }
879
880 #[test]
881 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
882 let model = CostModel::new().with_avg_fanout(7.0);
883 let expand = ExpandOp {
884 from_variable: "a".to_string(),
885 to_variable: "b".to_string(),
886 edge_variable: None,
887 direction: ExpandDirection::Outgoing,
888 edge_types: vec!["UNKNOWN_TYPE".to_string()],
889 min_hops: 1,
890 max_hops: Some(1),
891 input: Box::new(LogicalOperator::Empty),
892 path_alias: None,
893 path_mode: PathMode::Walk,
894 };
895 let cost_unknown = model.expand_cost(&expand, 1000.0);
896
897 let expand_no_type = ExpandOp {
899 from_variable: "a".to_string(),
900 to_variable: "b".to_string(),
901 edge_variable: None,
902 direction: ExpandDirection::Outgoing,
903 edge_types: vec![],
904 min_hops: 1,
905 max_hops: Some(1),
906 input: Box::new(LogicalOperator::Empty),
907 path_alias: None,
908 path_mode: PathMode::Walk,
909 };
910 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
911
912 assert!(
914 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
915 "Unknown edge type should use global fanout"
916 );
917 }
918
919 #[test]
920 fn test_cost_model_hash_join() {
921 let model = CostModel::new();
922 let join = JoinOp {
923 left: Box::new(LogicalOperator::Empty),
924 right: Box::new(LogicalOperator::Empty),
925 join_type: JoinType::Inner,
926 conditions: vec![JoinCondition {
927 left: LogicalExpression::Variable("a".to_string()),
928 right: LogicalExpression::Variable("b".to_string()),
929 }],
930 };
931 let cost = model.join_cost(&join, 10000.0);
932
933 assert!(cost.cpu > 0.0);
935 assert!(cost.memory > 0.0);
936 }
937
938 #[test]
939 fn test_cost_model_cross_join() {
940 let model = CostModel::new();
941 let join = JoinOp {
942 left: Box::new(LogicalOperator::Empty),
943 right: Box::new(LogicalOperator::Empty),
944 join_type: JoinType::Cross,
945 conditions: vec![],
946 };
947 let cost = model.join_cost(&join, 1000000.0);
948
949 assert!(cost.cpu > 0.0);
951 }
952
953 #[test]
954 fn test_cost_model_semi_join() {
955 let model = CostModel::new();
956 let join = JoinOp {
957 left: Box::new(LogicalOperator::Empty),
958 right: Box::new(LogicalOperator::Empty),
959 join_type: JoinType::Semi,
960 conditions: vec![],
961 };
962 let cost_semi = model.join_cost(&join, 1000.0);
963
964 let inner_join = JoinOp {
965 left: Box::new(LogicalOperator::Empty),
966 right: Box::new(LogicalOperator::Empty),
967 join_type: JoinType::Inner,
968 conditions: vec![],
969 };
970 let cost_inner = model.join_cost(&inner_join, 1000.0);
971
972 assert!(cost_semi.cpu > 0.0);
974 assert!(cost_inner.cpu > 0.0);
975 }
976
977 #[test]
978 fn test_cost_model_aggregate() {
979 let model = CostModel::new();
980 let agg = AggregateOp {
981 group_by: vec![],
982 aggregates: vec![
983 AggregateExpr {
984 function: AggregateFunction::Count,
985 expression: None,
986 expression2: None,
987 distinct: false,
988 alias: Some("cnt".to_string()),
989 percentile: None,
990 separator: None,
991 },
992 AggregateExpr {
993 function: AggregateFunction::Sum,
994 expression: Some(LogicalExpression::Variable("x".to_string())),
995 expression2: None,
996 distinct: false,
997 alias: Some("total".to_string()),
998 percentile: None,
999 separator: None,
1000 },
1001 ],
1002 input: Box::new(LogicalOperator::Empty),
1003 having: None,
1004 };
1005 let cost = model.aggregate_cost(&agg, 1000.0);
1006
1007 assert!(cost.cpu > 0.0);
1009 assert!(cost.memory > 0.0);
1010 }
1011
1012 #[test]
1013 fn test_cost_model_distinct() {
1014 let model = CostModel::new();
1015 let distinct = DistinctOp {
1016 input: Box::new(LogicalOperator::Empty),
1017 columns: None,
1018 };
1019 let cost = model.distinct_cost(&distinct, 1000.0);
1020
1021 assert!(cost.cpu > 0.0);
1023 assert!(cost.memory > 0.0);
1024 }
1025
1026 #[test]
1027 fn test_cost_model_limit() {
1028 let model = CostModel::new();
1029 let limit = LimitOp {
1030 count: 10.into(),
1031 input: Box::new(LogicalOperator::Empty),
1032 };
1033 let cost = model.limit_cost(&limit, 1000.0);
1034
1035 assert!(cost.cpu > 0.0);
1037 assert!(cost.cpu < 1.0); }
1039
1040 #[test]
1041 fn test_cost_model_skip() {
1042 let model = CostModel::new();
1043 let skip = SkipOp {
1044 count: 100.into(),
1045 input: Box::new(LogicalOperator::Empty),
1046 };
1047 let cost = model.skip_cost(&skip, 1000.0);
1048
1049 assert!(cost.cpu > 0.0);
1051 }
1052
1053 #[test]
1054 fn test_cost_model_return() {
1055 let model = CostModel::new();
1056 let ret = ReturnOp {
1057 items: vec![
1058 ReturnItem {
1059 expression: LogicalExpression::Variable("a".to_string()),
1060 alias: None,
1061 },
1062 ReturnItem {
1063 expression: LogicalExpression::Variable("b".to_string()),
1064 alias: None,
1065 },
1066 ],
1067 distinct: false,
1068 input: Box::new(LogicalOperator::Empty),
1069 };
1070 let cost = model.return_cost(&ret, 1000.0);
1071
1072 assert!(cost.cpu > 0.0);
1074 }
1075
1076 #[test]
1077 fn test_cost_cheaper() {
1078 let model = CostModel::new();
1079 let cheap = Cost::cpu(10.0);
1080 let expensive = Cost::cpu(100.0);
1081
1082 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
1083 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
1084 }
1085
1086 #[test]
1087 fn test_cost_comparison_prefers_lower_total() {
1088 let model = CostModel::new();
1089 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
1091 let io_heavy = Cost::cpu(10.0).with_io(20.0);
1093
1094 assert!(cpu_heavy.total() < io_heavy.total());
1096 assert_eq!(
1097 model.cheaper(&cpu_heavy, &io_heavy).total(),
1098 cpu_heavy.total()
1099 );
1100 }
1101
1102 #[test]
1103 fn test_cost_model_sort_with_keys() {
1104 let model = CostModel::new();
1105 let sort_single = SortOp {
1106 keys: vec![crate::query::plan::SortKey {
1107 expression: LogicalExpression::Variable("a".to_string()),
1108 order: SortOrder::Ascending,
1109 nulls: None,
1110 }],
1111 input: Box::new(LogicalOperator::Empty),
1112 };
1113 let sort_multi = SortOp {
1114 keys: vec![
1115 crate::query::plan::SortKey {
1116 expression: LogicalExpression::Variable("a".to_string()),
1117 order: SortOrder::Ascending,
1118 nulls: None,
1119 },
1120 crate::query::plan::SortKey {
1121 expression: LogicalExpression::Variable("b".to_string()),
1122 order: SortOrder::Descending,
1123 nulls: None,
1124 },
1125 ],
1126 input: Box::new(LogicalOperator::Empty),
1127 };
1128
1129 let cost_single = model.sort_cost(&sort_single, 1000.0);
1130 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
1131
1132 assert!(cost_multi.cpu > cost_single.cpu);
1134 }
1135
1136 #[test]
1137 fn test_cost_model_empty_operator() {
1138 let model = CostModel::new();
1139 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
1140 assert!((cost.total()).abs() < 0.001);
1141 }
1142
1143 #[test]
1144 fn test_cost_model_default() {
1145 let model = CostModel::default();
1146 let scan = NodeScanOp {
1147 variable: "n".to_string(),
1148 label: None,
1149 input: None,
1150 };
1151 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1152 assert!(cost.total() > 0.0);
1153 }
1154
1155 #[test]
1156 fn test_leapfrog_join_cost() {
1157 let model = CostModel::new();
1158
1159 let cardinalities = vec![1000.0, 1000.0, 1000.0];
1161 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1162
1163 assert!(cost.cpu > 0.0);
1165 assert!(cost.memory > 0.0);
1167 }
1168
1169 #[test]
1170 fn test_leapfrog_join_cost_empty() {
1171 let model = CostModel::new();
1172 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1173 assert!((cost.total()).abs() < 0.001);
1174 }
1175
1176 #[test]
1177 fn test_prefer_leapfrog_join_for_triangles() {
1178 let model = CostModel::new();
1179
1180 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1182 let output = 1000.0;
1183
1184 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1185
1186 assert!(leapfrog_cost.cpu > 0.0);
1188 assert!(leapfrog_cost.memory > 0.0);
1189
1190 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1193 }
1195
1196 #[test]
1197 fn test_prefer_leapfrog_join_binary_case() {
1198 let model = CostModel::new();
1199
1200 let cardinalities = vec![1000.0, 1000.0];
1202 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1203 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1204 }
1205
1206 #[test]
1207 fn test_factorized_benefit_single_hop() {
1208 let model = CostModel::new();
1209
1210 let benefit = model.factorized_benefit(10.0, 1);
1212 assert!(
1213 (benefit - 1.0).abs() < 0.001,
1214 "Single hop should have no benefit"
1215 );
1216 }
1217
1218 #[test]
1219 fn test_factorized_benefit_multi_hop() {
1220 let model = CostModel::new();
1221
1222 let benefit = model.factorized_benefit(10.0, 3);
1224
1225 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1229 assert!(benefit > 0.0, "Benefit should be positive");
1230 }
1231
1232 #[test]
1233 fn test_factorized_benefit_low_fanout() {
1234 let model = CostModel::new();
1235
1236 let benefit = model.factorized_benefit(1.5, 2);
1238 assert!(
1239 benefit <= 1.0,
1240 "Low fanout still benefits from factorization"
1241 );
1242 }
1243
1244 #[test]
1245 fn test_node_scan_uses_label_cardinality_for_io() {
1246 let mut label_cards = std::collections::HashMap::new();
1247 label_cards.insert("Person".to_string(), 500_u64);
1248 label_cards.insert("Company".to_string(), 50_u64);
1249
1250 let model = CostModel::new()
1251 .with_label_cardinalities(label_cards)
1252 .with_graph_totals(550, 1000);
1253
1254 let person_scan = NodeScanOp {
1255 variable: "n".to_string(),
1256 label: Some("Person".to_string()),
1257 input: None,
1258 };
1259 let company_scan = NodeScanOp {
1260 variable: "n".to_string(),
1261 label: Some("Company".to_string()),
1262 input: None,
1263 };
1264
1265 let person_cost = model.node_scan_cost(&person_scan, 500.0);
1266 let company_cost = model.node_scan_cost(&company_scan, 50.0);
1267
1268 assert!(
1270 person_cost.io > company_cost.io * 5.0,
1271 "Person ({}) should have much higher IO than Company ({})",
1272 person_cost.io,
1273 company_cost.io
1274 );
1275 }
1276
1277 #[test]
1278 fn test_node_scan_unlabeled_uses_total_nodes() {
1279 let model = CostModel::new().with_graph_totals(10_000, 50_000);
1280
1281 let scan = NodeScanOp {
1282 variable: "n".to_string(),
1283 label: None,
1284 input: None,
1285 };
1286
1287 let cost = model.node_scan_cost(&scan, 10_000.0);
1288 let expected_pages = (10_000.0 * 100.0) / 8192.0;
1289 assert!(
1290 (cost.io - expected_pages).abs() < 0.1,
1291 "Unlabeled scan should use total_nodes for IO: got {}, expected {}",
1292 cost.io,
1293 expected_pages
1294 );
1295 }
1296
1297 #[test]
1298 fn test_join_cost_with_actual_child_cardinalities() {
1299 let model = CostModel::new();
1300 let join = JoinOp {
1301 left: Box::new(LogicalOperator::Empty),
1302 right: Box::new(LogicalOperator::Empty),
1303 join_type: JoinType::Inner,
1304 conditions: vec![JoinCondition {
1305 left: LogicalExpression::Variable("a".to_string()),
1306 right: LogicalExpression::Variable("b".to_string()),
1307 }],
1308 };
1309
1310 let cost_actual = model.join_cost_with_children(&join, 500.0, Some(100.0), Some(10_000.0));
1312
1313 let cost_sqrt = model.join_cost(&join, 500.0);
1315
1316 assert!(
1320 cost_actual.cpu > cost_sqrt.cpu,
1321 "Actual child cardinalities ({}) should produce different cost than sqrt fallback ({})",
1322 cost_actual.cpu,
1323 cost_sqrt.cpu
1324 );
1325 }
1326
1327 #[test]
1328 fn test_expand_multi_edge_types() {
1329 let mut degrees = std::collections::HashMap::new();
1330 degrees.insert("KNOWS".to_string(), (5.0, 5.0));
1331 degrees.insert("FOLLOWS".to_string(), (20.0, 100.0));
1332
1333 let model = CostModel::new().with_edge_type_degrees(degrees);
1334
1335 let multi_expand = ExpandOp {
1337 from_variable: "a".to_string(),
1338 to_variable: "b".to_string(),
1339 edge_variable: None,
1340 direction: ExpandDirection::Outgoing,
1341 edge_types: vec!["KNOWS".to_string(), "FOLLOWS".to_string()],
1342 min_hops: 1,
1343 max_hops: Some(1),
1344 input: Box::new(LogicalOperator::Empty),
1345 path_alias: None,
1346 path_mode: PathMode::Walk,
1347 };
1348 let multi_cost = model.expand_cost(&multi_expand, 100.0);
1349
1350 let single_expand = ExpandOp {
1352 from_variable: "a".to_string(),
1353 to_variable: "b".to_string(),
1354 edge_variable: None,
1355 direction: ExpandDirection::Outgoing,
1356 edge_types: vec!["KNOWS".to_string()],
1357 min_hops: 1,
1358 max_hops: Some(1),
1359 input: Box::new(LogicalOperator::Empty),
1360 path_alias: None,
1361 path_mode: PathMode::Walk,
1362 };
1363 let single_cost = model.expand_cost(&single_expand, 100.0);
1364
1365 assert!(
1367 multi_cost.cpu > single_cost.cpu * 3.0,
1368 "Multi-type fanout ({}) should be much higher than single-type ({})",
1369 multi_cost.cpu,
1370 single_cost.cpu
1371 );
1372 }
1373
1374 #[test]
1375 fn test_recursive_tree_cost() {
1376 use crate::query::optimizer::CardinalityEstimator;
1377
1378 let mut label_cards = std::collections::HashMap::new();
1379 label_cards.insert("Person".to_string(), 1000_u64);
1380
1381 let model = CostModel::new()
1382 .with_label_cardinalities(label_cards)
1383 .with_graph_totals(1000, 5000)
1384 .with_avg_fanout(5.0);
1385
1386 let mut card_est = CardinalityEstimator::new();
1387 card_est.add_table_stats("Person", crate::query::optimizer::TableStats::new(1000));
1388
1389 let plan = LogicalOperator::Return(ReturnOp {
1391 items: vec![ReturnItem {
1392 expression: LogicalExpression::Variable("n".to_string()),
1393 alias: None,
1394 }],
1395 distinct: false,
1396 input: Box::new(LogicalOperator::Filter(FilterOp {
1397 predicate: LogicalExpression::Binary {
1398 left: Box::new(LogicalExpression::Property {
1399 variable: "n".to_string(),
1400 property: "age".to_string(),
1401 }),
1402 op: crate::query::plan::BinaryOp::Gt,
1403 right: Box::new(LogicalExpression::Literal(
1404 grafeo_common::types::Value::Int64(30),
1405 )),
1406 },
1407 input: Box::new(LogicalOperator::NodeScan(NodeScanOp {
1408 variable: "n".to_string(),
1409 label: Some("Person".to_string()),
1410 input: None,
1411 })),
1412 pushdown_hint: None,
1413 })),
1414 });
1415
1416 let tree_cost = model.estimate_tree(&plan, &card_est);
1417
1418 assert!(tree_cost.cpu > 0.0, "Tree should have CPU cost");
1420 assert!(tree_cost.io > 0.0, "Tree should have IO cost from scan");
1421
1422 let root_only_card = card_est.estimate(&plan);
1424 let root_only_cost = model.estimate(&plan, root_only_card);
1425
1426 assert!(
1428 tree_cost.total() > root_only_cost.total(),
1429 "Recursive tree cost ({}) should exceed root-only cost ({})",
1430 tree_cost.total(),
1431 root_only_cost.total()
1432 );
1433 }
1434
1435 #[test]
1436 fn test_statistics_driven_vs_default_cost() {
1437 let default_model = CostModel::new();
1438
1439 let mut label_cards = std::collections::HashMap::new();
1440 label_cards.insert("Person".to_string(), 100_u64);
1441 let stats_model = CostModel::new()
1442 .with_label_cardinalities(label_cards)
1443 .with_graph_totals(100, 500);
1444
1445 let scan = NodeScanOp {
1447 variable: "n".to_string(),
1448 label: Some("Person".to_string()),
1449 input: None,
1450 };
1451
1452 let default_cost = default_model.node_scan_cost(&scan, 100.0);
1453 let stats_cost = stats_model.node_scan_cost(&scan, 100.0);
1454
1455 assert!(
1459 (default_cost.io - stats_cost.io).abs() < 0.1,
1460 "When cardinality matches label size, costs should be similar"
1461 );
1462 }
1463}