1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7 NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15 pub cpu: f64,
17 pub io: f64,
19 pub memory: f64,
21 pub network: f64,
23}
24
25impl Cost {
26 #[must_use]
28 pub fn zero() -> Self {
29 Self {
30 cpu: 0.0,
31 io: 0.0,
32 memory: 0.0,
33 network: 0.0,
34 }
35 }
36
37 #[must_use]
39 pub fn cpu(cpu: f64) -> Self {
40 Self {
41 cpu,
42 io: 0.0,
43 memory: 0.0,
44 network: 0.0,
45 }
46 }
47
48 #[must_use]
50 pub fn with_io(mut self, io: f64) -> Self {
51 self.io = io;
52 self
53 }
54
55 #[must_use]
57 pub fn with_memory(mut self, memory: f64) -> Self {
58 self.memory = memory;
59 self
60 }
61
62 #[must_use]
66 pub fn total(&self) -> f64 {
67 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68 }
69
70 #[must_use]
72 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74 }
75}
76
77impl std::ops::Add for Cost {
78 type Output = Self;
79
80 fn add(self, other: Self) -> Self {
81 Self {
82 cpu: self.cpu + other.cpu,
83 io: self.io + other.io,
84 memory: self.memory + other.memory,
85 network: self.network + other.network,
86 }
87 }
88}
89
90impl std::ops::AddAssign for Cost {
91 fn add_assign(&mut self, other: Self) {
92 self.cpu += other.cpu;
93 self.io += other.io;
94 self.memory += other.memory;
95 self.network += other.network;
96 }
97}
98
99pub struct CostModel {
101 cpu_tuple_cost: f64,
103 #[allow(dead_code)]
105 io_page_cost: f64,
106 hash_lookup_cost: f64,
108 sort_comparison_cost: f64,
110 avg_tuple_size: f64,
112 page_size: f64,
114}
115
116impl CostModel {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 cpu_tuple_cost: 0.01,
122 io_page_cost: 1.0,
123 hash_lookup_cost: 0.02,
124 sort_comparison_cost: 0.02,
125 avg_tuple_size: 100.0,
126 page_size: 8192.0,
127 }
128 }
129
130 #[must_use]
132 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133 match op {
134 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145 LogicalOperator::Empty => Cost::zero(),
146 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
147 }
148 }
149
150 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
152 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
153 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
154 }
155
156 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
158 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
160 }
161
162 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
164 let expr_count = project.projections.len() as f64;
166 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
167 }
168
169 fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
171 let lookup_cost = cardinality * self.hash_lookup_cost;
173 let avg_fanout = 10.0;
175 let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
176 Cost::cpu(lookup_cost + output_cost)
177 }
178
179 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
181 match join.join_type {
183 JoinType::Cross => {
184 Cost::cpu(cardinality * self.cpu_tuple_cost)
186 }
187 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
188 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
192
193 let build_cost = build_cardinality * self.hash_lookup_cost;
195 let memory_cost = build_cardinality * self.avg_tuple_size;
196
197 let probe_cost = probe_cardinality * self.hash_lookup_cost;
199
200 let output_cost = cardinality * self.cpu_tuple_cost;
202
203 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
204 }
205 JoinType::Semi | JoinType::Anti => {
206 let build_cardinality = cardinality.sqrt();
208 let probe_cardinality = cardinality.sqrt();
209
210 let build_cost = build_cardinality * self.hash_lookup_cost;
211 let probe_cost = probe_cardinality * self.hash_lookup_cost;
212
213 Cost::cpu(build_cost + probe_cost)
214 .with_memory(build_cardinality * self.avg_tuple_size)
215 }
216 }
217 }
218
219 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
221 let hash_cost = cardinality * self.hash_lookup_cost;
223
224 let agg_count = agg.aggregates.len() as f64;
226 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
227
228 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
231
232 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
233 }
234
235 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
237 if cardinality <= 1.0 {
238 return Cost::zero();
239 }
240
241 let comparisons = cardinality * cardinality.log2();
243 let key_count = sort.keys.len() as f64;
244
245 let memory_cost = cardinality * self.avg_tuple_size;
247
248 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
249 }
250
251 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
253 let hash_cost = cardinality * self.hash_lookup_cost;
255 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
258 }
259
260 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
262 Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
264 }
265
266 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
268 Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
270 }
271
272 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
274 let expr_count = ret.items.len() as f64;
276 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
277 }
278
279 #[must_use]
281 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
282 if a.total() <= b.total() { a } else { b }
283 }
284
285 #[must_use]
301 pub fn leapfrog_join_cost(
302 &self,
303 num_relations: usize,
304 cardinalities: &[f64],
305 output_cardinality: f64,
306 ) -> Cost {
307 if cardinalities.is_empty() {
308 return Cost::zero();
309 }
310
311 let total_input: f64 = cardinalities.iter().sum();
312 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
313
314 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
319 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
320 } else {
321 output_cardinality * self.cpu_tuple_cost
322 };
323
324 let output_cost = output_cardinality * self.cpu_tuple_cost;
326
327 let memory = total_input * self.avg_tuple_size * 2.0;
329
330 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
331 }
332
333 #[must_use]
337 pub fn prefer_leapfrog_join(
338 &self,
339 num_relations: usize,
340 cardinalities: &[f64],
341 output_cardinality: f64,
342 ) -> bool {
343 if num_relations < 3 || cardinalities.len() < 3 {
344 return false;
346 }
347
348 let leapfrog_cost =
349 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
350
351 let mut hash_cascade_cost = Cost::zero();
355 let mut intermediate_cardinality = cardinalities[0];
356
357 for card in &cardinalities[1..] {
358 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
361 left: Box::new(LogicalOperator::Empty),
362 right: Box::new(LogicalOperator::Empty),
363 join_type: JoinType::Inner,
364 conditions: vec![],
365 };
366 hash_cascade_cost += self.join_cost(&join, join_output);
367 intermediate_cardinality = join_output;
368 }
369
370 leapfrog_cost.total() < hash_cascade_cost.total()
371 }
372
373 #[must_use]
381 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
382 if num_hops <= 1 || avg_fanout <= 1.0 {
383 return 1.0; }
385
386 let full_size = avg_fanout.powi(num_hops as i32);
392 let factorized_size = if avg_fanout > 1.0 {
393 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
394 } else {
395 num_hops as f64
396 };
397
398 (factorized_size / full_size).min(1.0)
399 }
400}
401
402impl Default for CostModel {
403 fn default() -> Self {
404 Self::new()
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use crate::query::plan::{
412 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
413 Projection, ReturnItem, SortOrder,
414 };
415
416 #[test]
417 fn test_cost_addition() {
418 let a = Cost::cpu(10.0).with_io(5.0);
419 let b = Cost::cpu(20.0).with_memory(100.0);
420 let c = a + b;
421
422 assert!((c.cpu - 30.0).abs() < 0.001);
423 assert!((c.io - 5.0).abs() < 0.001);
424 assert!((c.memory - 100.0).abs() < 0.001);
425 }
426
427 #[test]
428 fn test_cost_total() {
429 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
430 assert!((cost.total() - 30.0).abs() < 0.001);
432 }
433
434 #[test]
435 fn test_cost_model_node_scan() {
436 let model = CostModel::new();
437 let scan = NodeScanOp {
438 variable: "n".to_string(),
439 label: Some("Person".to_string()),
440 input: None,
441 };
442 let cost = model.node_scan_cost(&scan, 1000.0);
443
444 assert!(cost.cpu > 0.0);
445 assert!(cost.io > 0.0);
446 }
447
448 #[test]
449 fn test_cost_model_sort() {
450 let model = CostModel::new();
451 let sort = SortOp {
452 keys: vec![],
453 input: Box::new(LogicalOperator::Empty),
454 };
455
456 let cost_100 = model.sort_cost(&sort, 100.0);
457 let cost_1000 = model.sort_cost(&sort, 1000.0);
458
459 assert!(cost_1000.total() > cost_100.total());
461 }
462
463 #[test]
464 fn test_cost_zero() {
465 let cost = Cost::zero();
466 assert!((cost.cpu).abs() < 0.001);
467 assert!((cost.io).abs() < 0.001);
468 assert!((cost.memory).abs() < 0.001);
469 assert!((cost.network).abs() < 0.001);
470 assert!((cost.total()).abs() < 0.001);
471 }
472
473 #[test]
474 fn test_cost_add_assign() {
475 let mut cost = Cost::cpu(10.0);
476 cost += Cost::cpu(5.0).with_io(2.0);
477 assert!((cost.cpu - 15.0).abs() < 0.001);
478 assert!((cost.io - 2.0).abs() < 0.001);
479 }
480
481 #[test]
482 fn test_cost_total_weighted() {
483 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
484 let total = cost.total_weighted(2.0, 5.0, 0.5);
486 assert!((total - 80.0).abs() < 0.001);
487 }
488
489 #[test]
490 fn test_cost_model_filter() {
491 let model = CostModel::new();
492 let filter = FilterOp {
493 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
494 input: Box::new(LogicalOperator::Empty),
495 };
496 let cost = model.filter_cost(&filter, 1000.0);
497
498 assert!(cost.cpu > 0.0);
500 assert!((cost.io).abs() < 0.001);
501 }
502
503 #[test]
504 fn test_cost_model_project() {
505 let model = CostModel::new();
506 let project = ProjectOp {
507 projections: vec![
508 Projection {
509 expression: LogicalExpression::Variable("a".to_string()),
510 alias: None,
511 },
512 Projection {
513 expression: LogicalExpression::Variable("b".to_string()),
514 alias: None,
515 },
516 ],
517 input: Box::new(LogicalOperator::Empty),
518 };
519 let cost = model.project_cost(&project, 1000.0);
520
521 assert!(cost.cpu > 0.0);
523 }
524
525 #[test]
526 fn test_cost_model_expand() {
527 let model = CostModel::new();
528 let expand = ExpandOp {
529 from_variable: "a".to_string(),
530 to_variable: "b".to_string(),
531 edge_variable: None,
532 direction: ExpandDirection::Outgoing,
533 edge_type: None,
534 min_hops: 1,
535 max_hops: Some(1),
536 input: Box::new(LogicalOperator::Empty),
537 path_alias: None,
538 };
539 let cost = model.expand_cost(&expand, 1000.0);
540
541 assert!(cost.cpu > 0.0);
543 }
544
545 #[test]
546 fn test_cost_model_hash_join() {
547 let model = CostModel::new();
548 let join = JoinOp {
549 left: Box::new(LogicalOperator::Empty),
550 right: Box::new(LogicalOperator::Empty),
551 join_type: JoinType::Inner,
552 conditions: vec![JoinCondition {
553 left: LogicalExpression::Variable("a".to_string()),
554 right: LogicalExpression::Variable("b".to_string()),
555 }],
556 };
557 let cost = model.join_cost(&join, 10000.0);
558
559 assert!(cost.cpu > 0.0);
561 assert!(cost.memory > 0.0);
562 }
563
564 #[test]
565 fn test_cost_model_cross_join() {
566 let model = CostModel::new();
567 let join = JoinOp {
568 left: Box::new(LogicalOperator::Empty),
569 right: Box::new(LogicalOperator::Empty),
570 join_type: JoinType::Cross,
571 conditions: vec![],
572 };
573 let cost = model.join_cost(&join, 1000000.0);
574
575 assert!(cost.cpu > 0.0);
577 }
578
579 #[test]
580 fn test_cost_model_semi_join() {
581 let model = CostModel::new();
582 let join = JoinOp {
583 left: Box::new(LogicalOperator::Empty),
584 right: Box::new(LogicalOperator::Empty),
585 join_type: JoinType::Semi,
586 conditions: vec![],
587 };
588 let cost_semi = model.join_cost(&join, 1000.0);
589
590 let inner_join = JoinOp {
591 left: Box::new(LogicalOperator::Empty),
592 right: Box::new(LogicalOperator::Empty),
593 join_type: JoinType::Inner,
594 conditions: vec![],
595 };
596 let cost_inner = model.join_cost(&inner_join, 1000.0);
597
598 assert!(cost_semi.cpu > 0.0);
600 assert!(cost_inner.cpu > 0.0);
601 }
602
603 #[test]
604 fn test_cost_model_aggregate() {
605 let model = CostModel::new();
606 let agg = AggregateOp {
607 group_by: vec![],
608 aggregates: vec![
609 AggregateExpr {
610 function: AggregateFunction::Count,
611 expression: None,
612 distinct: false,
613 alias: Some("cnt".to_string()),
614 percentile: None,
615 },
616 AggregateExpr {
617 function: AggregateFunction::Sum,
618 expression: Some(LogicalExpression::Variable("x".to_string())),
619 distinct: false,
620 alias: Some("total".to_string()),
621 percentile: None,
622 },
623 ],
624 input: Box::new(LogicalOperator::Empty),
625 having: None,
626 };
627 let cost = model.aggregate_cost(&agg, 1000.0);
628
629 assert!(cost.cpu > 0.0);
631 assert!(cost.memory > 0.0);
632 }
633
634 #[test]
635 fn test_cost_model_distinct() {
636 let model = CostModel::new();
637 let distinct = DistinctOp {
638 input: Box::new(LogicalOperator::Empty),
639 columns: None,
640 };
641 let cost = model.distinct_cost(&distinct, 1000.0);
642
643 assert!(cost.cpu > 0.0);
645 assert!(cost.memory > 0.0);
646 }
647
648 #[test]
649 fn test_cost_model_limit() {
650 let model = CostModel::new();
651 let limit = LimitOp {
652 count: 10,
653 input: Box::new(LogicalOperator::Empty),
654 };
655 let cost = model.limit_cost(&limit, 1000.0);
656
657 assert!(cost.cpu > 0.0);
659 assert!(cost.cpu < 1.0); }
661
662 #[test]
663 fn test_cost_model_skip() {
664 let model = CostModel::new();
665 let skip = SkipOp {
666 count: 100,
667 input: Box::new(LogicalOperator::Empty),
668 };
669 let cost = model.skip_cost(&skip, 1000.0);
670
671 assert!(cost.cpu > 0.0);
673 }
674
675 #[test]
676 fn test_cost_model_return() {
677 let model = CostModel::new();
678 let ret = ReturnOp {
679 items: vec![
680 ReturnItem {
681 expression: LogicalExpression::Variable("a".to_string()),
682 alias: None,
683 },
684 ReturnItem {
685 expression: LogicalExpression::Variable("b".to_string()),
686 alias: None,
687 },
688 ],
689 distinct: false,
690 input: Box::new(LogicalOperator::Empty),
691 };
692 let cost = model.return_cost(&ret, 1000.0);
693
694 assert!(cost.cpu > 0.0);
696 }
697
698 #[test]
699 fn test_cost_cheaper() {
700 let model = CostModel::new();
701 let cheap = Cost::cpu(10.0);
702 let expensive = Cost::cpu(100.0);
703
704 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
705 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
706 }
707
708 #[test]
709 fn test_cost_comparison_prefers_lower_total() {
710 let model = CostModel::new();
711 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
713 let io_heavy = Cost::cpu(10.0).with_io(20.0);
715
716 assert!(cpu_heavy.total() < io_heavy.total());
718 assert_eq!(
719 model.cheaper(&cpu_heavy, &io_heavy).total(),
720 cpu_heavy.total()
721 );
722 }
723
724 #[test]
725 fn test_cost_model_sort_with_keys() {
726 let model = CostModel::new();
727 let sort_single = SortOp {
728 keys: vec![crate::query::plan::SortKey {
729 expression: LogicalExpression::Variable("a".to_string()),
730 order: SortOrder::Ascending,
731 }],
732 input: Box::new(LogicalOperator::Empty),
733 };
734 let sort_multi = SortOp {
735 keys: vec![
736 crate::query::plan::SortKey {
737 expression: LogicalExpression::Variable("a".to_string()),
738 order: SortOrder::Ascending,
739 },
740 crate::query::plan::SortKey {
741 expression: LogicalExpression::Variable("b".to_string()),
742 order: SortOrder::Descending,
743 },
744 ],
745 input: Box::new(LogicalOperator::Empty),
746 };
747
748 let cost_single = model.sort_cost(&sort_single, 1000.0);
749 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
750
751 assert!(cost_multi.cpu > cost_single.cpu);
753 }
754
755 #[test]
756 fn test_cost_model_empty_operator() {
757 let model = CostModel::new();
758 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
759 assert!((cost.total()).abs() < 0.001);
760 }
761
762 #[test]
763 fn test_cost_model_default() {
764 let model = CostModel::default();
765 let scan = NodeScanOp {
766 variable: "n".to_string(),
767 label: None,
768 input: None,
769 };
770 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
771 assert!(cost.total() > 0.0);
772 }
773
774 #[test]
775 fn test_leapfrog_join_cost() {
776 let model = CostModel::new();
777
778 let cardinalities = vec![1000.0, 1000.0, 1000.0];
780 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
781
782 assert!(cost.cpu > 0.0);
784 assert!(cost.memory > 0.0);
786 }
787
788 #[test]
789 fn test_leapfrog_join_cost_empty() {
790 let model = CostModel::new();
791 let cost = model.leapfrog_join_cost(0, &[], 0.0);
792 assert!((cost.total()).abs() < 0.001);
793 }
794
795 #[test]
796 fn test_prefer_leapfrog_join_for_triangles() {
797 let model = CostModel::new();
798
799 let cardinalities = vec![10000.0, 10000.0, 10000.0];
801 let output = 1000.0;
802
803 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
804
805 assert!(leapfrog_cost.cpu > 0.0);
807 assert!(leapfrog_cost.memory > 0.0);
808
809 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
812 }
814
815 #[test]
816 fn test_prefer_leapfrog_join_binary_case() {
817 let model = CostModel::new();
818
819 let cardinalities = vec![1000.0, 1000.0];
821 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
822 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
823 }
824
825 #[test]
826 fn test_factorized_benefit_single_hop() {
827 let model = CostModel::new();
828
829 let benefit = model.factorized_benefit(10.0, 1);
831 assert!(
832 (benefit - 1.0).abs() < 0.001,
833 "Single hop should have no benefit"
834 );
835 }
836
837 #[test]
838 fn test_factorized_benefit_multi_hop() {
839 let model = CostModel::new();
840
841 let benefit = model.factorized_benefit(10.0, 3);
843
844 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
848 assert!(benefit > 0.0, "Benefit should be positive");
849 }
850
851 #[test]
852 fn test_factorized_benefit_low_fanout() {
853 let model = CostModel::new();
854
855 let benefit = model.factorized_benefit(1.5, 2);
857 assert!(
858 benefit <= 1.0,
859 "Low fanout still benefits from factorization"
860 );
861 }
862}