1use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18#[derive(Debug, Clone)]
20pub struct JoinNode {
21 pub id: usize,
23 pub variable: String,
25 pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30 fn eq(&self, other: &Self) -> bool {
31 self.id == other.id && self.variable == other.variable
32 }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39 self.id.hash(state);
40 self.variable.hash(state);
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct JoinEdge {
47 pub from: usize,
49 pub to: usize,
51 pub conditions: Vec<JoinCondition>,
53}
54
55#[derive(Debug)]
57pub struct JoinGraph {
58 nodes: Vec<JoinNode>,
60 edges: Vec<JoinEdge>,
62 adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67 pub fn new() -> Self {
69 Self {
70 nodes: Vec::new(),
71 edges: Vec::new(),
72 adjacency: HashMap::new(),
73 }
74 }
75
76 pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78 let id = self.nodes.len();
79 self.nodes.push(JoinNode {
80 id,
81 variable,
82 relation,
83 });
84 self.adjacency.insert(id, HashSet::new());
85 id
86 }
87
88 pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94 self.edges.push(JoinEdge {
95 from,
96 to,
97 conditions,
98 });
99 self.adjacency
102 .get_mut(&from)
103 .expect("'from' node must be added via add_node() before add_edge()")
104 .insert(to);
105 self.adjacency
106 .get_mut(&to)
107 .expect("'to' node must be added via add_node() before add_edge()")
108 .insert(from);
109 }
110
111 pub fn node_count(&self) -> usize {
113 self.nodes.len()
114 }
115
116 pub fn nodes(&self) -> &[JoinNode] {
118 &self.nodes
119 }
120
121 pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123 self.adjacency.get(&node_id).into_iter().flatten().copied()
124 }
125
126 pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128 let mut conditions = Vec::new();
129 for edge in &self.edges {
130 let from_in_left = left.contains(edge.from);
131 let from_in_right = right.contains(edge.from);
132 let to_in_left = left.contains(edge.to);
133 let to_in_right = right.contains(edge.to);
134
135 if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137 conditions.extend(edge.conditions.clone());
138 }
139 }
140 conditions
141 }
142
143 pub fn edges(&self) -> &[JoinEdge] {
145 &self.edges
146 }
147
148 #[must_use]
152 pub fn is_cyclic(&self) -> bool {
153 if self.nodes.is_empty() {
154 return false;
155 }
156 self.edges.len() >= self.nodes.len()
157 }
158
159 pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
161 for edge in &self.edges {
162 let from_in_left = left.contains(edge.from);
163 let from_in_right = right.contains(edge.from);
164 let to_in_left = left.contains(edge.to);
165 let to_in_right = right.contains(edge.to);
166
167 if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
168 return true;
169 }
170 }
171 false
172 }
173}
174
175impl Default for JoinGraph {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
183pub struct BitSet(u64);
184
185impl BitSet {
186 pub fn empty() -> Self {
188 Self(0)
189 }
190
191 pub fn singleton(i: usize) -> Self {
193 Self(1 << i)
194 }
195
196 pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
198 let mut bits = 0u64;
199 for i in iter {
200 bits |= 1 << i;
201 }
202 Self(bits)
203 }
204
205 pub fn full(n: usize) -> Self {
207 if n == 0 {
208 Self::empty()
209 } else if n >= 64 {
210 Self(u64::MAX)
211 } else {
212 Self((1_u64 << n) - 1)
213 }
214 }
215
216 pub fn is_empty(&self) -> bool {
218 self.0 == 0
219 }
220
221 pub fn len(&self) -> usize {
223 self.0.count_ones() as usize
224 }
225
226 pub fn contains(&self, i: usize) -> bool {
228 (self.0 & (1 << i)) != 0
229 }
230
231 pub fn insert(&mut self, i: usize) {
233 self.0 |= 1 << i;
234 }
235
236 pub fn remove(&mut self, i: usize) {
238 self.0 &= !(1 << i);
239 }
240
241 pub fn union(self, other: Self) -> Self {
243 Self(self.0 | other.0)
244 }
245
246 pub fn intersection(self, other: Self) -> Self {
248 Self(self.0 & other.0)
249 }
250
251 pub fn difference(self, other: Self) -> Self {
253 Self(self.0 & !other.0)
254 }
255
256 pub fn is_subset_of(self, other: Self) -> bool {
258 (self.0 & other.0) == self.0
259 }
260
261 pub fn iter(self) -> impl Iterator<Item = usize> {
263 (0..64).filter(move |&i| self.contains(i))
264 }
265
266 pub fn subsets(self) -> SubsetIterator {
268 SubsetIterator {
269 full: self.0,
270 current: Some(self.0),
271 }
272 }
273}
274
275pub struct SubsetIterator {
277 full: u64,
278 current: Option<u64>,
279}
280
281impl Iterator for SubsetIterator {
282 type Item = BitSet;
283
284 fn next(&mut self) -> Option<Self::Item> {
285 let current = self.current?;
286 if current == 0 {
287 self.current = None;
288 return Some(BitSet(0));
289 }
290 let result = current;
291 self.current = Some((current.wrapping_sub(1)) & self.full);
293 if self.current == Some(self.full) {
294 self.current = None;
295 }
296 Some(BitSet(result))
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct JoinPlan {
303 pub nodes: BitSet,
305 pub operator: LogicalOperator,
307 pub cost: Cost,
309 pub cardinality: f64,
311}
312
313pub struct DPccp<'a> {
315 graph: &'a JoinGraph,
317 cost_model: &'a CostModel,
319 card_estimator: &'a CardinalityEstimator,
321 memo: HashMap<BitSet, JoinPlan>,
323 iterations: usize,
325}
326
327const DPCCP_ITERATION_BUDGET: usize = 100_000;
329
330impl<'a> DPccp<'a> {
331 pub fn new(
333 graph: &'a JoinGraph,
334 cost_model: &'a CostModel,
335 card_estimator: &'a CardinalityEstimator,
336 ) -> Self {
337 Self {
338 graph,
339 cost_model,
340 card_estimator,
341 memo: HashMap::new(),
342 iterations: 0,
343 }
344 }
345
346 pub fn optimize(&mut self) -> Option<JoinPlan> {
348 let n = self.graph.node_count();
349 if n == 0 {
350 return None;
351 }
352 if n == 1 {
353 let node = &self.graph.nodes[0];
354 let cardinality = self.card_estimator.estimate(&node.relation);
355 let cost = self.cost_model.estimate(&node.relation, cardinality);
356 return Some(JoinPlan {
357 nodes: BitSet::singleton(0),
358 operator: node.relation.clone(),
359 cost,
360 cardinality,
361 });
362 }
363
364 if n > 64 {
367 return None;
368 }
369
370 for (i, node) in self.graph.nodes.iter().enumerate() {
372 let subset = BitSet::singleton(i);
373 let cardinality = self.card_estimator.estimate(&node.relation);
374 let cost = self.cost_model.estimate(&node.relation, cardinality);
375 self.memo.insert(
376 subset,
377 JoinPlan {
378 nodes: subset,
379 operator: node.relation.clone(),
380 cost,
381 cardinality,
382 },
383 );
384 }
385
386 let full_set = BitSet::full(n);
388 self.enumerate_ccp(full_set);
389
390 self.memo.get(&full_set).cloned()
392 }
393
394 fn enumerate_ccp(&mut self, s: BitSet) {
396 for s1 in s.subsets() {
398 self.iterations += 1;
401 if self.iterations > DPCCP_ITERATION_BUDGET {
402 return;
403 }
404
405 if s1.is_empty() || s1 == s {
406 continue;
407 }
408
409 let s2 = s.difference(s1);
410 if s2.is_empty() {
411 continue;
412 }
413
414 if !self.is_connected(s1) || !self.is_connected(s2) {
416 continue;
417 }
418
419 if !self.graph.are_connected(&s1, &s2) {
421 continue;
422 }
423
424 if !self.memo.contains_key(&s1) {
426 self.enumerate_ccp(s1);
427 }
428 if !self.memo.contains_key(&s2) {
429 self.enumerate_ccp(s2);
430 }
431
432 if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
434 let conditions = self.graph.get_conditions(&s1, &s2);
435 let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
436
437 let should_update = self.memo.get(&s).map_or(true, |existing| {
439 new_plan.cost.total() < existing.cost.total()
440 });
441
442 if should_update {
443 self.memo.insert(s, new_plan);
444 }
445 }
446 }
447 }
448
449 fn is_connected(&self, subset: BitSet) -> bool {
451 if subset.len() <= 1 {
452 return true;
453 }
454
455 let start = subset
458 .iter()
459 .next()
460 .expect("subset is non-empty: len >= 2 checked on line 400");
461 let mut visited = BitSet::singleton(start);
462 let mut queue = vec![start];
463
464 while let Some(node) = queue.pop() {
465 for neighbor in self.graph.neighbors(node) {
466 if subset.contains(neighbor) && !visited.contains(neighbor) {
467 visited.insert(neighbor);
468 queue.push(neighbor);
469 }
470 }
471 }
472
473 visited == subset
474 }
475
476 fn build_join_plan(
478 &self,
479 left: JoinPlan,
480 right: JoinPlan,
481 conditions: Vec<JoinCondition>,
482 ) -> JoinPlan {
483 let nodes = left.nodes.union(right.nodes);
484
485 let join_op = LogicalOperator::Join(JoinOp {
487 left: Box::new(left.operator),
488 right: Box::new(right.operator),
489 join_type: JoinType::Inner,
490 conditions,
491 });
492
493 let cardinality = self.card_estimator.estimate(&join_op);
495
496 let join_cost = self.cost_model.estimate(&join_op, cardinality);
498 let total_cost = left.cost + right.cost + join_cost;
499
500 JoinPlan {
501 nodes,
502 operator: join_op,
503 cost: total_cost,
504 cardinality,
505 }
506 }
507}
508
509pub struct JoinGraphBuilder {
511 graph: JoinGraph,
512 variable_to_node: HashMap<String, usize>,
513}
514
515impl JoinGraphBuilder {
516 pub fn new() -> Self {
518 Self {
519 graph: JoinGraph::new(),
520 variable_to_node: HashMap::new(),
521 }
522 }
523
524 pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
526 let id = self.graph.add_node(variable.to_string(), relation);
527 self.variable_to_node.insert(variable.to_string(), id);
528 id
529 }
530
531 pub fn add_join_condition(
533 &mut self,
534 left_var: &str,
535 right_var: &str,
536 left_expr: LogicalExpression,
537 right_expr: LogicalExpression,
538 ) {
539 if let (Some(&left_id), Some(&right_id)) = (
540 self.variable_to_node.get(left_var),
541 self.variable_to_node.get(right_var),
542 ) {
543 self.graph.add_edge(
544 left_id,
545 right_id,
546 vec![JoinCondition {
547 left: left_expr,
548 right: right_expr,
549 }],
550 );
551 }
552 }
553
554 pub fn build(self) -> JoinGraph {
556 self.graph
557 }
558}
559
560impl Default for JoinGraphBuilder {
561 fn default() -> Self {
562 Self::new()
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use crate::query::plan::NodeScanOp;
570
571 fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
572 LogicalOperator::NodeScan(NodeScanOp {
573 variable: var.to_string(),
574 label: Some(label.to_string()),
575 input: None,
576 })
577 }
578
579 #[test]
580 fn test_bitset_operations() {
581 let a = BitSet::singleton(0);
582 let b = BitSet::singleton(1);
583 let _c = BitSet::singleton(2);
584
585 assert!(a.contains(0));
586 assert!(!a.contains(1));
587
588 let ab = a.union(b);
589 assert!(ab.contains(0));
590 assert!(ab.contains(1));
591 assert!(!ab.contains(2));
592
593 let full = BitSet::full(3);
594 assert_eq!(full.len(), 3);
595 assert!(full.contains(0));
596 assert!(full.contains(1));
597 assert!(full.contains(2));
598 }
599
600 #[test]
601 fn test_subset_iteration() {
602 let set = BitSet::from_iter([0, 1].into_iter());
603 let subsets: Vec<_> = set.subsets().collect();
604
605 assert_eq!(subsets.len(), 4);
607 assert!(subsets.contains(&BitSet::empty()));
608 assert!(subsets.contains(&BitSet::singleton(0)));
609 assert!(subsets.contains(&BitSet::singleton(1)));
610 assert!(subsets.contains(&set));
611 }
612
613 #[test]
614 fn test_join_graph_construction() {
615 let mut builder = JoinGraphBuilder::new();
616
617 builder.add_relation("a", create_node_scan("a", "Person"));
618 builder.add_relation("b", create_node_scan("b", "Person"));
619 builder.add_relation("c", create_node_scan("c", "Company"));
620
621 builder.add_join_condition(
622 "a",
623 "b",
624 LogicalExpression::Property {
625 variable: "a".to_string(),
626 property: "id".to_string(),
627 },
628 LogicalExpression::Property {
629 variable: "b".to_string(),
630 property: "friend_id".to_string(),
631 },
632 );
633
634 builder.add_join_condition(
635 "a",
636 "c",
637 LogicalExpression::Property {
638 variable: "a".to_string(),
639 property: "company_id".to_string(),
640 },
641 LogicalExpression::Property {
642 variable: "c".to_string(),
643 property: "id".to_string(),
644 },
645 );
646
647 let graph = builder.build();
648 assert_eq!(graph.node_count(), 3);
649 }
650
651 #[test]
652 fn test_dpccp_single_relation() {
653 let mut builder = JoinGraphBuilder::new();
654 builder.add_relation("a", create_node_scan("a", "Person"));
655 let graph = builder.build();
656
657 let cost_model = CostModel::new();
658 let mut card_estimator = CardinalityEstimator::new();
659 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
660
661 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
662 let plan = dpccp.optimize();
663
664 assert!(plan.is_some());
665 let plan = plan.unwrap();
666 assert_eq!(plan.nodes.len(), 1);
667 }
668
669 #[test]
670 fn test_dpccp_two_relations() {
671 let mut builder = JoinGraphBuilder::new();
672 builder.add_relation("a", create_node_scan("a", "Person"));
673 builder.add_relation("b", create_node_scan("b", "Company"));
674
675 builder.add_join_condition(
676 "a",
677 "b",
678 LogicalExpression::Property {
679 variable: "a".to_string(),
680 property: "company_id".to_string(),
681 },
682 LogicalExpression::Property {
683 variable: "b".to_string(),
684 property: "id".to_string(),
685 },
686 );
687
688 let graph = builder.build();
689
690 let cost_model = CostModel::new();
691 let mut card_estimator = CardinalityEstimator::new();
692 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
693 card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
694
695 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
696 let plan = dpccp.optimize();
697
698 assert!(plan.is_some());
699 let plan = plan.unwrap();
700 assert_eq!(plan.nodes.len(), 2);
701
702 if let LogicalOperator::Join(_) = plan.operator {
704 } else {
706 panic!("Expected Join operator");
707 }
708 }
709
710 #[test]
711 fn test_dpccp_three_relations_chain() {
712 let mut builder = JoinGraphBuilder::new();
714 builder.add_relation("a", create_node_scan("a", "Person"));
715 builder.add_relation("b", create_node_scan("b", "Person"));
716 builder.add_relation("c", create_node_scan("c", "Company"));
717
718 builder.add_join_condition(
719 "a",
720 "b",
721 LogicalExpression::Property {
722 variable: "a".to_string(),
723 property: "knows".to_string(),
724 },
725 LogicalExpression::Property {
726 variable: "b".to_string(),
727 property: "id".to_string(),
728 },
729 );
730
731 builder.add_join_condition(
732 "b",
733 "c",
734 LogicalExpression::Property {
735 variable: "b".to_string(),
736 property: "company_id".to_string(),
737 },
738 LogicalExpression::Property {
739 variable: "c".to_string(),
740 property: "id".to_string(),
741 },
742 );
743
744 let graph = builder.build();
745
746 let cost_model = CostModel::new();
747 let mut card_estimator = CardinalityEstimator::new();
748 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
749 card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
750
751 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
752 let plan = dpccp.optimize();
753
754 assert!(plan.is_some());
755 let plan = plan.unwrap();
756 assert_eq!(plan.nodes.len(), 3);
757 }
758
759 #[test]
760 fn test_dpccp_prefers_smaller_intermediate() {
761 let mut builder = JoinGraphBuilder::new();
767 builder.add_relation("s", create_node_scan("s", "Small"));
768 builder.add_relation("m", create_node_scan("m", "Medium"));
769 builder.add_relation("l", create_node_scan("l", "Large"));
770
771 builder.add_join_condition(
773 "s",
774 "m",
775 LogicalExpression::Property {
776 variable: "s".to_string(),
777 property: "id".to_string(),
778 },
779 LogicalExpression::Property {
780 variable: "m".to_string(),
781 property: "s_id".to_string(),
782 },
783 );
784
785 builder.add_join_condition(
786 "m",
787 "l",
788 LogicalExpression::Property {
789 variable: "m".to_string(),
790 property: "id".to_string(),
791 },
792 LogicalExpression::Property {
793 variable: "l".to_string(),
794 property: "m_id".to_string(),
795 },
796 );
797
798 let graph = builder.build();
799
800 let cost_model = CostModel::new();
801 let mut card_estimator = CardinalityEstimator::new();
802 card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
803 card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
804 card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
805
806 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
807 let plan = dpccp.optimize();
808
809 assert!(plan.is_some());
810 let plan = plan.unwrap();
811
812 assert_eq!(plan.nodes.len(), 3);
814
815 assert!(plan.cost.total() > 0.0);
818 }
819
820 #[test]
823 fn test_bitset_empty() {
824 let empty = BitSet::empty();
825 assert!(empty.is_empty());
826 assert_eq!(empty.len(), 0);
827 assert!(!empty.contains(0));
828 }
829
830 #[test]
831 fn test_bitset_insert_remove() {
832 let mut set = BitSet::empty();
833 set.insert(3);
834 assert!(set.contains(3));
835 assert_eq!(set.len(), 1);
836
837 set.insert(5);
838 assert!(set.contains(5));
839 assert_eq!(set.len(), 2);
840
841 set.remove(3);
842 assert!(!set.contains(3));
843 assert_eq!(set.len(), 1);
844 }
845
846 #[test]
847 fn test_bitset_intersection() {
848 let a = BitSet::from_iter([0, 1, 2].into_iter());
849 let b = BitSet::from_iter([1, 2, 3].into_iter());
850 let intersection = a.intersection(b);
851
852 assert!(intersection.contains(1));
853 assert!(intersection.contains(2));
854 assert!(!intersection.contains(0));
855 assert!(!intersection.contains(3));
856 assert_eq!(intersection.len(), 2);
857 }
858
859 #[test]
860 fn test_bitset_difference() {
861 let a = BitSet::from_iter([0, 1, 2].into_iter());
862 let b = BitSet::from_iter([1, 2, 3].into_iter());
863 let diff = a.difference(b);
864
865 assert!(diff.contains(0));
866 assert!(!diff.contains(1));
867 assert!(!diff.contains(2));
868 assert_eq!(diff.len(), 1);
869 }
870
871 #[test]
872 fn test_bitset_is_subset_of() {
873 let a = BitSet::from_iter([1, 2].into_iter());
874 let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
875
876 assert!(a.is_subset_of(b));
877 assert!(!b.is_subset_of(a));
878 assert!(a.is_subset_of(a));
879 }
880
881 #[test]
882 fn test_bitset_iter() {
883 let set = BitSet::from_iter([0, 2, 5].into_iter());
884 let elements: Vec<_> = set.iter().collect();
885
886 assert_eq!(elements, vec![0, 2, 5]);
887 }
888
889 #[test]
892 fn test_join_graph_empty() {
893 let graph = JoinGraph::new();
894 assert_eq!(graph.node_count(), 0);
895 }
896
897 #[test]
898 fn test_join_graph_neighbors() {
899 let mut builder = JoinGraphBuilder::new();
900 builder.add_relation("a", create_node_scan("a", "A"));
901 builder.add_relation("b", create_node_scan("b", "B"));
902 builder.add_relation("c", create_node_scan("c", "C"));
903
904 builder.add_join_condition(
905 "a",
906 "b",
907 LogicalExpression::Variable("a".to_string()),
908 LogicalExpression::Variable("b".to_string()),
909 );
910 builder.add_join_condition(
911 "a",
912 "c",
913 LogicalExpression::Variable("a".to_string()),
914 LogicalExpression::Variable("c".to_string()),
915 );
916
917 let graph = builder.build();
918
919 let neighbors_a: Vec<_> = graph.neighbors(0).collect();
921 assert_eq!(neighbors_a.len(), 2);
922 assert!(neighbors_a.contains(&1));
923 assert!(neighbors_a.contains(&2));
924
925 let neighbors_b: Vec<_> = graph.neighbors(1).collect();
927 assert_eq!(neighbors_b.len(), 1);
928 assert!(neighbors_b.contains(&0));
929 }
930
931 #[test]
932 fn test_join_graph_are_connected() {
933 let mut builder = JoinGraphBuilder::new();
934 builder.add_relation("a", create_node_scan("a", "A"));
935 builder.add_relation("b", create_node_scan("b", "B"));
936 builder.add_relation("c", create_node_scan("c", "C"));
937
938 builder.add_join_condition(
939 "a",
940 "b",
941 LogicalExpression::Variable("a".to_string()),
942 LogicalExpression::Variable("b".to_string()),
943 );
944
945 let graph = builder.build();
946
947 let set_a = BitSet::singleton(0);
948 let set_b = BitSet::singleton(1);
949 let set_c = BitSet::singleton(2);
950
951 assert!(graph.are_connected(&set_a, &set_b));
952 assert!(graph.are_connected(&set_b, &set_a));
953 assert!(!graph.are_connected(&set_a, &set_c));
954 assert!(!graph.are_connected(&set_b, &set_c));
955 }
956
957 #[test]
958 fn test_join_graph_get_conditions() {
959 let mut builder = JoinGraphBuilder::new();
960 builder.add_relation("a", create_node_scan("a", "A"));
961 builder.add_relation("b", create_node_scan("b", "B"));
962
963 builder.add_join_condition(
964 "a",
965 "b",
966 LogicalExpression::Property {
967 variable: "a".to_string(),
968 property: "id".to_string(),
969 },
970 LogicalExpression::Property {
971 variable: "b".to_string(),
972 property: "a_id".to_string(),
973 },
974 );
975
976 let graph = builder.build();
977
978 let set_a = BitSet::singleton(0);
979 let set_b = BitSet::singleton(1);
980
981 let conditions = graph.get_conditions(&set_a, &set_b);
982 assert_eq!(conditions.len(), 1);
983 }
984
985 #[test]
988 fn test_dpccp_empty_graph() {
989 let graph = JoinGraph::new();
990 let cost_model = CostModel::new();
991 let card_estimator = CardinalityEstimator::new();
992
993 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
994 let plan = dpccp.optimize();
995
996 assert!(plan.is_none());
997 }
998
999 #[test]
1000 fn test_dpccp_star_query() {
1001 let mut builder = JoinGraphBuilder::new();
1004 builder.add_relation("center", create_node_scan("center", "Center"));
1005 builder.add_relation("a", create_node_scan("a", "A"));
1006 builder.add_relation("b", create_node_scan("b", "B"));
1007 builder.add_relation("c", create_node_scan("c", "C"));
1008
1009 builder.add_join_condition(
1010 "center",
1011 "a",
1012 LogicalExpression::Variable("center".to_string()),
1013 LogicalExpression::Variable("a".to_string()),
1014 );
1015 builder.add_join_condition(
1016 "center",
1017 "b",
1018 LogicalExpression::Variable("center".to_string()),
1019 LogicalExpression::Variable("b".to_string()),
1020 );
1021 builder.add_join_condition(
1022 "center",
1023 "c",
1024 LogicalExpression::Variable("center".to_string()),
1025 LogicalExpression::Variable("c".to_string()),
1026 );
1027
1028 let graph = builder.build();
1029
1030 let cost_model = CostModel::new();
1031 let mut card_estimator = CardinalityEstimator::new();
1032 card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
1033 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
1034 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
1035 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
1036
1037 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1038 let plan = dpccp.optimize();
1039
1040 assert!(plan.is_some());
1041 let plan = plan.unwrap();
1042 assert_eq!(plan.nodes.len(), 4);
1043 assert!(plan.cost.total() > 0.0);
1044 }
1045
1046 #[test]
1047 fn test_dpccp_cycle_query() {
1048 let mut builder = JoinGraphBuilder::new();
1050 builder.add_relation("a", create_node_scan("a", "A"));
1051 builder.add_relation("b", create_node_scan("b", "B"));
1052 builder.add_relation("c", create_node_scan("c", "C"));
1053
1054 builder.add_join_condition(
1055 "a",
1056 "b",
1057 LogicalExpression::Variable("a".to_string()),
1058 LogicalExpression::Variable("b".to_string()),
1059 );
1060 builder.add_join_condition(
1061 "b",
1062 "c",
1063 LogicalExpression::Variable("b".to_string()),
1064 LogicalExpression::Variable("c".to_string()),
1065 );
1066 builder.add_join_condition(
1067 "c",
1068 "a",
1069 LogicalExpression::Variable("c".to_string()),
1070 LogicalExpression::Variable("a".to_string()),
1071 );
1072
1073 let graph = builder.build();
1074
1075 let cost_model = CostModel::new();
1076 let mut card_estimator = CardinalityEstimator::new();
1077 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1078 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1079 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1080
1081 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1082 let plan = dpccp.optimize();
1083
1084 assert!(plan.is_some());
1085 let plan = plan.unwrap();
1086 assert_eq!(plan.nodes.len(), 3);
1087 }
1088
1089 #[test]
1090 fn test_dpccp_four_relations() {
1091 let mut builder = JoinGraphBuilder::new();
1093 builder.add_relation("a", create_node_scan("a", "A"));
1094 builder.add_relation("b", create_node_scan("b", "B"));
1095 builder.add_relation("c", create_node_scan("c", "C"));
1096 builder.add_relation("d", create_node_scan("d", "D"));
1097
1098 builder.add_join_condition(
1099 "a",
1100 "b",
1101 LogicalExpression::Variable("a".to_string()),
1102 LogicalExpression::Variable("b".to_string()),
1103 );
1104 builder.add_join_condition(
1105 "b",
1106 "c",
1107 LogicalExpression::Variable("b".to_string()),
1108 LogicalExpression::Variable("c".to_string()),
1109 );
1110 builder.add_join_condition(
1111 "c",
1112 "d",
1113 LogicalExpression::Variable("c".to_string()),
1114 LogicalExpression::Variable("d".to_string()),
1115 );
1116
1117 let graph = builder.build();
1118
1119 let cost_model = CostModel::new();
1120 let mut card_estimator = CardinalityEstimator::new();
1121 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1122 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1123 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1124 card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1125
1126 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1127 let plan = dpccp.optimize();
1128
1129 assert!(plan.is_some());
1130 let plan = plan.unwrap();
1131 assert_eq!(plan.nodes.len(), 4);
1132 }
1133
1134 #[test]
1135 fn test_join_plan_cardinality_and_cost() {
1136 let mut builder = JoinGraphBuilder::new();
1137 builder.add_relation("a", create_node_scan("a", "A"));
1138 builder.add_relation("b", create_node_scan("b", "B"));
1139
1140 builder.add_join_condition(
1141 "a",
1142 "b",
1143 LogicalExpression::Variable("a".to_string()),
1144 LogicalExpression::Variable("b".to_string()),
1145 );
1146
1147 let graph = builder.build();
1148
1149 let cost_model = CostModel::new();
1150 let mut card_estimator = CardinalityEstimator::new();
1151 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1152 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1153
1154 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1155 let plan = dpccp.optimize().unwrap();
1156
1157 assert!(plan.cardinality > 0.0);
1159 assert!(plan.cost.total() > 0.0);
1160 }
1161
1162 #[test]
1163 fn test_join_graph_default() {
1164 let graph = JoinGraph::default();
1165 assert_eq!(graph.node_count(), 0);
1166 }
1167
1168 #[test]
1169 fn test_join_graph_builder_default() {
1170 let builder = JoinGraphBuilder::default();
1171 let graph = builder.build();
1172 assert_eq!(graph.node_count(), 0);
1173 }
1174
1175 #[test]
1176 fn test_join_graph_nodes_accessor() {
1177 let mut builder = JoinGraphBuilder::new();
1178 builder.add_relation("a", create_node_scan("a", "A"));
1179 builder.add_relation("b", create_node_scan("b", "B"));
1180
1181 let graph = builder.build();
1182 let nodes = graph.nodes();
1183
1184 assert_eq!(nodes.len(), 2);
1185 assert_eq!(nodes[0].variable, "a");
1186 assert_eq!(nodes[1].variable, "b");
1187 }
1188
1189 #[test]
1190 fn test_join_node_equality() {
1191 let node1 = JoinNode {
1192 id: 0,
1193 variable: "a".to_string(),
1194 relation: create_node_scan("a", "A"),
1195 };
1196 let node2 = JoinNode {
1197 id: 0,
1198 variable: "a".to_string(),
1199 relation: create_node_scan("a", "A"),
1200 };
1201 let node3 = JoinNode {
1202 id: 1,
1203 variable: "a".to_string(),
1204 relation: create_node_scan("a", "A"),
1205 };
1206
1207 assert_eq!(node1, node2);
1208 assert_ne!(node1, node3);
1209 }
1210
1211 #[test]
1212 fn test_join_node_hash() {
1213 use std::collections::HashSet;
1214
1215 let node1 = JoinNode {
1216 id: 0,
1217 variable: "a".to_string(),
1218 relation: create_node_scan("a", "A"),
1219 };
1220 let node2 = JoinNode {
1221 id: 0,
1222 variable: "a".to_string(),
1223 relation: create_node_scan("a", "A"),
1224 };
1225
1226 let mut set = HashSet::new();
1227 set.insert(node1.clone());
1228
1229 assert!(set.contains(&node2));
1231 }
1232
1233 #[test]
1234 fn test_add_join_condition_unknown_variable() {
1235 let mut builder = JoinGraphBuilder::new();
1236 builder.add_relation("a", create_node_scan("a", "A"));
1237
1238 builder.add_join_condition(
1240 "a",
1241 "unknown",
1242 LogicalExpression::Variable("a".to_string()),
1243 LogicalExpression::Variable("unknown".to_string()),
1244 );
1245
1246 let graph = builder.build();
1247 assert_eq!(graph.node_count(), 1);
1248 }
1249
1250 #[test]
1251 fn test_dpccp_with_different_cardinalities() {
1252 let mut builder = JoinGraphBuilder::new();
1254 builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1255 builder.add_relation("huge", create_node_scan("huge", "Huge"));
1256
1257 builder.add_join_condition(
1258 "tiny",
1259 "huge",
1260 LogicalExpression::Variable("tiny".to_string()),
1261 LogicalExpression::Variable("huge".to_string()),
1262 );
1263
1264 let graph = builder.build();
1265
1266 let cost_model = CostModel::new();
1267 let mut card_estimator = CardinalityEstimator::new();
1268 card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1269 card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1270
1271 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1272 let plan = dpccp.optimize();
1273
1274 assert!(plan.is_some());
1275 let plan = plan.unwrap();
1276 assert_eq!(plan.nodes.len(), 2);
1277 }
1278
1279 #[test]
1280 fn test_join_graph_cyclic_triangle() {
1281 let mut builder = JoinGraphBuilder::new();
1283 builder.add_relation("a", create_node_scan("a", "A"));
1284 builder.add_relation("b", create_node_scan("b", "B"));
1285 builder.add_relation("c", create_node_scan("c", "C"));
1286
1287 builder.add_join_condition(
1288 "a",
1289 "b",
1290 LogicalExpression::Variable("a".to_string()),
1291 LogicalExpression::Variable("b".to_string()),
1292 );
1293 builder.add_join_condition(
1294 "b",
1295 "c",
1296 LogicalExpression::Variable("b".to_string()),
1297 LogicalExpression::Variable("c".to_string()),
1298 );
1299 builder.add_join_condition(
1300 "c",
1301 "a",
1302 LogicalExpression::Variable("c".to_string()),
1303 LogicalExpression::Variable("a".to_string()),
1304 );
1305
1306 let graph = builder.build();
1307 assert!(graph.is_cyclic());
1308 }
1309
1310 #[test]
1311 fn test_join_graph_acyclic_chain() {
1312 let mut builder = JoinGraphBuilder::new();
1314 builder.add_relation("a", create_node_scan("a", "A"));
1315 builder.add_relation("b", create_node_scan("b", "B"));
1316 builder.add_relation("c", create_node_scan("c", "C"));
1317
1318 builder.add_join_condition(
1319 "a",
1320 "b",
1321 LogicalExpression::Variable("a".to_string()),
1322 LogicalExpression::Variable("b".to_string()),
1323 );
1324 builder.add_join_condition(
1325 "b",
1326 "c",
1327 LogicalExpression::Variable("b".to_string()),
1328 LogicalExpression::Variable("c".to_string()),
1329 );
1330
1331 let graph = builder.build();
1332 assert!(!graph.is_cyclic());
1333 }
1334
1335 #[test]
1336 fn test_join_graph_empty_not_cyclic() {
1337 let graph = JoinGraph::new();
1338 assert!(!graph.is_cyclic());
1339 }
1340
1341 #[test]
1342 fn test_join_graph_edges_accessor() {
1343 let mut builder = JoinGraphBuilder::new();
1344 builder.add_relation("a", create_node_scan("a", "A"));
1345 builder.add_relation("b", create_node_scan("b", "B"));
1346
1347 builder.add_join_condition(
1348 "a",
1349 "b",
1350 LogicalExpression::Variable("a".to_string()),
1351 LogicalExpression::Variable("b".to_string()),
1352 );
1353
1354 let graph = builder.build();
1355 assert_eq!(graph.edges().len(), 1);
1356 }
1357
1358 #[test]
1359 fn test_bitset_full_boundary_values() {
1360 assert_eq!(BitSet::full(0), BitSet::empty());
1362 assert_eq!(BitSet::full(1).0, 1); assert_eq!(BitSet::full(2).0, 3); assert_eq!(BitSet::full(63).0, u64::MAX >> 1); assert_eq!(BitSet::full(64).0, u64::MAX); }
1367
1368 #[test]
1369 fn test_bitset_full_overflow_protection() {
1370 let big = BitSet::full(65);
1372 assert_eq!(big.0, u64::MAX);
1373 let huge = BitSet::full(100);
1374 assert_eq!(huge.0, u64::MAX);
1375 let max = BitSet::full(usize::MAX);
1376 assert_eq!(max.0, u64::MAX);
1377 }
1378}