1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp, LogicalOperator,
7 NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15 pub cpu: f64,
17 pub io: f64,
19 pub memory: f64,
21 pub network: f64,
23}
24
25impl Cost {
26 #[must_use]
28 pub fn zero() -> Self {
29 Self {
30 cpu: 0.0,
31 io: 0.0,
32 memory: 0.0,
33 network: 0.0,
34 }
35 }
36
37 #[must_use]
39 pub fn cpu(cpu: f64) -> Self {
40 Self {
41 cpu,
42 io: 0.0,
43 memory: 0.0,
44 network: 0.0,
45 }
46 }
47
48 #[must_use]
50 pub fn with_io(mut self, io: f64) -> Self {
51 self.io = io;
52 self
53 }
54
55 #[must_use]
57 pub fn with_memory(mut self, memory: f64) -> Self {
58 self.memory = memory;
59 self
60 }
61
62 #[must_use]
66 pub fn total(&self) -> f64 {
67 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68 }
69
70 #[must_use]
72 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74 }
75}
76
77impl std::ops::Add for Cost {
78 type Output = Self;
79
80 fn add(self, other: Self) -> Self {
81 Self {
82 cpu: self.cpu + other.cpu,
83 io: self.io + other.io,
84 memory: self.memory + other.memory,
85 network: self.network + other.network,
86 }
87 }
88}
89
90impl std::ops::AddAssign for Cost {
91 fn add_assign(&mut self, other: Self) {
92 self.cpu += other.cpu;
93 self.io += other.io;
94 self.memory += other.memory;
95 self.network += other.network;
96 }
97}
98
99pub struct CostModel {
101 cpu_tuple_cost: f64,
103 #[allow(dead_code)]
105 io_page_cost: f64,
106 hash_lookup_cost: f64,
108 sort_comparison_cost: f64,
110 avg_tuple_size: f64,
112 page_size: f64,
114}
115
116impl CostModel {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 cpu_tuple_cost: 0.01,
122 io_page_cost: 1.0,
123 hash_lookup_cost: 0.02,
124 sort_comparison_cost: 0.02,
125 avg_tuple_size: 100.0,
126 page_size: 8192.0,
127 }
128 }
129
130 #[must_use]
132 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
133 match op {
134 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
135 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
136 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
137 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
138 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
139 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
140 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
141 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
142 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
143 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
144 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
145 LogicalOperator::Empty => Cost::zero(),
146 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
147 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
148 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
149 }
150 }
151
152 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
154 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
155 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
156 }
157
158 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
160 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
162 }
163
164 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
166 let expr_count = project.projections.len() as f64;
168 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
169 }
170
171 fn expand_cost(&self, _expand: &ExpandOp, cardinality: f64) -> Cost {
173 let lookup_cost = cardinality * self.hash_lookup_cost;
175 let avg_fanout = 10.0;
177 let output_cost = cardinality * avg_fanout * self.cpu_tuple_cost;
178 Cost::cpu(lookup_cost + output_cost)
179 }
180
181 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
183 match join.join_type {
185 JoinType::Cross => {
186 Cost::cpu(cardinality * self.cpu_tuple_cost)
188 }
189 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
190 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
194
195 let build_cost = build_cardinality * self.hash_lookup_cost;
197 let memory_cost = build_cardinality * self.avg_tuple_size;
198
199 let probe_cost = probe_cardinality * self.hash_lookup_cost;
201
202 let output_cost = cardinality * self.cpu_tuple_cost;
204
205 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
206 }
207 JoinType::Semi | JoinType::Anti => {
208 let build_cardinality = cardinality.sqrt();
210 let probe_cardinality = cardinality.sqrt();
211
212 let build_cost = build_cardinality * self.hash_lookup_cost;
213 let probe_cost = probe_cardinality * self.hash_lookup_cost;
214
215 Cost::cpu(build_cost + probe_cost)
216 .with_memory(build_cardinality * self.avg_tuple_size)
217 }
218 }
219 }
220
221 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
223 let hash_cost = cardinality * self.hash_lookup_cost;
225
226 let agg_count = agg.aggregates.len() as f64;
228 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
229
230 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
233
234 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
235 }
236
237 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
239 if cardinality <= 1.0 {
240 return Cost::zero();
241 }
242
243 let comparisons = cardinality * cardinality.log2();
245 let key_count = sort.keys.len() as f64;
246
247 let memory_cost = cardinality * self.avg_tuple_size;
249
250 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
251 }
252
253 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
255 let hash_cost = cardinality * self.hash_lookup_cost;
257 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
260 }
261
262 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
264 Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
266 }
267
268 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
270 Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
272 }
273
274 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
276 let expr_count = ret.items.len() as f64;
278 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
279 }
280
281 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
286 let k = scan.k as f64;
288
289 let ef = 64.0;
292 let n = cardinality.max(1.0);
293 let search_cost = if scan.index_name.is_some() {
294 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
297 n * self.cpu_tuple_cost * 10.0
299 };
300
301 let memory = k * self.avg_tuple_size * 2.0;
303
304 Cost::cpu(search_cost).with_memory(memory)
305 }
306
307 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
311 let k = join.k as f64;
312
313 let per_row_search_cost = if join.index_name.is_some() {
316 let ef = 64.0;
318 let n = cardinality.max(1.0);
319 ef * n.ln() * self.cpu_tuple_cost * 10.0
320 } else {
321 cardinality * self.cpu_tuple_cost * 10.0
323 };
324
325 let input_cardinality = (cardinality / k).max(1.0);
328 let total_search_cost = input_cardinality * per_row_search_cost;
329
330 let memory = cardinality * self.avg_tuple_size;
332
333 Cost::cpu(total_search_cost).with_memory(memory)
334 }
335
336 #[must_use]
338 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
339 if a.total() <= b.total() { a } else { b }
340 }
341
342 #[must_use]
358 pub fn leapfrog_join_cost(
359 &self,
360 num_relations: usize,
361 cardinalities: &[f64],
362 output_cardinality: f64,
363 ) -> Cost {
364 if cardinalities.is_empty() {
365 return Cost::zero();
366 }
367
368 let total_input: f64 = cardinalities.iter().sum();
369 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
370
371 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
376 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
377 } else {
378 output_cardinality * self.cpu_tuple_cost
379 };
380
381 let output_cost = output_cardinality * self.cpu_tuple_cost;
383
384 let memory = total_input * self.avg_tuple_size * 2.0;
386
387 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
388 }
389
390 #[must_use]
394 pub fn prefer_leapfrog_join(
395 &self,
396 num_relations: usize,
397 cardinalities: &[f64],
398 output_cardinality: f64,
399 ) -> bool {
400 if num_relations < 3 || cardinalities.len() < 3 {
401 return false;
403 }
404
405 let leapfrog_cost =
406 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
407
408 let mut hash_cascade_cost = Cost::zero();
412 let mut intermediate_cardinality = cardinalities[0];
413
414 for card in &cardinalities[1..] {
415 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
418 left: Box::new(LogicalOperator::Empty),
419 right: Box::new(LogicalOperator::Empty),
420 join_type: JoinType::Inner,
421 conditions: vec![],
422 };
423 hash_cascade_cost += self.join_cost(&join, join_output);
424 intermediate_cardinality = join_output;
425 }
426
427 leapfrog_cost.total() < hash_cascade_cost.total()
428 }
429
430 #[must_use]
438 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
439 if num_hops <= 1 || avg_fanout <= 1.0 {
440 return 1.0; }
442
443 let full_size = avg_fanout.powi(num_hops as i32);
449 let factorized_size = if avg_fanout > 1.0 {
450 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
451 } else {
452 num_hops as f64
453 };
454
455 (factorized_size / full_size).min(1.0)
456 }
457}
458
459impl Default for CostModel {
460 fn default() -> Self {
461 Self::new()
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::query::plan::{
469 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
470 Projection, ReturnItem, SortOrder,
471 };
472
473 #[test]
474 fn test_cost_addition() {
475 let a = Cost::cpu(10.0).with_io(5.0);
476 let b = Cost::cpu(20.0).with_memory(100.0);
477 let c = a + b;
478
479 assert!((c.cpu - 30.0).abs() < 0.001);
480 assert!((c.io - 5.0).abs() < 0.001);
481 assert!((c.memory - 100.0).abs() < 0.001);
482 }
483
484 #[test]
485 fn test_cost_total() {
486 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
487 assert!((cost.total() - 30.0).abs() < 0.001);
489 }
490
491 #[test]
492 fn test_cost_model_node_scan() {
493 let model = CostModel::new();
494 let scan = NodeScanOp {
495 variable: "n".to_string(),
496 label: Some("Person".to_string()),
497 input: None,
498 };
499 let cost = model.node_scan_cost(&scan, 1000.0);
500
501 assert!(cost.cpu > 0.0);
502 assert!(cost.io > 0.0);
503 }
504
505 #[test]
506 fn test_cost_model_sort() {
507 let model = CostModel::new();
508 let sort = SortOp {
509 keys: vec![],
510 input: Box::new(LogicalOperator::Empty),
511 };
512
513 let cost_100 = model.sort_cost(&sort, 100.0);
514 let cost_1000 = model.sort_cost(&sort, 1000.0);
515
516 assert!(cost_1000.total() > cost_100.total());
518 }
519
520 #[test]
521 fn test_cost_zero() {
522 let cost = Cost::zero();
523 assert!((cost.cpu).abs() < 0.001);
524 assert!((cost.io).abs() < 0.001);
525 assert!((cost.memory).abs() < 0.001);
526 assert!((cost.network).abs() < 0.001);
527 assert!((cost.total()).abs() < 0.001);
528 }
529
530 #[test]
531 fn test_cost_add_assign() {
532 let mut cost = Cost::cpu(10.0);
533 cost += Cost::cpu(5.0).with_io(2.0);
534 assert!((cost.cpu - 15.0).abs() < 0.001);
535 assert!((cost.io - 2.0).abs() < 0.001);
536 }
537
538 #[test]
539 fn test_cost_total_weighted() {
540 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
541 let total = cost.total_weighted(2.0, 5.0, 0.5);
543 assert!((total - 80.0).abs() < 0.001);
544 }
545
546 #[test]
547 fn test_cost_model_filter() {
548 let model = CostModel::new();
549 let filter = FilterOp {
550 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
551 input: Box::new(LogicalOperator::Empty),
552 };
553 let cost = model.filter_cost(&filter, 1000.0);
554
555 assert!(cost.cpu > 0.0);
557 assert!((cost.io).abs() < 0.001);
558 }
559
560 #[test]
561 fn test_cost_model_project() {
562 let model = CostModel::new();
563 let project = ProjectOp {
564 projections: vec![
565 Projection {
566 expression: LogicalExpression::Variable("a".to_string()),
567 alias: None,
568 },
569 Projection {
570 expression: LogicalExpression::Variable("b".to_string()),
571 alias: None,
572 },
573 ],
574 input: Box::new(LogicalOperator::Empty),
575 };
576 let cost = model.project_cost(&project, 1000.0);
577
578 assert!(cost.cpu > 0.0);
580 }
581
582 #[test]
583 fn test_cost_model_expand() {
584 let model = CostModel::new();
585 let expand = ExpandOp {
586 from_variable: "a".to_string(),
587 to_variable: "b".to_string(),
588 edge_variable: None,
589 direction: ExpandDirection::Outgoing,
590 edge_type: None,
591 min_hops: 1,
592 max_hops: Some(1),
593 input: Box::new(LogicalOperator::Empty),
594 path_alias: None,
595 };
596 let cost = model.expand_cost(&expand, 1000.0);
597
598 assert!(cost.cpu > 0.0);
600 }
601
602 #[test]
603 fn test_cost_model_hash_join() {
604 let model = CostModel::new();
605 let join = JoinOp {
606 left: Box::new(LogicalOperator::Empty),
607 right: Box::new(LogicalOperator::Empty),
608 join_type: JoinType::Inner,
609 conditions: vec![JoinCondition {
610 left: LogicalExpression::Variable("a".to_string()),
611 right: LogicalExpression::Variable("b".to_string()),
612 }],
613 };
614 let cost = model.join_cost(&join, 10000.0);
615
616 assert!(cost.cpu > 0.0);
618 assert!(cost.memory > 0.0);
619 }
620
621 #[test]
622 fn test_cost_model_cross_join() {
623 let model = CostModel::new();
624 let join = JoinOp {
625 left: Box::new(LogicalOperator::Empty),
626 right: Box::new(LogicalOperator::Empty),
627 join_type: JoinType::Cross,
628 conditions: vec![],
629 };
630 let cost = model.join_cost(&join, 1000000.0);
631
632 assert!(cost.cpu > 0.0);
634 }
635
636 #[test]
637 fn test_cost_model_semi_join() {
638 let model = CostModel::new();
639 let join = JoinOp {
640 left: Box::new(LogicalOperator::Empty),
641 right: Box::new(LogicalOperator::Empty),
642 join_type: JoinType::Semi,
643 conditions: vec![],
644 };
645 let cost_semi = model.join_cost(&join, 1000.0);
646
647 let inner_join = JoinOp {
648 left: Box::new(LogicalOperator::Empty),
649 right: Box::new(LogicalOperator::Empty),
650 join_type: JoinType::Inner,
651 conditions: vec![],
652 };
653 let cost_inner = model.join_cost(&inner_join, 1000.0);
654
655 assert!(cost_semi.cpu > 0.0);
657 assert!(cost_inner.cpu > 0.0);
658 }
659
660 #[test]
661 fn test_cost_model_aggregate() {
662 let model = CostModel::new();
663 let agg = AggregateOp {
664 group_by: vec![],
665 aggregates: vec![
666 AggregateExpr {
667 function: AggregateFunction::Count,
668 expression: None,
669 distinct: false,
670 alias: Some("cnt".to_string()),
671 percentile: None,
672 },
673 AggregateExpr {
674 function: AggregateFunction::Sum,
675 expression: Some(LogicalExpression::Variable("x".to_string())),
676 distinct: false,
677 alias: Some("total".to_string()),
678 percentile: None,
679 },
680 ],
681 input: Box::new(LogicalOperator::Empty),
682 having: None,
683 };
684 let cost = model.aggregate_cost(&agg, 1000.0);
685
686 assert!(cost.cpu > 0.0);
688 assert!(cost.memory > 0.0);
689 }
690
691 #[test]
692 fn test_cost_model_distinct() {
693 let model = CostModel::new();
694 let distinct = DistinctOp {
695 input: Box::new(LogicalOperator::Empty),
696 columns: None,
697 };
698 let cost = model.distinct_cost(&distinct, 1000.0);
699
700 assert!(cost.cpu > 0.0);
702 assert!(cost.memory > 0.0);
703 }
704
705 #[test]
706 fn test_cost_model_limit() {
707 let model = CostModel::new();
708 let limit = LimitOp {
709 count: 10,
710 input: Box::new(LogicalOperator::Empty),
711 };
712 let cost = model.limit_cost(&limit, 1000.0);
713
714 assert!(cost.cpu > 0.0);
716 assert!(cost.cpu < 1.0); }
718
719 #[test]
720 fn test_cost_model_skip() {
721 let model = CostModel::new();
722 let skip = SkipOp {
723 count: 100,
724 input: Box::new(LogicalOperator::Empty),
725 };
726 let cost = model.skip_cost(&skip, 1000.0);
727
728 assert!(cost.cpu > 0.0);
730 }
731
732 #[test]
733 fn test_cost_model_return() {
734 let model = CostModel::new();
735 let ret = ReturnOp {
736 items: vec![
737 ReturnItem {
738 expression: LogicalExpression::Variable("a".to_string()),
739 alias: None,
740 },
741 ReturnItem {
742 expression: LogicalExpression::Variable("b".to_string()),
743 alias: None,
744 },
745 ],
746 distinct: false,
747 input: Box::new(LogicalOperator::Empty),
748 };
749 let cost = model.return_cost(&ret, 1000.0);
750
751 assert!(cost.cpu > 0.0);
753 }
754
755 #[test]
756 fn test_cost_cheaper() {
757 let model = CostModel::new();
758 let cheap = Cost::cpu(10.0);
759 let expensive = Cost::cpu(100.0);
760
761 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
762 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
763 }
764
765 #[test]
766 fn test_cost_comparison_prefers_lower_total() {
767 let model = CostModel::new();
768 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
770 let io_heavy = Cost::cpu(10.0).with_io(20.0);
772
773 assert!(cpu_heavy.total() < io_heavy.total());
775 assert_eq!(
776 model.cheaper(&cpu_heavy, &io_heavy).total(),
777 cpu_heavy.total()
778 );
779 }
780
781 #[test]
782 fn test_cost_model_sort_with_keys() {
783 let model = CostModel::new();
784 let sort_single = SortOp {
785 keys: vec![crate::query::plan::SortKey {
786 expression: LogicalExpression::Variable("a".to_string()),
787 order: SortOrder::Ascending,
788 }],
789 input: Box::new(LogicalOperator::Empty),
790 };
791 let sort_multi = SortOp {
792 keys: vec![
793 crate::query::plan::SortKey {
794 expression: LogicalExpression::Variable("a".to_string()),
795 order: SortOrder::Ascending,
796 },
797 crate::query::plan::SortKey {
798 expression: LogicalExpression::Variable("b".to_string()),
799 order: SortOrder::Descending,
800 },
801 ],
802 input: Box::new(LogicalOperator::Empty),
803 };
804
805 let cost_single = model.sort_cost(&sort_single, 1000.0);
806 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
807
808 assert!(cost_multi.cpu > cost_single.cpu);
810 }
811
812 #[test]
813 fn test_cost_model_empty_operator() {
814 let model = CostModel::new();
815 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
816 assert!((cost.total()).abs() < 0.001);
817 }
818
819 #[test]
820 fn test_cost_model_default() {
821 let model = CostModel::default();
822 let scan = NodeScanOp {
823 variable: "n".to_string(),
824 label: None,
825 input: None,
826 };
827 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
828 assert!(cost.total() > 0.0);
829 }
830
831 #[test]
832 fn test_leapfrog_join_cost() {
833 let model = CostModel::new();
834
835 let cardinalities = vec![1000.0, 1000.0, 1000.0];
837 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
838
839 assert!(cost.cpu > 0.0);
841 assert!(cost.memory > 0.0);
843 }
844
845 #[test]
846 fn test_leapfrog_join_cost_empty() {
847 let model = CostModel::new();
848 let cost = model.leapfrog_join_cost(0, &[], 0.0);
849 assert!((cost.total()).abs() < 0.001);
850 }
851
852 #[test]
853 fn test_prefer_leapfrog_join_for_triangles() {
854 let model = CostModel::new();
855
856 let cardinalities = vec![10000.0, 10000.0, 10000.0];
858 let output = 1000.0;
859
860 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
861
862 assert!(leapfrog_cost.cpu > 0.0);
864 assert!(leapfrog_cost.memory > 0.0);
865
866 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
869 }
871
872 #[test]
873 fn test_prefer_leapfrog_join_binary_case() {
874 let model = CostModel::new();
875
876 let cardinalities = vec![1000.0, 1000.0];
878 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
879 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
880 }
881
882 #[test]
883 fn test_factorized_benefit_single_hop() {
884 let model = CostModel::new();
885
886 let benefit = model.factorized_benefit(10.0, 1);
888 assert!(
889 (benefit - 1.0).abs() < 0.001,
890 "Single hop should have no benefit"
891 );
892 }
893
894 #[test]
895 fn test_factorized_benefit_multi_hop() {
896 let model = CostModel::new();
897
898 let benefit = model.factorized_benefit(10.0, 3);
900
901 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
905 assert!(benefit > 0.0, "Benefit should be positive");
906 }
907
908 #[test]
909 fn test_factorized_benefit_low_fanout() {
910 let model = CostModel::new();
911
912 let benefit = model.factorized_benefit(1.5, 2);
914 assert!(
915 benefit <= 1.0,
916 "Low fanout still benefits from factorization"
917 );
918 }
919}