1use crate::query::plan::{
6 AggregateOp, DistinctOp, ExpandDirection, ExpandOp, FilterOp, JoinOp, JoinType, LimitOp,
7 LogicalOperator, MultiWayJoinOp, NodeScanOp, ProjectOp, ReturnOp, SkipOp, SortOp, VectorJoinOp,
8 VectorScanOp,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq)]
15pub struct Cost {
16 pub cpu: f64,
18 pub io: f64,
20 pub memory: f64,
22 pub network: f64,
24}
25
26impl Cost {
27 #[must_use]
29 pub fn zero() -> Self {
30 Self {
31 cpu: 0.0,
32 io: 0.0,
33 memory: 0.0,
34 network: 0.0,
35 }
36 }
37
38 #[must_use]
40 pub fn cpu(cpu: f64) -> Self {
41 Self {
42 cpu,
43 io: 0.0,
44 memory: 0.0,
45 network: 0.0,
46 }
47 }
48
49 #[must_use]
51 pub fn with_io(mut self, io: f64) -> Self {
52 self.io = io;
53 self
54 }
55
56 #[must_use]
58 pub fn with_memory(mut self, memory: f64) -> Self {
59 self.memory = memory;
60 self
61 }
62
63 #[must_use]
67 pub fn total(&self) -> f64 {
68 self.cpu + self.io * 10.0 + self.memory * 0.1 + self.network * 100.0
69 }
70
71 #[must_use]
73 pub fn total_weighted(&self, cpu_weight: f64, io_weight: f64, mem_weight: f64) -> f64 {
74 self.cpu * cpu_weight + self.io * io_weight + self.memory * mem_weight
75 }
76}
77
78impl std::ops::Add for Cost {
79 type Output = Self;
80
81 fn add(self, other: Self) -> Self {
82 Self {
83 cpu: self.cpu + other.cpu,
84 io: self.io + other.io,
85 memory: self.memory + other.memory,
86 network: self.network + other.network,
87 }
88 }
89}
90
91impl std::ops::AddAssign for Cost {
92 fn add_assign(&mut self, other: Self) {
93 self.cpu += other.cpu;
94 self.io += other.io;
95 self.memory += other.memory;
96 self.network += other.network;
97 }
98}
99
100pub struct CostModel {
108 cpu_tuple_cost: f64,
110 hash_lookup_cost: f64,
112 sort_comparison_cost: f64,
114 avg_tuple_size: f64,
116 page_size: f64,
118 avg_fanout: f64,
120 edge_type_degrees: std::collections::HashMap<String, (f64, f64)>,
122}
123
124impl CostModel {
125 #[must_use]
127 pub fn new() -> Self {
128 Self {
129 cpu_tuple_cost: 0.01,
130 hash_lookup_cost: 0.03,
131 sort_comparison_cost: 0.02,
132 avg_tuple_size: 100.0,
133 page_size: 8192.0,
134 avg_fanout: 10.0,
135 edge_type_degrees: std::collections::HashMap::new(),
136 }
137 }
138
139 #[must_use]
141 pub fn with_avg_fanout(mut self, avg_fanout: f64) -> Self {
142 self.avg_fanout = if avg_fanout > 0.0 { avg_fanout } else { 10.0 };
143 self
144 }
145
146 #[must_use]
150 pub fn with_edge_type_degrees(
151 mut self,
152 degrees: std::collections::HashMap<String, (f64, f64)>,
153 ) -> Self {
154 self.edge_type_degrees = degrees;
155 self
156 }
157
158 fn fanout_for_expand(&self, expand: &ExpandOp) -> f64 {
163 if expand.edge_types.len() == 1
164 && let Some(&(out_deg, in_deg)) = self.edge_type_degrees.get(&expand.edge_types[0])
165 {
166 return match expand.direction {
167 ExpandDirection::Outgoing => out_deg,
168 ExpandDirection::Incoming => in_deg,
169 ExpandDirection::Both => out_deg + in_deg,
170 };
171 }
172 self.avg_fanout
173 }
174
175 #[must_use]
177 pub fn estimate(&self, op: &LogicalOperator, cardinality: f64) -> Cost {
178 match op {
179 LogicalOperator::NodeScan(scan) => self.node_scan_cost(scan, cardinality),
180 LogicalOperator::Filter(filter) => self.filter_cost(filter, cardinality),
181 LogicalOperator::Project(project) => self.project_cost(project, cardinality),
182 LogicalOperator::Expand(expand) => self.expand_cost(expand, cardinality),
183 LogicalOperator::Join(join) => self.join_cost(join, cardinality),
184 LogicalOperator::Aggregate(agg) => self.aggregate_cost(agg, cardinality),
185 LogicalOperator::Sort(sort) => self.sort_cost(sort, cardinality),
186 LogicalOperator::Distinct(distinct) => self.distinct_cost(distinct, cardinality),
187 LogicalOperator::Limit(limit) => self.limit_cost(limit, cardinality),
188 LogicalOperator::Skip(skip) => self.skip_cost(skip, cardinality),
189 LogicalOperator::Return(ret) => self.return_cost(ret, cardinality),
190 LogicalOperator::Empty => Cost::zero(),
191 LogicalOperator::VectorScan(scan) => self.vector_scan_cost(scan, cardinality),
192 LogicalOperator::VectorJoin(join) => self.vector_join_cost(join, cardinality),
193 LogicalOperator::MultiWayJoin(mwj) => self.multi_way_join_cost(mwj, cardinality),
194 _ => Cost::cpu(cardinality * self.cpu_tuple_cost),
195 }
196 }
197
198 fn node_scan_cost(&self, _scan: &NodeScanOp, cardinality: f64) -> Cost {
200 let pages = (cardinality * self.avg_tuple_size) / self.page_size;
201 Cost::cpu(cardinality * self.cpu_tuple_cost).with_io(pages)
202 }
203
204 fn filter_cost(&self, _filter: &FilterOp, cardinality: f64) -> Cost {
206 Cost::cpu(cardinality * self.cpu_tuple_cost * 1.5)
208 }
209
210 fn project_cost(&self, project: &ProjectOp, cardinality: f64) -> Cost {
212 let expr_count = project.projections.len() as f64;
214 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
215 }
216
217 fn expand_cost(&self, expand: &ExpandOp, cardinality: f64) -> Cost {
222 let fanout = self.fanout_for_expand(expand);
223 let lookup_cost = cardinality * self.hash_lookup_cost;
225 let output_cost = cardinality * fanout * self.cpu_tuple_cost;
227 Cost::cpu(lookup_cost + output_cost)
228 }
229
230 fn join_cost(&self, join: &JoinOp, cardinality: f64) -> Cost {
232 match join.join_type {
234 JoinType::Cross => {
235 Cost::cpu(cardinality * self.cpu_tuple_cost)
237 }
238 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
239 let build_cardinality = cardinality.sqrt(); let probe_cardinality = cardinality.sqrt();
243
244 let build_cost = build_cardinality * self.hash_lookup_cost;
246 let memory_cost = build_cardinality * self.avg_tuple_size;
247
248 let probe_cost = probe_cardinality * self.hash_lookup_cost;
250
251 let output_cost = cardinality * self.cpu_tuple_cost;
253
254 Cost::cpu(build_cost + probe_cost + output_cost).with_memory(memory_cost)
255 }
256 JoinType::Semi | JoinType::Anti => {
257 let build_cardinality = cardinality.sqrt();
259 let probe_cardinality = cardinality.sqrt();
260
261 let build_cost = build_cardinality * self.hash_lookup_cost;
262 let probe_cost = probe_cardinality * self.hash_lookup_cost;
263
264 Cost::cpu(build_cost + probe_cost)
265 .with_memory(build_cardinality * self.avg_tuple_size)
266 }
267 }
268 }
269
270 fn multi_way_join_cost(&self, mwj: &MultiWayJoinOp, cardinality: f64) -> Cost {
275 let n = mwj.inputs.len();
276 if n == 0 {
277 return Cost::zero();
278 }
279 let per_input = cardinality.powf(1.0 / n as f64).max(1.0);
282 let cardinalities: Vec<f64> = (0..n).map(|_| per_input).collect();
283 self.leapfrog_join_cost(n, &cardinalities, cardinality)
284 }
285
286 fn aggregate_cost(&self, agg: &AggregateOp, cardinality: f64) -> Cost {
288 let hash_cost = cardinality * self.hash_lookup_cost;
290
291 let agg_count = agg.aggregates.len() as f64;
293 let agg_cost = cardinality * self.cpu_tuple_cost * agg_count;
294
295 let distinct_groups = (cardinality / 10.0).max(1.0); let memory_cost = distinct_groups * self.avg_tuple_size;
298
299 Cost::cpu(hash_cost + agg_cost).with_memory(memory_cost)
300 }
301
302 fn sort_cost(&self, sort: &SortOp, cardinality: f64) -> Cost {
304 if cardinality <= 1.0 {
305 return Cost::zero();
306 }
307
308 let comparisons = cardinality * cardinality.log2();
310 let key_count = sort.keys.len() as f64;
311
312 let memory_cost = cardinality * self.avg_tuple_size;
314
315 Cost::cpu(comparisons * self.sort_comparison_cost * key_count).with_memory(memory_cost)
316 }
317
318 fn distinct_cost(&self, _distinct: &DistinctOp, cardinality: f64) -> Cost {
320 let hash_cost = cardinality * self.hash_lookup_cost;
322 let memory_cost = cardinality * self.avg_tuple_size * 0.5; Cost::cpu(hash_cost).with_memory(memory_cost)
325 }
326
327 fn limit_cost(&self, limit: &LimitOp, _cardinality: f64) -> Cost {
329 Cost::cpu(limit.count.estimate() * self.cpu_tuple_cost * 0.1)
331 }
332
333 fn skip_cost(&self, skip: &SkipOp, _cardinality: f64) -> Cost {
335 Cost::cpu(skip.count.estimate() * self.cpu_tuple_cost)
337 }
338
339 fn return_cost(&self, ret: &ReturnOp, cardinality: f64) -> Cost {
341 let expr_count = ret.items.len() as f64;
343 Cost::cpu(cardinality * self.cpu_tuple_cost * expr_count)
344 }
345
346 fn vector_scan_cost(&self, scan: &VectorScanOp, cardinality: f64) -> Cost {
351 let k = scan.k as f64;
353
354 let ef = 64.0;
357 let n = cardinality.max(1.0);
358 let search_cost = if scan.index_name.is_some() {
359 ef * n.ln() * self.cpu_tuple_cost * 10.0 } else {
362 n * self.cpu_tuple_cost * 10.0
364 };
365
366 let memory = k * self.avg_tuple_size * 2.0;
368
369 Cost::cpu(search_cost).with_memory(memory)
370 }
371
372 fn vector_join_cost(&self, join: &VectorJoinOp, cardinality: f64) -> Cost {
376 let k = join.k as f64;
377
378 let per_row_search_cost = if join.index_name.is_some() {
381 let ef = 64.0;
383 let n = cardinality.max(1.0);
384 ef * n.ln() * self.cpu_tuple_cost * 10.0
385 } else {
386 cardinality * self.cpu_tuple_cost * 10.0
388 };
389
390 let input_cardinality = (cardinality / k).max(1.0);
393 let total_search_cost = input_cardinality * per_row_search_cost;
394
395 let memory = cardinality * self.avg_tuple_size;
397
398 Cost::cpu(total_search_cost).with_memory(memory)
399 }
400
401 #[must_use]
403 pub fn cheaper<'a>(&self, a: &'a Cost, b: &'a Cost) -> &'a Cost {
404 if a.total() <= b.total() { a } else { b }
405 }
406
407 #[must_use]
423 pub fn leapfrog_join_cost(
424 &self,
425 num_relations: usize,
426 cardinalities: &[f64],
427 output_cardinality: f64,
428 ) -> Cost {
429 if cardinalities.is_empty() {
430 return Cost::zero();
431 }
432
433 let total_input: f64 = cardinalities.iter().sum();
434 let min_card = cardinalities.iter().copied().fold(f64::INFINITY, f64::min);
435
436 let materialize_cost = total_input * self.cpu_tuple_cost * 2.0; let seek_cost = if min_card > 1.0 {
441 output_cardinality * (num_relations as f64) * min_card.log2() * self.hash_lookup_cost
442 } else {
443 output_cardinality * self.cpu_tuple_cost
444 };
445
446 let output_cost = output_cardinality * self.cpu_tuple_cost;
448
449 let memory = total_input * self.avg_tuple_size * 2.0;
451
452 Cost::cpu(materialize_cost + seek_cost + output_cost).with_memory(memory)
453 }
454
455 #[must_use]
459 pub fn prefer_leapfrog_join(
460 &self,
461 num_relations: usize,
462 cardinalities: &[f64],
463 output_cardinality: f64,
464 ) -> bool {
465 if num_relations < 3 || cardinalities.len() < 3 {
466 return false;
468 }
469
470 let leapfrog_cost =
471 self.leapfrog_join_cost(num_relations, cardinalities, output_cardinality);
472
473 let mut hash_cascade_cost = Cost::zero();
477 let mut intermediate_cardinality = cardinalities[0];
478
479 for card in &cardinalities[1..] {
480 let join_output = (intermediate_cardinality * card).sqrt(); let join = JoinOp {
483 left: Box::new(LogicalOperator::Empty),
484 right: Box::new(LogicalOperator::Empty),
485 join_type: JoinType::Inner,
486 conditions: vec![],
487 };
488 hash_cascade_cost += self.join_cost(&join, join_output);
489 intermediate_cardinality = join_output;
490 }
491
492 leapfrog_cost.total() < hash_cascade_cost.total()
493 }
494
495 #[must_use]
503 pub fn factorized_benefit(&self, avg_fanout: f64, num_hops: usize) -> f64 {
504 if num_hops <= 1 || avg_fanout <= 1.0 {
505 return 1.0; }
507
508 let full_size = avg_fanout.powi(num_hops as i32);
514 let factorized_size = if avg_fanout > 1.0 {
515 (avg_fanout.powi(num_hops as i32 + 1) - 1.0) / (avg_fanout - 1.0)
516 } else {
517 num_hops as f64
518 };
519
520 (factorized_size / full_size).min(1.0)
521 }
522}
523
524impl Default for CostModel {
525 fn default() -> Self {
526 Self::new()
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use crate::query::plan::{
534 AggregateExpr, AggregateFunction, ExpandDirection, JoinCondition, LogicalExpression,
535 PathMode, Projection, ReturnItem, SortOrder,
536 };
537
538 #[test]
539 fn test_cost_addition() {
540 let a = Cost::cpu(10.0).with_io(5.0);
541 let b = Cost::cpu(20.0).with_memory(100.0);
542 let c = a + b;
543
544 assert!((c.cpu - 30.0).abs() < 0.001);
545 assert!((c.io - 5.0).abs() < 0.001);
546 assert!((c.memory - 100.0).abs() < 0.001);
547 }
548
549 #[test]
550 fn test_cost_total() {
551 let cost = Cost::cpu(10.0).with_io(1.0).with_memory(100.0);
552 assert!((cost.total() - 30.0).abs() < 0.001);
554 }
555
556 #[test]
557 fn test_cost_model_node_scan() {
558 let model = CostModel::new();
559 let scan = NodeScanOp {
560 variable: "n".to_string(),
561 label: Some("Person".to_string()),
562 input: None,
563 };
564 let cost = model.node_scan_cost(&scan, 1000.0);
565
566 assert!(cost.cpu > 0.0);
567 assert!(cost.io > 0.0);
568 }
569
570 #[test]
571 fn test_cost_model_sort() {
572 let model = CostModel::new();
573 let sort = SortOp {
574 keys: vec![],
575 input: Box::new(LogicalOperator::Empty),
576 };
577
578 let cost_100 = model.sort_cost(&sort, 100.0);
579 let cost_1000 = model.sort_cost(&sort, 1000.0);
580
581 assert!(cost_1000.total() > cost_100.total());
583 }
584
585 #[test]
586 fn test_cost_zero() {
587 let cost = Cost::zero();
588 assert!((cost.cpu).abs() < 0.001);
589 assert!((cost.io).abs() < 0.001);
590 assert!((cost.memory).abs() < 0.001);
591 assert!((cost.network).abs() < 0.001);
592 assert!((cost.total()).abs() < 0.001);
593 }
594
595 #[test]
596 fn test_cost_add_assign() {
597 let mut cost = Cost::cpu(10.0);
598 cost += Cost::cpu(5.0).with_io(2.0);
599 assert!((cost.cpu - 15.0).abs() < 0.001);
600 assert!((cost.io - 2.0).abs() < 0.001);
601 }
602
603 #[test]
604 fn test_cost_total_weighted() {
605 let cost = Cost::cpu(10.0).with_io(2.0).with_memory(100.0);
606 let total = cost.total_weighted(2.0, 5.0, 0.5);
608 assert!((total - 80.0).abs() < 0.001);
609 }
610
611 #[test]
612 fn test_cost_model_filter() {
613 let model = CostModel::new();
614 let filter = FilterOp {
615 predicate: LogicalExpression::Literal(grafeo_common::types::Value::Bool(true)),
616 input: Box::new(LogicalOperator::Empty),
617 pushdown_hint: None,
618 };
619 let cost = model.filter_cost(&filter, 1000.0);
620
621 assert!(cost.cpu > 0.0);
623 assert!((cost.io).abs() < 0.001);
624 }
625
626 #[test]
627 fn test_cost_model_project() {
628 let model = CostModel::new();
629 let project = ProjectOp {
630 projections: vec![
631 Projection {
632 expression: LogicalExpression::Variable("a".to_string()),
633 alias: None,
634 },
635 Projection {
636 expression: LogicalExpression::Variable("b".to_string()),
637 alias: None,
638 },
639 ],
640 input: Box::new(LogicalOperator::Empty),
641 };
642 let cost = model.project_cost(&project, 1000.0);
643
644 assert!(cost.cpu > 0.0);
646 }
647
648 #[test]
649 fn test_cost_model_expand() {
650 let model = CostModel::new();
651 let expand = ExpandOp {
652 from_variable: "a".to_string(),
653 to_variable: "b".to_string(),
654 edge_variable: None,
655 direction: ExpandDirection::Outgoing,
656 edge_types: vec![],
657 min_hops: 1,
658 max_hops: Some(1),
659 input: Box::new(LogicalOperator::Empty),
660 path_alias: None,
661 path_mode: PathMode::Walk,
662 };
663 let cost = model.expand_cost(&expand, 1000.0);
664
665 assert!(cost.cpu > 0.0);
667 }
668
669 #[test]
670 fn test_cost_model_expand_with_edge_type_stats() {
671 let mut degrees = std::collections::HashMap::new();
672 degrees.insert("KNOWS".to_string(), (5.0, 5.0)); degrees.insert("WORKS_AT".to_string(), (1.0, 50.0)); let model = CostModel::new().with_edge_type_degrees(degrees);
676
677 let knows_out = ExpandOp {
679 from_variable: "a".to_string(),
680 to_variable: "b".to_string(),
681 edge_variable: None,
682 direction: ExpandDirection::Outgoing,
683 edge_types: vec!["KNOWS".to_string()],
684 min_hops: 1,
685 max_hops: Some(1),
686 input: Box::new(LogicalOperator::Empty),
687 path_alias: None,
688 path_mode: PathMode::Walk,
689 };
690 let cost_knows = model.expand_cost(&knows_out, 1000.0);
691
692 let works_out = ExpandOp {
694 from_variable: "a".to_string(),
695 to_variable: "b".to_string(),
696 edge_variable: None,
697 direction: ExpandDirection::Outgoing,
698 edge_types: vec!["WORKS_AT".to_string()],
699 min_hops: 1,
700 max_hops: Some(1),
701 input: Box::new(LogicalOperator::Empty),
702 path_alias: None,
703 path_mode: PathMode::Walk,
704 };
705 let cost_works = model.expand_cost(&works_out, 1000.0);
706
707 assert!(
709 cost_knows.cpu > cost_works.cpu,
710 "KNOWS(5) should cost more than WORKS_AT(1)"
711 );
712
713 let works_in = ExpandOp {
715 from_variable: "c".to_string(),
716 to_variable: "p".to_string(),
717 edge_variable: None,
718 direction: ExpandDirection::Incoming,
719 edge_types: vec!["WORKS_AT".to_string()],
720 min_hops: 1,
721 max_hops: Some(1),
722 input: Box::new(LogicalOperator::Empty),
723 path_alias: None,
724 path_mode: PathMode::Walk,
725 };
726 let cost_works_in = model.expand_cost(&works_in, 1000.0);
727
728 assert!(
730 cost_works_in.cpu > cost_knows.cpu,
731 "Incoming WORKS_AT(50) should cost more than KNOWS(5)"
732 );
733 }
734
735 #[test]
736 fn test_cost_model_expand_unknown_edge_type_uses_global_fanout() {
737 let model = CostModel::new().with_avg_fanout(7.0);
738 let expand = ExpandOp {
739 from_variable: "a".to_string(),
740 to_variable: "b".to_string(),
741 edge_variable: None,
742 direction: ExpandDirection::Outgoing,
743 edge_types: vec!["UNKNOWN_TYPE".to_string()],
744 min_hops: 1,
745 max_hops: Some(1),
746 input: Box::new(LogicalOperator::Empty),
747 path_alias: None,
748 path_mode: PathMode::Walk,
749 };
750 let cost_unknown = model.expand_cost(&expand, 1000.0);
751
752 let expand_no_type = ExpandOp {
754 from_variable: "a".to_string(),
755 to_variable: "b".to_string(),
756 edge_variable: None,
757 direction: ExpandDirection::Outgoing,
758 edge_types: vec![],
759 min_hops: 1,
760 max_hops: Some(1),
761 input: Box::new(LogicalOperator::Empty),
762 path_alias: None,
763 path_mode: PathMode::Walk,
764 };
765 let cost_no_type = model.expand_cost(&expand_no_type, 1000.0);
766
767 assert!(
769 (cost_unknown.cpu - cost_no_type.cpu).abs() < 0.001,
770 "Unknown edge type should use global fanout"
771 );
772 }
773
774 #[test]
775 fn test_cost_model_hash_join() {
776 let model = CostModel::new();
777 let join = JoinOp {
778 left: Box::new(LogicalOperator::Empty),
779 right: Box::new(LogicalOperator::Empty),
780 join_type: JoinType::Inner,
781 conditions: vec![JoinCondition {
782 left: LogicalExpression::Variable("a".to_string()),
783 right: LogicalExpression::Variable("b".to_string()),
784 }],
785 };
786 let cost = model.join_cost(&join, 10000.0);
787
788 assert!(cost.cpu > 0.0);
790 assert!(cost.memory > 0.0);
791 }
792
793 #[test]
794 fn test_cost_model_cross_join() {
795 let model = CostModel::new();
796 let join = JoinOp {
797 left: Box::new(LogicalOperator::Empty),
798 right: Box::new(LogicalOperator::Empty),
799 join_type: JoinType::Cross,
800 conditions: vec![],
801 };
802 let cost = model.join_cost(&join, 1000000.0);
803
804 assert!(cost.cpu > 0.0);
806 }
807
808 #[test]
809 fn test_cost_model_semi_join() {
810 let model = CostModel::new();
811 let join = JoinOp {
812 left: Box::new(LogicalOperator::Empty),
813 right: Box::new(LogicalOperator::Empty),
814 join_type: JoinType::Semi,
815 conditions: vec![],
816 };
817 let cost_semi = model.join_cost(&join, 1000.0);
818
819 let inner_join = JoinOp {
820 left: Box::new(LogicalOperator::Empty),
821 right: Box::new(LogicalOperator::Empty),
822 join_type: JoinType::Inner,
823 conditions: vec![],
824 };
825 let cost_inner = model.join_cost(&inner_join, 1000.0);
826
827 assert!(cost_semi.cpu > 0.0);
829 assert!(cost_inner.cpu > 0.0);
830 }
831
832 #[test]
833 fn test_cost_model_aggregate() {
834 let model = CostModel::new();
835 let agg = AggregateOp {
836 group_by: vec![],
837 aggregates: vec![
838 AggregateExpr {
839 function: AggregateFunction::Count,
840 expression: None,
841 expression2: None,
842 distinct: false,
843 alias: Some("cnt".to_string()),
844 percentile: None,
845 separator: None,
846 },
847 AggregateExpr {
848 function: AggregateFunction::Sum,
849 expression: Some(LogicalExpression::Variable("x".to_string())),
850 expression2: None,
851 distinct: false,
852 alias: Some("total".to_string()),
853 percentile: None,
854 separator: None,
855 },
856 ],
857 input: Box::new(LogicalOperator::Empty),
858 having: None,
859 };
860 let cost = model.aggregate_cost(&agg, 1000.0);
861
862 assert!(cost.cpu > 0.0);
864 assert!(cost.memory > 0.0);
865 }
866
867 #[test]
868 fn test_cost_model_distinct() {
869 let model = CostModel::new();
870 let distinct = DistinctOp {
871 input: Box::new(LogicalOperator::Empty),
872 columns: None,
873 };
874 let cost = model.distinct_cost(&distinct, 1000.0);
875
876 assert!(cost.cpu > 0.0);
878 assert!(cost.memory > 0.0);
879 }
880
881 #[test]
882 fn test_cost_model_limit() {
883 let model = CostModel::new();
884 let limit = LimitOp {
885 count: 10.into(),
886 input: Box::new(LogicalOperator::Empty),
887 };
888 let cost = model.limit_cost(&limit, 1000.0);
889
890 assert!(cost.cpu > 0.0);
892 assert!(cost.cpu < 1.0); }
894
895 #[test]
896 fn test_cost_model_skip() {
897 let model = CostModel::new();
898 let skip = SkipOp {
899 count: 100.into(),
900 input: Box::new(LogicalOperator::Empty),
901 };
902 let cost = model.skip_cost(&skip, 1000.0);
903
904 assert!(cost.cpu > 0.0);
906 }
907
908 #[test]
909 fn test_cost_model_return() {
910 let model = CostModel::new();
911 let ret = ReturnOp {
912 items: vec![
913 ReturnItem {
914 expression: LogicalExpression::Variable("a".to_string()),
915 alias: None,
916 },
917 ReturnItem {
918 expression: LogicalExpression::Variable("b".to_string()),
919 alias: None,
920 },
921 ],
922 distinct: false,
923 input: Box::new(LogicalOperator::Empty),
924 };
925 let cost = model.return_cost(&ret, 1000.0);
926
927 assert!(cost.cpu > 0.0);
929 }
930
931 #[test]
932 fn test_cost_cheaper() {
933 let model = CostModel::new();
934 let cheap = Cost::cpu(10.0);
935 let expensive = Cost::cpu(100.0);
936
937 assert_eq!(model.cheaper(&cheap, &expensive).total(), cheap.total());
938 assert_eq!(model.cheaper(&expensive, &cheap).total(), cheap.total());
939 }
940
941 #[test]
942 fn test_cost_comparison_prefers_lower_total() {
943 let model = CostModel::new();
944 let cpu_heavy = Cost::cpu(100.0).with_io(1.0);
946 let io_heavy = Cost::cpu(10.0).with_io(20.0);
948
949 assert!(cpu_heavy.total() < io_heavy.total());
951 assert_eq!(
952 model.cheaper(&cpu_heavy, &io_heavy).total(),
953 cpu_heavy.total()
954 );
955 }
956
957 #[test]
958 fn test_cost_model_sort_with_keys() {
959 let model = CostModel::new();
960 let sort_single = SortOp {
961 keys: vec![crate::query::plan::SortKey {
962 expression: LogicalExpression::Variable("a".to_string()),
963 order: SortOrder::Ascending,
964 nulls: None,
965 }],
966 input: Box::new(LogicalOperator::Empty),
967 };
968 let sort_multi = SortOp {
969 keys: vec![
970 crate::query::plan::SortKey {
971 expression: LogicalExpression::Variable("a".to_string()),
972 order: SortOrder::Ascending,
973 nulls: None,
974 },
975 crate::query::plan::SortKey {
976 expression: LogicalExpression::Variable("b".to_string()),
977 order: SortOrder::Descending,
978 nulls: None,
979 },
980 ],
981 input: Box::new(LogicalOperator::Empty),
982 };
983
984 let cost_single = model.sort_cost(&sort_single, 1000.0);
985 let cost_multi = model.sort_cost(&sort_multi, 1000.0);
986
987 assert!(cost_multi.cpu > cost_single.cpu);
989 }
990
991 #[test]
992 fn test_cost_model_empty_operator() {
993 let model = CostModel::new();
994 let cost = model.estimate(&LogicalOperator::Empty, 0.0);
995 assert!((cost.total()).abs() < 0.001);
996 }
997
998 #[test]
999 fn test_cost_model_default() {
1000 let model = CostModel::default();
1001 let scan = NodeScanOp {
1002 variable: "n".to_string(),
1003 label: None,
1004 input: None,
1005 };
1006 let cost = model.estimate(&LogicalOperator::NodeScan(scan), 100.0);
1007 assert!(cost.total() > 0.0);
1008 }
1009
1010 #[test]
1011 fn test_leapfrog_join_cost() {
1012 let model = CostModel::new();
1013
1014 let cardinalities = vec![1000.0, 1000.0, 1000.0];
1016 let cost = model.leapfrog_join_cost(3, &cardinalities, 100.0);
1017
1018 assert!(cost.cpu > 0.0);
1020 assert!(cost.memory > 0.0);
1022 }
1023
1024 #[test]
1025 fn test_leapfrog_join_cost_empty() {
1026 let model = CostModel::new();
1027 let cost = model.leapfrog_join_cost(0, &[], 0.0);
1028 assert!((cost.total()).abs() < 0.001);
1029 }
1030
1031 #[test]
1032 fn test_prefer_leapfrog_join_for_triangles() {
1033 let model = CostModel::new();
1034
1035 let cardinalities = vec![10000.0, 10000.0, 10000.0];
1037 let output = 1000.0;
1038
1039 let leapfrog_cost = model.leapfrog_join_cost(3, &cardinalities, output);
1040
1041 assert!(leapfrog_cost.cpu > 0.0);
1043 assert!(leapfrog_cost.memory > 0.0);
1044
1045 let _prefer = model.prefer_leapfrog_join(3, &cardinalities, output);
1048 }
1050
1051 #[test]
1052 fn test_prefer_leapfrog_join_binary_case() {
1053 let model = CostModel::new();
1054
1055 let cardinalities = vec![1000.0, 1000.0];
1057 let prefer = model.prefer_leapfrog_join(2, &cardinalities, 500.0);
1058 assert!(!prefer, "Binary joins should use hash join, not leapfrog");
1059 }
1060
1061 #[test]
1062 fn test_factorized_benefit_single_hop() {
1063 let model = CostModel::new();
1064
1065 let benefit = model.factorized_benefit(10.0, 1);
1067 assert!(
1068 (benefit - 1.0).abs() < 0.001,
1069 "Single hop should have no benefit"
1070 );
1071 }
1072
1073 #[test]
1074 fn test_factorized_benefit_multi_hop() {
1075 let model = CostModel::new();
1076
1077 let benefit = model.factorized_benefit(10.0, 3);
1079
1080 assert!(benefit <= 1.0, "Benefit should be <= 1.0");
1084 assert!(benefit > 0.0, "Benefit should be positive");
1085 }
1086
1087 #[test]
1088 fn test_factorized_benefit_low_fanout() {
1089 let model = CostModel::new();
1090
1091 let benefit = model.factorized_benefit(1.5, 2);
1093 assert!(
1094 benefit <= 1.0,
1095 "Low fanout still benefits from factorization"
1096 );
1097 }
1098}