1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7 LogicalOperator, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp, VectorScanOp,
8};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct Cost {
15 pub cpu: f64,
17 pub io: f64,
19 pub memory: f64,
21 pub network: f64,
23}
24
25impl Cost {
26 #[must_use]
28 pub fn zero() -> Self {
29 Self {
30 cpu: 0.0,
31 io: 0.0,
32 memory: 0.0,
33 network: 0.0,
34 }
35 }
36
37 #[must_use]
39 pub fn cpu(cpu: f64) -> Self {
40 Self {
41 cpu,
42 io: 0.0,
43 memory: 0.0,
44 network: 0.0,
45 }
46 }
47
48 #[must_use]
50 pub fn with_io(mut self, io: f64) -> Self {
51 self.io = io;
52 self
53 }
54
55 #[must_use]
57 pub fn with_memory(mut self, memory: f64) -> Self {
58 self.memory = memory;
59 self
60 }
61
62 #[must_use]
66 pub fn total(&self) -> f64 {
67 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
68 }
69
70 #[must_use]
72 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
73 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
74 }
75}
76
77impl std::ops::Add for Cost {
78 type Output = Self;
79
80 fn add(self, other: Self) -> Self {
81 Self {
82 cpu: self.cpu + other.cpu,
83 io: self.io + other.io,
84 memory: self.memory + other.memory,
85 network: self.network + other.network,
86 }
87 }
88}
89
90impl std::ops::AddAssign for Cost {
91 fn add_assign(&mut self, other: Self) {
92 self.cpu += other.cpu;
93 self.io += other.io;
94 self.memory += other.memory;
95 self.network += other.network;
96 }
97}
98
99pub struct CostModel {
107 cpu_tuple_cost: f64,
109 hash_lookup_cost: f64,
111 sort_comparison_cost: f64,
113 avg_tuple_size: f64,
115 page_size: f64,
117 avg_fanout: f64,
119 edge_type_degrees: std::collections::HashMap<String, (f64, f64)>,
121}
122
123impl CostModel {
124 #[must_use]
126 pub fn new() -> Self {
127 Self {
128 cpu_tuple_cost: 0.01,
129 hash_lookup_cost: 0.03,
130 sort_comparison_cost: 0.02,
131 avg_tuple_size: 100.0,
132 page_size: 8192.0,
133 avg_fanout: 10.0,
134 edge_type_degrees: std::collections::HashMap::new(),
135 }
136 }
137
138 #[must_use]
140 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
141 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
142 self
143 }
144
145 #[must_use]
149 pub fn with_edge_type_degrees(
150 mut self,
151 degrees: std::collections::HashMap<String, (f64, f64)>,
152 ) -> Self {
153 self.edge_type_degrees = degrees;
154 self
155 }
156
157 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
162 if expand.edge_types.len() == 1
163 && let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(&expand.edge_types[0])
164 {
165 return match expand.direction {
166 ExpandDirection::Outgoing => out_deg,
167 ExpandDirection::Incoming => in_deg,
168 ExpandDirection::Both => out_deg + in_deg,
169 };
170 }
171 self.avg_fanout
172 }
173
174 #[must_use]
176 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
177 match op {
178 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
179 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
180 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
181 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
182 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
183 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
184 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
185 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
186 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
187 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
188 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
189 LogicalOperator::Empty => Cost::zero(),
190 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
191 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
192 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
193 }
194 }
195
196 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
198 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
199 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
200 }
201
202 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
204 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
206 }
207
208 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
210 let expr_count = project.projections.len() as f64;
212 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
213 }
214
215 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
220 let fanout = self.fanout_for_expand(expand);
221 let lookup_cost = cardinality * self.hash_lookup_cost;
223 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
225 Cost::cpu(lookup_cost + output_cost)
226 }
227
228 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
230 match join.join_type {
232 JoinType::Cross => {
233 Cost::cpu(cardinality * self.cpu_tuple_cost)
235 }
236 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
237 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
241
242 let build_cost = build_cardinality * self.hash_lookup_cost;
244 let memory_cost = build_cardinality * self.avg_tuple_size;
245
246 let probe_cost = probe_cardinality * self.hash_lookup_cost;
248
249 let output_cost = cardinality * self.cpu_tuple_cost;
251
252 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
253 }
254 JoinType::Semi | JoinType::Anti => {
255 let build_cardinality = cardinality.sqrt();
257 let probe_cardinality = cardinality.sqrt();
258
259 let build_cost = build_cardinality * self.hash_lookup_cost;
260 let probe_cost = probe_cardinality * self.hash_lookup_cost;
261
262 Cost::cpu(build_cost + probe_cost)
263 .with_memory(build_cardinality * self.avg_tuple_size)
264 }
265 }
266 }
267
268 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
270 let hash_cost = cardinality * self.hash_lookup_cost;
272
273 let agg_count = agg.aggregates.len() as f64;
275 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
276
277 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
280
281 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
282 }
283
284 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
286 if cardinality <= 1.0 {
287 return Cost::zero();
288 }
289
290 let comparisons = cardinality * cardinality.log2();
292 let key_count = sort.keys.len() as f64;
293
294 let memory_cost = cardinality * self.avg_tuple_size;
296
297 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
298 }
299
300 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
302 let hash_cost = cardinality * self.hash_lookup_cost;
304 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
307 }
308
309 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
311 Cost::cpu(limit.count as f64 * self.cpu_tuple_cost * 0.1)
313 }
314
315 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
317 Cost::cpu(skip.count as f64 * self.cpu_tuple_cost)
319 }
320
321 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
323 let expr_count = ret.items.len() as f64;
325 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
326 }
327
328 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
333 let k = scan.k as f64;
335
336 let ef = 64.0;
339 let n = cardinality.max(1.0);
340 let search_cost = if scan.index_name.is_some() {
341 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
344 n * self.cpu_tuple_cost * 10.0
346 };
347
348 let memory = k * self.avg_tuple_size * 2.0;
350
351 Cost::cpu(search_cost).with_memory(memory)
352 }
353
354 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
358 let k = join.k as f64;
359
360 let per_row_search_cost = if join.index_name.is_some() {
363 let ef = 64.0;
365 let n = cardinality.max(1.0);
366 ef * n.ln() * self.cpu_tuple_cost * 10.0
367 } else {
368 cardinality * self.cpu_tuple_cost * 10.0
370 };
371
372 let input_cardinality = (cardinality / k).max(1.0);
375 let total_search_cost = input_cardinality * per_row_search_cost;
376
377 let memory = cardinality * self.avg_tuple_size;
379
380 Cost::cpu(total_search_cost).with_memory(memory)
381 }
382
383 #[must_use]
385 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
386 if a.total() <= b.total() { a } else { b }
387 }
388
389 #[must_use]
405 pub fn leapfrog_join_cost(
406 &self,
407 num_relations: usize,
408 cardinalities: &[f64],
409 output_cardinality: f64,
410 ) -> Cost {
411 if cardinalities.is_empty() {
412 return Cost::zero();
413 }
414
415 let total_input: f64 = cardinalities.iter().sum();
416 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
417
418 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
423 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
424 } else {
425 output_cardinality * self.cpu_tuple_cost
426 };
427
428 let output_cost = output_cardinality * self.cpu_tuple_cost;
430
431 let memory = total_input * self.avg_tuple_size * 2.0;
433
434 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
435 }
436
437 #[must_use]
441 pub fn prefer_leapfrog_join(
442 &self,
443 num_relations: usize,
444 cardinalities: &[f64],
445 output_cardinality: f64,
446 ) -> bool {
447 if num_relations < 3 || cardinalities.len() < 3 {
448 return false;
450 }
451
452 let leapfrog_cost =
453 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
454
455 let mut hash_cascade_cost = Cost::zero();
459 let mut intermediate_cardinality = cardinalities[0];
460
461 for card in &cardinalities[1..] {
462 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
465 left: Box::new(LogicalOperator::Empty),
466 right: Box::new(LogicalOperator::Empty),
467 join_type: JoinType::Inner,
468 conditions: vec![],
469 };
470 hash_cascade_cost += self.join_cost(&join, join_output);
471 intermediate_cardinality = join_output;
472 }
473
474 leapfrog_cost.total() < hash_cascade_cost.total()
475 }
476
477 #[must_use]
485 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
486 if num_hops <= 1 || avg_fanout <= 1.0 {
487 return 1.0; }
489
490 let full_size = avg_fanout.powi(num_hops as i32);
496 let factorized_size = if avg_fanout > 1.0 {
497 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
498 } else {
499 num_hops as f64
500 };
501
502 (factorized_size / full_size).min(1.0)
503 }
504}
505
506impl Default for CostModel {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::query::plan::{
516 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
517 PathMode, Projection, ReturnItem, SortOrder,
518 };
519
520 #[test]
521 fn test_cost_addition() {
522 let a = Cost::cpu(10.0).with_io(5.0);
523 let b = Cost::cpu(20.0).with_memory(100.0);
524 let c = a + b;
525
526 assert!((c.cpu - 30.0).abs() < 0.001);
527 assert!((c.io - 5.0).abs() < 0.001);
528 assert!((c.memory - 100.0).abs() < 0.001);
529 }
530
531 #[test]
532 fn test_cost_total() {
533 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
534 assert!((cost.total() - 30.0).abs() < 0.001);
536 }
537
538 #[test]
539 fn test_cost_model_node_scan() {
540 let model = CostModel::new();
541 let scan = NodeScanOp {
542 variable: "n".to_string(),
543 label: Some("Person".to_string()),
544 input: None,
545 };
546 let cost = model.node_scan_cost(&scan, 1000.0);
547
548 assert!(cost.cpu > 0.0);
549 assert!(cost.io > 0.0);
550 }
551
552 #[test]
553 fn test_cost_model_sort() {
554 let model = CostModel::new();
555 let sort = SortOp {
556 keys: vec![],
557 input: Box::new(LogicalOperator::Empty),
558 };
559
560 let cost_100 = model.sort_cost(&sort, 100.0);
561 let cost_1000 = model.sort_cost(&sort, 1000.0);
562
563 assert!(cost_1000.total() > cost_100.total());
565 }
566
567 #[test]
568 fn test_cost_zero() {
569 let cost = Cost::zero();
570 assert!((cost.cpu).abs() < 0.001);
571 assert!((cost.io).abs() < 0.001);
572 assert!((cost.memory).abs() < 0.001);
573 assert!((cost.network).abs() < 0.001);
574 assert!((cost.total()).abs() < 0.001);
575 }
576
577 #[test]
578 fn test_cost_add_assign() {
579 let mut cost = Cost::cpu(10.0);
580 cost += Cost::cpu(5.0).with_io(2.0);
581 assert!((cost.cpu - 15.0).abs() < 0.001);
582 assert!((cost.io - 2.0).abs() < 0.001);
583 }
584
585 #[test]
586 fn test_cost_total_weighted() {
587 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
588 let total = cost.total_weighted(2.0, 5.0, 0.5);
590 assert!((total - 80.0).abs() < 0.001);
591 }
592
593 #[test]
594 fn test_cost_model_filter() {
595 let model = CostModel::new();
596 let filter = FilterOp {
597 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
598 input: Box::new(LogicalOperator::Empty),
599 };
600 let cost = model.filter_cost(&filter, 1000.0);
601
602 assert!(cost.cpu > 0.0);
604 assert!((cost.io).abs() < 0.001);
605 }
606
607 #[test]
608 fn test_cost_model_project() {
609 let model = CostModel::new();
610 let project = ProjectOp {
611 projections: vec![
612 Projection {
613 expression: LogicalExpression::Variable("a".to_string()),
614 alias: None,
615 },
616 Projection {
617 expression: LogicalExpression::Variable("b".to_string()),
618 alias: None,
619 },
620 ],
621 input: Box::new(LogicalOperator::Empty),
622 };
623 let cost = model.project_cost(&project, 1000.0);
624
625 assert!(cost.cpu > 0.0);
627 }
628
629 #[test]
630 fn test_cost_model_expand() {
631 let model = CostModel::new();
632 let expand = ExpandOp {
633 from_variable: "a".to_string(),
634 to_variable: "b".to_string(),
635 edge_variable: None,
636 direction: ExpandDirection::Outgoing,
637 edge_types: vec![],
638 min_hops: 1,
639 max_hops: Some(1),
640 input: Box::new(LogicalOperator::Empty),
641 path_alias: None,
642 path_mode: PathMode::Walk,
643 };
644 let cost = model.expand_cost(&expand, 1000.0);
645
646 assert!(cost.cpu > 0.0);
648 }
649
650 #[test]
651 fn test_cost_model_expand_with_edge_type_stats() {
652 let mut degrees = std::collections::HashMap::new();
653 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
657
658 let knows_out = ExpandOp {
660 from_variable: "a".to_string(),
661 to_variable: "b".to_string(),
662 edge_variable: None,
663 direction: ExpandDirection::Outgoing,
664 edge_types: vec!["KNOWS".to_string()],
665 min_hops: 1,
666 max_hops: Some(1),
667 input: Box::new(LogicalOperator::Empty),
668 path_alias: None,
669 path_mode: PathMode::Walk,
670 };
671 let cost_knows = model.expand_cost(&knows_out, 1000.0);
672
673 let works_out = ExpandOp {
675 from_variable: "a".to_string(),
676 to_variable: "b".to_string(),
677 edge_variable: None,
678 direction: ExpandDirection::Outgoing,
679 edge_types: vec!["WORKS_AT".to_string()],
680 min_hops: 1,
681 max_hops: Some(1),
682 input: Box::new(LogicalOperator::Empty),
683 path_alias: None,
684 path_mode: PathMode::Walk,
685 };
686 let cost_works = model.expand_cost(&works_out, 1000.0);
687
688 assert!(
690 cost_knows.cpu > cost_works.cpu,
691 "KNOWS(5) should cost more than WORKS_AT(1)"
692 );
693
694 let works_in = ExpandOp {
696 from_variable: "c".to_string(),
697 to_variable: "p".to_string(),
698 edge_variable: None,
699 direction: ExpandDirection::Incoming,
700 edge_types: vec!["WORKS_AT".to_string()],
701 min_hops: 1,
702 max_hops: Some(1),
703 input: Box::new(LogicalOperator::Empty),
704 path_alias: None,
705 path_mode: PathMode::Walk,
706 };
707 let cost_works_in = model.expand_cost(&works_in, 1000.0);
708
709 assert!(
711 cost_works_in.cpu > cost_knows.cpu,
712 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
713 );
714 }
715
716 #[test]
717 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
718 let model = CostModel::new().with_avg_fanout(7.0);
719 let expand = ExpandOp {
720 from_variable: "a".to_string(),
721 to_variable: "b".to_string(),
722 edge_variable: None,
723 direction: ExpandDirection::Outgoing,
724 edge_types: vec!["UNKNOWN_TYPE".to_string()],
725 min_hops: 1,
726 max_hops: Some(1),
727 input: Box::new(LogicalOperator::Empty),
728 path_alias: None,
729 path_mode: PathMode::Walk,
730 };
731 let cost_unknown = model.expand_cost(&expand, 1000.0);
732
733 let expand_no_type = ExpandOp {
735 from_variable: "a".to_string(),
736 to_variable: "b".to_string(),
737 edge_variable: None,
738 direction: ExpandDirection::Outgoing,
739 edge_types: vec![],
740 min_hops: 1,
741 max_hops: Some(1),
742 input: Box::new(LogicalOperator::Empty),
743 path_alias: None,
744 path_mode: PathMode::Walk,
745 };
746 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
747
748 assert!(
750 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
751 "Unknown edge type should use global fanout"
752 );
753 }
754
755 #[test]
756 fn test_cost_model_hash_join() {
757 let model = CostModel::new();
758 let join = JoinOp {
759 left: Box::new(LogicalOperator::Empty),
760 right: Box::new(LogicalOperator::Empty),
761 join_type: JoinType::Inner,
762 conditions: vec![JoinCondition {
763 left: LogicalExpression::Variable("a".to_string()),
764 right: LogicalExpression::Variable("b".to_string()),
765 }],
766 };
767 let cost = model.join_cost(&join, 10000.0);
768
769 assert!(cost.cpu > 0.0);
771 assert!(cost.memory > 0.0);
772 }
773
774 #[test]
775 fn test_cost_model_cross_join() {
776 let model = CostModel::new();
777 let join = JoinOp {
778 left: Box::new(LogicalOperator::Empty),
779 right: Box::new(LogicalOperator::Empty),
780 join_type: JoinType::Cross,
781 conditions: vec![],
782 };
783 let cost = model.join_cost(&join, 1000000.0);
784
785 assert!(cost.cpu > 0.0);
787 }
788
789 #[test]
790 fn test_cost_model_semi_join() {
791 let model = CostModel::new();
792 let join = JoinOp {
793 left: Box::new(LogicalOperator::Empty),
794 right: Box::new(LogicalOperator::Empty),
795 join_type: JoinType::Semi,
796 conditions: vec![],
797 };
798 let cost_semi = model.join_cost(&join, 1000.0);
799
800 let inner_join = JoinOp {
801 left: Box::new(LogicalOperator::Empty),
802 right: Box::new(LogicalOperator::Empty),
803 join_type: JoinType::Inner,
804 conditions: vec![],
805 };
806 let cost_inner = model.join_cost(&inner_join, 1000.0);
807
808 assert!(cost_semi.cpu > 0.0);
810 assert!(cost_inner.cpu > 0.0);
811 }
812
813 #[test]
814 fn test_cost_model_aggregate() {
815 let model = CostModel::new();
816 let agg = AggregateOp {
817 group_by: vec![],
818 aggregates: vec![
819 AggregateExpr {
820 function: AggregateFunction::Count,
821 expression: None,
822 distinct: false,
823 alias: Some("cnt".to_string()),
824 percentile: None,
825 },
826 AggregateExpr {
827 function: AggregateFunction::Sum,
828 expression: Some(LogicalExpression::Variable("x".to_string())),
829 distinct: false,
830 alias: Some("total".to_string()),
831 percentile: None,
832 },
833 ],
834 input: Box::new(LogicalOperator::Empty),
835 having: None,
836 };
837 let cost = model.aggregate_cost(&agg, 1000.0);
838
839 assert!(cost.cpu > 0.0);
841 assert!(cost.memory > 0.0);
842 }
843
844 #[test]
845 fn test_cost_model_distinct() {
846 let model = CostModel::new();
847 let distinct = DistinctOp {
848 input: Box::new(LogicalOperator::Empty),
849 columns: None,
850 };
851 let cost = model.distinct_cost(&distinct, 1000.0);
852
853 assert!(cost.cpu > 0.0);
855 assert!(cost.memory > 0.0);
856 }
857
858 #[test]
859 fn test_cost_model_limit() {
860 let model = CostModel::new();
861 let limit = LimitOp {
862 count: 10,
863 input: Box::new(LogicalOperator::Empty),
864 };
865 let cost = model.limit_cost(&limit, 1000.0);
866
867 assert!(cost.cpu > 0.0);
869 assert!(cost.cpu < 1.0); }
871
872 #[test]
873 fn test_cost_model_skip() {
874 let model = CostModel::new();
875 let skip = SkipOp {
876 count: 100,
877 input: Box::new(LogicalOperator::Empty),
878 };
879 let cost = model.skip_cost(&skip, 1000.0);
880
881 assert!(cost.cpu > 0.0);
883 }
884
885 #[test]
886 fn test_cost_model_return() {
887 let model = CostModel::new();
888 let ret = ReturnOp {
889 items: vec![
890 ReturnItem {
891 expression: LogicalExpression::Variable("a".to_string()),
892 alias: None,
893 },
894 ReturnItem {
895 expression: LogicalExpression::Variable("b".to_string()),
896 alias: None,
897 },
898 ],
899 distinct: false,
900 input: Box::new(LogicalOperator::Empty),
901 };
902 let cost = model.return_cost(&ret, 1000.0);
903
904 assert!(cost.cpu > 0.0);
906 }
907
908 #[test]
909 fn test_cost_cheaper() {
910 let model = CostModel::new();
911 let cheap = Cost::cpu(10.0);
912 let expensive = Cost::cpu(100.0);
913
914 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
915 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
916 }
917
918 #[test]
919 fn test_cost_comparison_prefers_lower_total() {
920 let model = CostModel::new();
921 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
923 let io_heavy = Cost::cpu(10.0).with_io(20.0);
925
926 assert!(cpu_heavy.total() < io_heavy.total());
928 assert_eq!(
929 model.cheaper(&cpu_heavy, &io_heavy).total(),
930 cpu_heavy.total()
931 );
932 }
933
934 #[test]
935 fn test_cost_model_sort_with_keys() {
936 let model = CostModel::new();
937 let sort_single = SortOp {
938 keys: vec![crate::query::plan::SortKey {
939 expression: LogicalExpression::Variable("a".to_string()),
940 order: SortOrder::Ascending,
941 }],
942 input: Box::new(LogicalOperator::Empty),
943 };
944 let sort_multi = SortOp {
945 keys: vec![
946 crate::query::plan::SortKey {
947 expression: LogicalExpression::Variable("a".to_string()),
948 order: SortOrder::Ascending,
949 },
950 crate::query::plan::SortKey {
951 expression: LogicalExpression::Variable("b".to_string()),
952 order: SortOrder::Descending,
953 },
954 ],
955 input: Box::new(LogicalOperator::Empty),
956 };
957
958 let cost_single = model.sort_cost(&sort_single, 1000.0);
959 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
960
961 assert!(cost_multi.cpu > cost_single.cpu);
963 }
964
965 #[test]
966 fn test_cost_model_empty_operator() {
967 let model = CostModel::new();
968 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
969 assert!((cost.total()).abs() < 0.001);
970 }
971
972 #[test]
973 fn test_cost_model_default() {
974 let model = CostModel::default();
975 let scan = NodeScanOp {
976 variable: "n".to_string(),
977 label: None,
978 input: None,
979 };
980 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
981 assert!(cost.total() > 0.0);
982 }
983
984 #[test]
985 fn test_leapfrog_join_cost() {
986 let model = CostModel::new();
987
988 let cardinalities = vec![1000.0, 1000.0, 1000.0];
990 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
991
992 assert!(cost.cpu > 0.0);
994 assert!(cost.memory > 0.0);
996 }
997
998 #[test]
999 fn test_leapfrog_join_cost_empty() {
1000 let model = CostModel::new();
1001 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1002 assert!((cost.total()).abs() < 0.001);
1003 }
1004
1005 #[test]
1006 fn test_prefer_leapfrog_join_for_triangles() {
1007 let model = CostModel::new();
1008
1009 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1011 let output = 1000.0;
1012
1013 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1014
1015 assert!(leapfrog_cost.cpu > 0.0);
1017 assert!(leapfrog_cost.memory > 0.0);
1018
1019 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1022 }
1024
1025 #[test]
1026 fn test_prefer_leapfrog_join_binary_case() {
1027 let model = CostModel::new();
1028
1029 let cardinalities = vec![1000.0, 1000.0];
1031 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1032 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1033 }
1034
1035 #[test]
1036 fn test_factorized_benefit_single_hop() {
1037 let model = CostModel::new();
1038
1039 let benefit = model.factorized_benefit(10.0, 1);
1041 assert!(
1042 (benefit - 1.0).abs() < 0.001,
1043 "Single hop should have no benefit"
1044 );
1045 }
1046
1047 #[test]
1048 fn test_factorized_benefit_multi_hop() {
1049 let model = CostModel::new();
1050
1051 let benefit = model.factorized_benefit(10.0, 3);
1053
1054 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1058 assert!(benefit > 0.0, "Benefit should be positive");
1059 }
1060
1061 #[test]
1062 fn test_factorized_benefit_low_fanout() {
1063 let model = CostModel::new();
1064
1065 let benefit = model.factorized_benefit(1.5, 2);
1067 assert!(
1068 benefit <= 1.0,
1069 "Low fanout still benefits from factorization"
1070 );
1071 }
1072}