1use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18#[derive(Debug, Clone)]
20pub struct JoinNode {
21 pub id: usize,
23 pub variable: String,
25 pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30 fn eq(&self, other: &Self) -> bool {
31 self.id == other.id && self.variable == other.variable
32 }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39 self.id.hash(state);
40 self.variable.hash(state);
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct JoinEdge {
47 pub from: usize,
49 pub to: usize,
51 pub conditions: Vec<JoinCondition>,
53}
54
55#[derive(Debug)]
57pub struct JoinGraph {
58 nodes: Vec<JoinNode>,
60 edges: Vec<JoinEdge>,
62 adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67 pub fn new() -> Self {
69 Self {
70 nodes: Vec::new(),
71 edges: Vec::new(),
72 adjacency: HashMap::new(),
73 }
74 }
75
76 pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78 let id = self.nodes.len();
79 self.nodes.push(JoinNode {
80 id,
81 variable,
82 relation,
83 });
84 self.adjacency.insert(id, HashSet::new());
85 id
86 }
87
88 pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94 self.edges.push(JoinEdge {
95 from,
96 to,
97 conditions,
98 });
99 self.adjacency
102 .get_mut(&from)
103 .expect("'from' node must be added via add_node() before add_edge()")
104 .insert(to);
105 self.adjacency
106 .get_mut(&to)
107 .expect("'to' node must be added via add_node() before add_edge()")
108 .insert(from);
109 }
110
111 pub fn node_count(&self) -> usize {
113 self.nodes.len()
114 }
115
116 pub fn nodes(&self) -> &[JoinNode] {
118 &self.nodes
119 }
120
121 pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123 self.adjacency.get(&node_id).into_iter().flatten().copied()
124 }
125
126 pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128 let mut conditions = Vec::new();
129 for edge in &self.edges {
130 let from_in_left = left.contains(edge.from);
131 let from_in_right = right.contains(edge.from);
132 let to_in_left = left.contains(edge.to);
133 let to_in_right = right.contains(edge.to);
134
135 if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137 conditions.extend(edge.conditions.clone());
138 }
139 }
140 conditions
141 }
142
143 pub fn edges(&self) -> &[JoinEdge] {
145 &self.edges
146 }
147
148 #[must_use]
152 pub fn is_cyclic(&self) -> bool {
153 if self.nodes.is_empty() {
154 return false;
155 }
156 self.edges.len() >= self.nodes.len()
157 }
158
159 pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
161 for edge in &self.edges {
162 let from_in_left = left.contains(edge.from);
163 let from_in_right = right.contains(edge.from);
164 let to_in_left = left.contains(edge.to);
165 let to_in_right = right.contains(edge.to);
166
167 if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
168 return true;
169 }
170 }
171 false
172 }
173}
174
175impl Default for JoinGraph {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
183pub struct BitSet(u64);
184
185impl BitSet {
186 pub fn empty() -> Self {
188 Self(0)
189 }
190
191 pub fn singleton(i: usize) -> Self {
193 Self(1 << i)
194 }
195
196 pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
198 let mut bits = 0u64;
199 for i in iter {
200 bits |= 1 << i;
201 }
202 Self(bits)
203 }
204
205 pub fn full(n: usize) -> Self {
207 Self((1 << n) - 1)
208 }
209
210 pub fn is_empty(&self) -> bool {
212 self.0 == 0
213 }
214
215 pub fn len(&self) -> usize {
217 self.0.count_ones() as usize
218 }
219
220 pub fn contains(&self, i: usize) -> bool {
222 (self.0 & (1 << i)) != 0
223 }
224
225 pub fn insert(&mut self, i: usize) {
227 self.0 |= 1 << i;
228 }
229
230 pub fn remove(&mut self, i: usize) {
232 self.0 &= !(1 << i);
233 }
234
235 pub fn union(self, other: Self) -> Self {
237 Self(self.0 | other.0)
238 }
239
240 pub fn intersection(self, other: Self) -> Self {
242 Self(self.0 & other.0)
243 }
244
245 pub fn difference(self, other: Self) -> Self {
247 Self(self.0 & !other.0)
248 }
249
250 pub fn is_subset_of(self, other: Self) -> bool {
252 (self.0 & other.0) == self.0
253 }
254
255 pub fn iter(self) -> impl Iterator<Item = usize> {
257 (0..64).filter(move |&i| self.contains(i))
258 }
259
260 pub fn subsets(self) -> SubsetIterator {
262 SubsetIterator {
263 full: self.0,
264 current: Some(self.0),
265 }
266 }
267}
268
269pub struct SubsetIterator {
271 full: u64,
272 current: Option<u64>,
273}
274
275impl Iterator for SubsetIterator {
276 type Item = BitSet;
277
278 fn next(&mut self) -> Option<Self::Item> {
279 let current = self.current?;
280 if current == 0 {
281 self.current = None;
282 return Some(BitSet(0));
283 }
284 let result = current;
285 self.current = Some((current.wrapping_sub(1)) & self.full);
287 if self.current == Some(self.full) {
288 self.current = None;
289 }
290 Some(BitSet(result))
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct JoinPlan {
297 pub nodes: BitSet,
299 pub operator: LogicalOperator,
301 pub cost: Cost,
303 pub cardinality: f64,
305}
306
307pub struct DPccp<'a> {
309 graph: &'a JoinGraph,
311 cost_model: &'a CostModel,
313 card_estimator: &'a CardinalityEstimator,
315 memo: HashMap<BitSet, JoinPlan>,
317}
318
319impl<'a> DPccp<'a> {
320 pub fn new(
322 graph: &'a JoinGraph,
323 cost_model: &'a CostModel,
324 card_estimator: &'a CardinalityEstimator,
325 ) -> Self {
326 Self {
327 graph,
328 cost_model,
329 card_estimator,
330 memo: HashMap::new(),
331 }
332 }
333
334 pub fn optimize(&mut self) -> Option<JoinPlan> {
336 let n = self.graph.node_count();
337 if n == 0 {
338 return None;
339 }
340 if n == 1 {
341 let node = &self.graph.nodes[0];
342 let cardinality = self.card_estimator.estimate(&node.relation);
343 let cost = self.cost_model.estimate(&node.relation, cardinality);
344 return Some(JoinPlan {
345 nodes: BitSet::singleton(0),
346 operator: node.relation.clone(),
347 cost,
348 cardinality,
349 });
350 }
351
352 for (i, node) in self.graph.nodes.iter().enumerate() {
354 let subset = BitSet::singleton(i);
355 let cardinality = self.card_estimator.estimate(&node.relation);
356 let cost = self.cost_model.estimate(&node.relation, cardinality);
357 self.memo.insert(
358 subset,
359 JoinPlan {
360 nodes: subset,
361 operator: node.relation.clone(),
362 cost,
363 cardinality,
364 },
365 );
366 }
367
368 let full_set = BitSet::full(n);
370 self.enumerate_ccp(full_set);
371
372 self.memo.get(&full_set).cloned()
374 }
375
376 fn enumerate_ccp(&mut self, s: BitSet) {
378 for s1 in s.subsets() {
380 if s1.is_empty() || s1 == s {
381 continue;
382 }
383
384 let s2 = s.difference(s1);
385 if s2.is_empty() {
386 continue;
387 }
388
389 if !self.is_connected(s1) || !self.is_connected(s2) {
391 continue;
392 }
393
394 if !self.graph.are_connected(&s1, &s2) {
396 continue;
397 }
398
399 if !self.memo.contains_key(&s1) {
401 self.enumerate_ccp(s1);
402 }
403 if !self.memo.contains_key(&s2) {
404 self.enumerate_ccp(s2);
405 }
406
407 if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
409 let conditions = self.graph.get_conditions(&s1, &s2);
410 let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
411
412 let should_update = self.memo.get(&s).map_or(true, |existing| {
414 new_plan.cost.total() < existing.cost.total()
415 });
416
417 if should_update {
418 self.memo.insert(s, new_plan);
419 }
420 }
421 }
422 }
423
424 fn is_connected(&self, subset: BitSet) -> bool {
426 if subset.len() <= 1 {
427 return true;
428 }
429
430 let start = subset
433 .iter()
434 .next()
435 .expect("subset is non-empty: len >= 2 checked on line 400");
436 let mut visited = BitSet::singleton(start);
437 let mut queue = vec![start];
438
439 while let Some(node) = queue.pop() {
440 for neighbor in self.graph.neighbors(node) {
441 if subset.contains(neighbor) && !visited.contains(neighbor) {
442 visited.insert(neighbor);
443 queue.push(neighbor);
444 }
445 }
446 }
447
448 visited == subset
449 }
450
451 fn build_join_plan(
453 &self,
454 left: JoinPlan,
455 right: JoinPlan,
456 conditions: Vec<JoinCondition>,
457 ) -> JoinPlan {
458 let nodes = left.nodes.union(right.nodes);
459
460 let join_op = LogicalOperator::Join(JoinOp {
462 left: Box::new(left.operator),
463 right: Box::new(right.operator),
464 join_type: JoinType::Inner,
465 conditions,
466 });
467
468 let cardinality = self.card_estimator.estimate(&join_op);
470
471 let join_cost = self.cost_model.estimate(&join_op, cardinality);
473 let total_cost = left.cost + right.cost + join_cost;
474
475 JoinPlan {
476 nodes,
477 operator: join_op,
478 cost: total_cost,
479 cardinality,
480 }
481 }
482}
483
484pub struct JoinGraphBuilder {
486 graph: JoinGraph,
487 variable_to_node: HashMap<String, usize>,
488}
489
490impl JoinGraphBuilder {
491 pub fn new() -> Self {
493 Self {
494 graph: JoinGraph::new(),
495 variable_to_node: HashMap::new(),
496 }
497 }
498
499 pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
501 let id = self.graph.add_node(variable.to_string(), relation);
502 self.variable_to_node.insert(variable.to_string(), id);
503 id
504 }
505
506 pub fn add_join_condition(
508 &mut self,
509 left_var: &str,
510 right_var: &str,
511 left_expr: LogicalExpression,
512 right_expr: LogicalExpression,
513 ) {
514 if let (Some(&left_id), Some(&right_id)) = (
515 self.variable_to_node.get(left_var),
516 self.variable_to_node.get(right_var),
517 ) {
518 self.graph.add_edge(
519 left_id,
520 right_id,
521 vec![JoinCondition {
522 left: left_expr,
523 right: right_expr,
524 }],
525 );
526 }
527 }
528
529 pub fn build(self) -> JoinGraph {
531 self.graph
532 }
533}
534
535impl Default for JoinGraphBuilder {
536 fn default() -> Self {
537 Self::new()
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::query::plan::NodeScanOp;
545
546 fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
547 LogicalOperator::NodeScan(NodeScanOp {
548 variable: var.to_string(),
549 label: Some(label.to_string()),
550 input: None,
551 })
552 }
553
554 #[test]
555 fn test_bitset_operations() {
556 let a = BitSet::singleton(0);
557 let b = BitSet::singleton(1);
558 let _c = BitSet::singleton(2);
559
560 assert!(a.contains(0));
561 assert!(!a.contains(1));
562
563 let ab = a.union(b);
564 assert!(ab.contains(0));
565 assert!(ab.contains(1));
566 assert!(!ab.contains(2));
567
568 let full = BitSet::full(3);
569 assert_eq!(full.len(), 3);
570 assert!(full.contains(0));
571 assert!(full.contains(1));
572 assert!(full.contains(2));
573 }
574
575 #[test]
576 fn test_subset_iteration() {
577 let set = BitSet::from_iter([0, 1].into_iter());
578 let subsets: Vec<_> = set.subsets().collect();
579
580 assert_eq!(subsets.len(), 4);
582 assert!(subsets.contains(&BitSet::empty()));
583 assert!(subsets.contains(&BitSet::singleton(0)));
584 assert!(subsets.contains(&BitSet::singleton(1)));
585 assert!(subsets.contains(&set));
586 }
587
588 #[test]
589 fn test_join_graph_construction() {
590 let mut builder = JoinGraphBuilder::new();
591
592 builder.add_relation("a", create_node_scan("a", "Person"));
593 builder.add_relation("b", create_node_scan("b", "Person"));
594 builder.add_relation("c", create_node_scan("c", "Company"));
595
596 builder.add_join_condition(
597 "a",
598 "b",
599 LogicalExpression::Property {
600 variable: "a".to_string(),
601 property: "id".to_string(),
602 },
603 LogicalExpression::Property {
604 variable: "b".to_string(),
605 property: "friend_id".to_string(),
606 },
607 );
608
609 builder.add_join_condition(
610 "a",
611 "c",
612 LogicalExpression::Property {
613 variable: "a".to_string(),
614 property: "company_id".to_string(),
615 },
616 LogicalExpression::Property {
617 variable: "c".to_string(),
618 property: "id".to_string(),
619 },
620 );
621
622 let graph = builder.build();
623 assert_eq!(graph.node_count(), 3);
624 }
625
626 #[test]
627 fn test_dpccp_single_relation() {
628 let mut builder = JoinGraphBuilder::new();
629 builder.add_relation("a", create_node_scan("a", "Person"));
630 let graph = builder.build();
631
632 let cost_model = CostModel::new();
633 let mut card_estimator = CardinalityEstimator::new();
634 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
635
636 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
637 let plan = dpccp.optimize();
638
639 assert!(plan.is_some());
640 let plan = plan.unwrap();
641 assert_eq!(plan.nodes.len(), 1);
642 }
643
644 #[test]
645 fn test_dpccp_two_relations() {
646 let mut builder = JoinGraphBuilder::new();
647 builder.add_relation("a", create_node_scan("a", "Person"));
648 builder.add_relation("b", create_node_scan("b", "Company"));
649
650 builder.add_join_condition(
651 "a",
652 "b",
653 LogicalExpression::Property {
654 variable: "a".to_string(),
655 property: "company_id".to_string(),
656 },
657 LogicalExpression::Property {
658 variable: "b".to_string(),
659 property: "id".to_string(),
660 },
661 );
662
663 let graph = builder.build();
664
665 let cost_model = CostModel::new();
666 let mut card_estimator = CardinalityEstimator::new();
667 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
668 card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
669
670 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
671 let plan = dpccp.optimize();
672
673 assert!(plan.is_some());
674 let plan = plan.unwrap();
675 assert_eq!(plan.nodes.len(), 2);
676
677 if let LogicalOperator::Join(_) = plan.operator {
679 } else {
681 panic!("Expected Join operator");
682 }
683 }
684
685 #[test]
686 fn test_dpccp_three_relations_chain() {
687 let mut builder = JoinGraphBuilder::new();
689 builder.add_relation("a", create_node_scan("a", "Person"));
690 builder.add_relation("b", create_node_scan("b", "Person"));
691 builder.add_relation("c", create_node_scan("c", "Company"));
692
693 builder.add_join_condition(
694 "a",
695 "b",
696 LogicalExpression::Property {
697 variable: "a".to_string(),
698 property: "knows".to_string(),
699 },
700 LogicalExpression::Property {
701 variable: "b".to_string(),
702 property: "id".to_string(),
703 },
704 );
705
706 builder.add_join_condition(
707 "b",
708 "c",
709 LogicalExpression::Property {
710 variable: "b".to_string(),
711 property: "company_id".to_string(),
712 },
713 LogicalExpression::Property {
714 variable: "c".to_string(),
715 property: "id".to_string(),
716 },
717 );
718
719 let graph = builder.build();
720
721 let cost_model = CostModel::new();
722 let mut card_estimator = CardinalityEstimator::new();
723 card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
724 card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
725
726 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
727 let plan = dpccp.optimize();
728
729 assert!(plan.is_some());
730 let plan = plan.unwrap();
731 assert_eq!(plan.nodes.len(), 3);
732 }
733
734 #[test]
735 fn test_dpccp_prefers_smaller_intermediate() {
736 let mut builder = JoinGraphBuilder::new();
742 builder.add_relation("s", create_node_scan("s", "Small"));
743 builder.add_relation("m", create_node_scan("m", "Medium"));
744 builder.add_relation("l", create_node_scan("l", "Large"));
745
746 builder.add_join_condition(
748 "s",
749 "m",
750 LogicalExpression::Property {
751 variable: "s".to_string(),
752 property: "id".to_string(),
753 },
754 LogicalExpression::Property {
755 variable: "m".to_string(),
756 property: "s_id".to_string(),
757 },
758 );
759
760 builder.add_join_condition(
761 "m",
762 "l",
763 LogicalExpression::Property {
764 variable: "m".to_string(),
765 property: "id".to_string(),
766 },
767 LogicalExpression::Property {
768 variable: "l".to_string(),
769 property: "m_id".to_string(),
770 },
771 );
772
773 let graph = builder.build();
774
775 let cost_model = CostModel::new();
776 let mut card_estimator = CardinalityEstimator::new();
777 card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
778 card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
779 card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
780
781 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
782 let plan = dpccp.optimize();
783
784 assert!(plan.is_some());
785 let plan = plan.unwrap();
786
787 assert_eq!(plan.nodes.len(), 3);
789
790 assert!(plan.cost.total() > 0.0);
793 }
794
795 #[test]
798 fn test_bitset_empty() {
799 let empty = BitSet::empty();
800 assert!(empty.is_empty());
801 assert_eq!(empty.len(), 0);
802 assert!(!empty.contains(0));
803 }
804
805 #[test]
806 fn test_bitset_insert_remove() {
807 let mut set = BitSet::empty();
808 set.insert(3);
809 assert!(set.contains(3));
810 assert_eq!(set.len(), 1);
811
812 set.insert(5);
813 assert!(set.contains(5));
814 assert_eq!(set.len(), 2);
815
816 set.remove(3);
817 assert!(!set.contains(3));
818 assert_eq!(set.len(), 1);
819 }
820
821 #[test]
822 fn test_bitset_intersection() {
823 let a = BitSet::from_iter([0, 1, 2].into_iter());
824 let b = BitSet::from_iter([1, 2, 3].into_iter());
825 let intersection = a.intersection(b);
826
827 assert!(intersection.contains(1));
828 assert!(intersection.contains(2));
829 assert!(!intersection.contains(0));
830 assert!(!intersection.contains(3));
831 assert_eq!(intersection.len(), 2);
832 }
833
834 #[test]
835 fn test_bitset_difference() {
836 let a = BitSet::from_iter([0, 1, 2].into_iter());
837 let b = BitSet::from_iter([1, 2, 3].into_iter());
838 let diff = a.difference(b);
839
840 assert!(diff.contains(0));
841 assert!(!diff.contains(1));
842 assert!(!diff.contains(2));
843 assert_eq!(diff.len(), 1);
844 }
845
846 #[test]
847 fn test_bitset_is_subset_of() {
848 let a = BitSet::from_iter([1, 2].into_iter());
849 let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
850
851 assert!(a.is_subset_of(b));
852 assert!(!b.is_subset_of(a));
853 assert!(a.is_subset_of(a));
854 }
855
856 #[test]
857 fn test_bitset_iter() {
858 let set = BitSet::from_iter([0, 2, 5].into_iter());
859 let elements: Vec<_> = set.iter().collect();
860
861 assert_eq!(elements, vec![0, 2, 5]);
862 }
863
864 #[test]
867 fn test_join_graph_empty() {
868 let graph = JoinGraph::new();
869 assert_eq!(graph.node_count(), 0);
870 }
871
872 #[test]
873 fn test_join_graph_neighbors() {
874 let mut builder = JoinGraphBuilder::new();
875 builder.add_relation("a", create_node_scan("a", "A"));
876 builder.add_relation("b", create_node_scan("b", "B"));
877 builder.add_relation("c", create_node_scan("c", "C"));
878
879 builder.add_join_condition(
880 "a",
881 "b",
882 LogicalExpression::Variable("a".to_string()),
883 LogicalExpression::Variable("b".to_string()),
884 );
885 builder.add_join_condition(
886 "a",
887 "c",
888 LogicalExpression::Variable("a".to_string()),
889 LogicalExpression::Variable("c".to_string()),
890 );
891
892 let graph = builder.build();
893
894 let neighbors_a: Vec<_> = graph.neighbors(0).collect();
896 assert_eq!(neighbors_a.len(), 2);
897 assert!(neighbors_a.contains(&1));
898 assert!(neighbors_a.contains(&2));
899
900 let neighbors_b: Vec<_> = graph.neighbors(1).collect();
902 assert_eq!(neighbors_b.len(), 1);
903 assert!(neighbors_b.contains(&0));
904 }
905
906 #[test]
907 fn test_join_graph_are_connected() {
908 let mut builder = JoinGraphBuilder::new();
909 builder.add_relation("a", create_node_scan("a", "A"));
910 builder.add_relation("b", create_node_scan("b", "B"));
911 builder.add_relation("c", create_node_scan("c", "C"));
912
913 builder.add_join_condition(
914 "a",
915 "b",
916 LogicalExpression::Variable("a".to_string()),
917 LogicalExpression::Variable("b".to_string()),
918 );
919
920 let graph = builder.build();
921
922 let set_a = BitSet::singleton(0);
923 let set_b = BitSet::singleton(1);
924 let set_c = BitSet::singleton(2);
925
926 assert!(graph.are_connected(&set_a, &set_b));
927 assert!(graph.are_connected(&set_b, &set_a));
928 assert!(!graph.are_connected(&set_a, &set_c));
929 assert!(!graph.are_connected(&set_b, &set_c));
930 }
931
932 #[test]
933 fn test_join_graph_get_conditions() {
934 let mut builder = JoinGraphBuilder::new();
935 builder.add_relation("a", create_node_scan("a", "A"));
936 builder.add_relation("b", create_node_scan("b", "B"));
937
938 builder.add_join_condition(
939 "a",
940 "b",
941 LogicalExpression::Property {
942 variable: "a".to_string(),
943 property: "id".to_string(),
944 },
945 LogicalExpression::Property {
946 variable: "b".to_string(),
947 property: "a_id".to_string(),
948 },
949 );
950
951 let graph = builder.build();
952
953 let set_a = BitSet::singleton(0);
954 let set_b = BitSet::singleton(1);
955
956 let conditions = graph.get_conditions(&set_a, &set_b);
957 assert_eq!(conditions.len(), 1);
958 }
959
960 #[test]
963 fn test_dpccp_empty_graph() {
964 let graph = JoinGraph::new();
965 let cost_model = CostModel::new();
966 let card_estimator = CardinalityEstimator::new();
967
968 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
969 let plan = dpccp.optimize();
970
971 assert!(plan.is_none());
972 }
973
974 #[test]
975 fn test_dpccp_star_query() {
976 let mut builder = JoinGraphBuilder::new();
979 builder.add_relation("center", create_node_scan("center", "Center"));
980 builder.add_relation("a", create_node_scan("a", "A"));
981 builder.add_relation("b", create_node_scan("b", "B"));
982 builder.add_relation("c", create_node_scan("c", "C"));
983
984 builder.add_join_condition(
985 "center",
986 "a",
987 LogicalExpression::Variable("center".to_string()),
988 LogicalExpression::Variable("a".to_string()),
989 );
990 builder.add_join_condition(
991 "center",
992 "b",
993 LogicalExpression::Variable("center".to_string()),
994 LogicalExpression::Variable("b".to_string()),
995 );
996 builder.add_join_condition(
997 "center",
998 "c",
999 LogicalExpression::Variable("center".to_string()),
1000 LogicalExpression::Variable("c".to_string()),
1001 );
1002
1003 let graph = builder.build();
1004
1005 let cost_model = CostModel::new();
1006 let mut card_estimator = CardinalityEstimator::new();
1007 card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
1008 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
1009 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
1010 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
1011
1012 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1013 let plan = dpccp.optimize();
1014
1015 assert!(plan.is_some());
1016 let plan = plan.unwrap();
1017 assert_eq!(plan.nodes.len(), 4);
1018 assert!(plan.cost.total() > 0.0);
1019 }
1020
1021 #[test]
1022 fn test_dpccp_cycle_query() {
1023 let mut builder = JoinGraphBuilder::new();
1025 builder.add_relation("a", create_node_scan("a", "A"));
1026 builder.add_relation("b", create_node_scan("b", "B"));
1027 builder.add_relation("c", create_node_scan("c", "C"));
1028
1029 builder.add_join_condition(
1030 "a",
1031 "b",
1032 LogicalExpression::Variable("a".to_string()),
1033 LogicalExpression::Variable("b".to_string()),
1034 );
1035 builder.add_join_condition(
1036 "b",
1037 "c",
1038 LogicalExpression::Variable("b".to_string()),
1039 LogicalExpression::Variable("c".to_string()),
1040 );
1041 builder.add_join_condition(
1042 "c",
1043 "a",
1044 LogicalExpression::Variable("c".to_string()),
1045 LogicalExpression::Variable("a".to_string()),
1046 );
1047
1048 let graph = builder.build();
1049
1050 let cost_model = CostModel::new();
1051 let mut card_estimator = CardinalityEstimator::new();
1052 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1053 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1054 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1055
1056 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1057 let plan = dpccp.optimize();
1058
1059 assert!(plan.is_some());
1060 let plan = plan.unwrap();
1061 assert_eq!(plan.nodes.len(), 3);
1062 }
1063
1064 #[test]
1065 fn test_dpccp_four_relations() {
1066 let mut builder = JoinGraphBuilder::new();
1068 builder.add_relation("a", create_node_scan("a", "A"));
1069 builder.add_relation("b", create_node_scan("b", "B"));
1070 builder.add_relation("c", create_node_scan("c", "C"));
1071 builder.add_relation("d", create_node_scan("d", "D"));
1072
1073 builder.add_join_condition(
1074 "a",
1075 "b",
1076 LogicalExpression::Variable("a".to_string()),
1077 LogicalExpression::Variable("b".to_string()),
1078 );
1079 builder.add_join_condition(
1080 "b",
1081 "c",
1082 LogicalExpression::Variable("b".to_string()),
1083 LogicalExpression::Variable("c".to_string()),
1084 );
1085 builder.add_join_condition(
1086 "c",
1087 "d",
1088 LogicalExpression::Variable("c".to_string()),
1089 LogicalExpression::Variable("d".to_string()),
1090 );
1091
1092 let graph = builder.build();
1093
1094 let cost_model = CostModel::new();
1095 let mut card_estimator = CardinalityEstimator::new();
1096 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1097 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1098 card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1099 card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1100
1101 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1102 let plan = dpccp.optimize();
1103
1104 assert!(plan.is_some());
1105 let plan = plan.unwrap();
1106 assert_eq!(plan.nodes.len(), 4);
1107 }
1108
1109 #[test]
1110 fn test_join_plan_cardinality_and_cost() {
1111 let mut builder = JoinGraphBuilder::new();
1112 builder.add_relation("a", create_node_scan("a", "A"));
1113 builder.add_relation("b", create_node_scan("b", "B"));
1114
1115 builder.add_join_condition(
1116 "a",
1117 "b",
1118 LogicalExpression::Variable("a".to_string()),
1119 LogicalExpression::Variable("b".to_string()),
1120 );
1121
1122 let graph = builder.build();
1123
1124 let cost_model = CostModel::new();
1125 let mut card_estimator = CardinalityEstimator::new();
1126 card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1127 card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1128
1129 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1130 let plan = dpccp.optimize().unwrap();
1131
1132 assert!(plan.cardinality > 0.0);
1134 assert!(plan.cost.total() > 0.0);
1135 }
1136
1137 #[test]
1138 fn test_join_graph_default() {
1139 let graph = JoinGraph::default();
1140 assert_eq!(graph.node_count(), 0);
1141 }
1142
1143 #[test]
1144 fn test_join_graph_builder_default() {
1145 let builder = JoinGraphBuilder::default();
1146 let graph = builder.build();
1147 assert_eq!(graph.node_count(), 0);
1148 }
1149
1150 #[test]
1151 fn test_join_graph_nodes_accessor() {
1152 let mut builder = JoinGraphBuilder::new();
1153 builder.add_relation("a", create_node_scan("a", "A"));
1154 builder.add_relation("b", create_node_scan("b", "B"));
1155
1156 let graph = builder.build();
1157 let nodes = graph.nodes();
1158
1159 assert_eq!(nodes.len(), 2);
1160 assert_eq!(nodes[0].variable, "a");
1161 assert_eq!(nodes[1].variable, "b");
1162 }
1163
1164 #[test]
1165 fn test_join_node_equality() {
1166 let node1 = JoinNode {
1167 id: 0,
1168 variable: "a".to_string(),
1169 relation: create_node_scan("a", "A"),
1170 };
1171 let node2 = JoinNode {
1172 id: 0,
1173 variable: "a".to_string(),
1174 relation: create_node_scan("a", "A"),
1175 };
1176 let node3 = JoinNode {
1177 id: 1,
1178 variable: "a".to_string(),
1179 relation: create_node_scan("a", "A"),
1180 };
1181
1182 assert_eq!(node1, node2);
1183 assert_ne!(node1, node3);
1184 }
1185
1186 #[test]
1187 fn test_join_node_hash() {
1188 use std::collections::HashSet;
1189
1190 let node1 = JoinNode {
1191 id: 0,
1192 variable: "a".to_string(),
1193 relation: create_node_scan("a", "A"),
1194 };
1195 let node2 = JoinNode {
1196 id: 0,
1197 variable: "a".to_string(),
1198 relation: create_node_scan("a", "A"),
1199 };
1200
1201 let mut set = HashSet::new();
1202 set.insert(node1.clone());
1203
1204 assert!(set.contains(&node2));
1206 }
1207
1208 #[test]
1209 fn test_add_join_condition_unknown_variable() {
1210 let mut builder = JoinGraphBuilder::new();
1211 builder.add_relation("a", create_node_scan("a", "A"));
1212
1213 builder.add_join_condition(
1215 "a",
1216 "unknown",
1217 LogicalExpression::Variable("a".to_string()),
1218 LogicalExpression::Variable("unknown".to_string()),
1219 );
1220
1221 let graph = builder.build();
1222 assert_eq!(graph.node_count(), 1);
1223 }
1224
1225 #[test]
1226 fn test_dpccp_with_different_cardinalities() {
1227 let mut builder = JoinGraphBuilder::new();
1229 builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1230 builder.add_relation("huge", create_node_scan("huge", "Huge"));
1231
1232 builder.add_join_condition(
1233 "tiny",
1234 "huge",
1235 LogicalExpression::Variable("tiny".to_string()),
1236 LogicalExpression::Variable("huge".to_string()),
1237 );
1238
1239 let graph = builder.build();
1240
1241 let cost_model = CostModel::new();
1242 let mut card_estimator = CardinalityEstimator::new();
1243 card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1244 card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1245
1246 let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1247 let plan = dpccp.optimize();
1248
1249 assert!(plan.is_some());
1250 let plan = plan.unwrap();
1251 assert_eq!(plan.nodes.len(), 2);
1252 }
1253
1254 #[test]
1255 fn test_join_graph_cyclic_triangle() {
1256 let mut builder = JoinGraphBuilder::new();
1258 builder.add_relation("a", create_node_scan("a", "A"));
1259 builder.add_relation("b", create_node_scan("b", "B"));
1260 builder.add_relation("c", create_node_scan("c", "C"));
1261
1262 builder.add_join_condition(
1263 "a",
1264 "b",
1265 LogicalExpression::Variable("a".to_string()),
1266 LogicalExpression::Variable("b".to_string()),
1267 );
1268 builder.add_join_condition(
1269 "b",
1270 "c",
1271 LogicalExpression::Variable("b".to_string()),
1272 LogicalExpression::Variable("c".to_string()),
1273 );
1274 builder.add_join_condition(
1275 "c",
1276 "a",
1277 LogicalExpression::Variable("c".to_string()),
1278 LogicalExpression::Variable("a".to_string()),
1279 );
1280
1281 let graph = builder.build();
1282 assert!(graph.is_cyclic());
1283 }
1284
1285 #[test]
1286 fn test_join_graph_acyclic_chain() {
1287 let mut builder = JoinGraphBuilder::new();
1289 builder.add_relation("a", create_node_scan("a", "A"));
1290 builder.add_relation("b", create_node_scan("b", "B"));
1291 builder.add_relation("c", create_node_scan("c", "C"));
1292
1293 builder.add_join_condition(
1294 "a",
1295 "b",
1296 LogicalExpression::Variable("a".to_string()),
1297 LogicalExpression::Variable("b".to_string()),
1298 );
1299 builder.add_join_condition(
1300 "b",
1301 "c",
1302 LogicalExpression::Variable("b".to_string()),
1303 LogicalExpression::Variable("c".to_string()),
1304 );
1305
1306 let graph = builder.build();
1307 assert!(!graph.is_cyclic());
1308 }
1309
1310 #[test]
1311 fn test_join_graph_empty_not_cyclic() {
1312 let graph = JoinGraph::new();
1313 assert!(!graph.is_cyclic());
1314 }
1315
1316 #[test]
1317 fn test_join_graph_edges_accessor() {
1318 let mut builder = JoinGraphBuilder::new();
1319 builder.add_relation("a", create_node_scan("a", "A"));
1320 builder.add_relation("b", create_node_scan("b", "B"));
1321
1322 builder.add_join_condition(
1323 "a",
1324 "b",
1325 LogicalExpression::Variable("a".to_string()),
1326 LogicalExpression::Variable("b".to_string()),
1327 );
1328
1329 let graph = builder.build();
1330 assert_eq!(graph.edges().len(), 1);
1331 }
1332}