use super::cardinality::CardinalityEstimator;
use super::cost::{Cost, CostModel};
use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct JoinNode {
pub id: usize,
pub variable: String,
pub relation: LogicalOperator,
}
impl PartialEq for JoinNode {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.variable == other.variable
}
}
impl Eq for JoinNode {}
impl std::hash::Hash for JoinNode {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.variable.hash(state);
}
}
#[derive(Debug, Clone)]
pub struct JoinEdge {
pub from: usize,
pub to: usize,
pub conditions: Vec<JoinCondition>,
}
#[derive(Debug)]
pub struct JoinGraph {
nodes: Vec<JoinNode>,
edges: Vec<JoinEdge>,
adjacency: HashMap<usize, HashSet<usize>>,
}
impl JoinGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
adjacency: HashMap::new(),
}
}
pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
let id = self.nodes.len();
self.nodes.push(JoinNode {
id,
variable,
relation,
});
self.adjacency.insert(id, HashSet::new());
id
}
pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
self.edges.push(JoinEdge {
from,
to,
conditions,
});
self.adjacency
.get_mut(&from)
.expect("'from' node must be added via add_node() before add_edge()")
.insert(to);
self.adjacency
.get_mut(&to)
.expect("'to' node must be added via add_node() before add_edge()")
.insert(from);
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn nodes(&self) -> &[JoinNode] {
&self.nodes
}
pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
self.adjacency.get(&node_id).into_iter().flatten().copied()
}
pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
let mut conditions = Vec::new();
for edge in &self.edges {
let from_in_left = left.contains(edge.from);
let from_in_right = right.contains(edge.from);
let to_in_left = left.contains(edge.to);
let to_in_right = right.contains(edge.to);
if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
conditions.extend(edge.conditions.clone());
}
}
conditions
}
pub fn edges(&self) -> &[JoinEdge] {
&self.edges
}
#[must_use]
pub fn is_cyclic(&self) -> bool {
if self.nodes.is_empty() {
return false;
}
self.edges.len() >= self.nodes.len()
}
pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
for edge in &self.edges {
let from_in_left = left.contains(edge.from);
let from_in_right = right.contains(edge.from);
let to_in_left = left.contains(edge.to);
let to_in_right = right.contains(edge.to);
if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
return true;
}
}
false
}
}
impl Default for JoinGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BitSet(u64);
impl BitSet {
pub fn empty() -> Self {
Self(0)
}
pub fn singleton(i: usize) -> Self {
Self(1 << i)
}
pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
let mut bits = 0u64;
for i in iter {
bits |= 1 << i;
}
Self(bits)
}
pub fn full(n: usize) -> Self {
Self((1 << n) - 1)
}
pub fn is_empty(&self) -> bool {
self.0 == 0
}
pub fn len(&self) -> usize {
self.0.count_ones() as usize
}
pub fn contains(&self, i: usize) -> bool {
(self.0 & (1 << i)) != 0
}
pub fn insert(&mut self, i: usize) {
self.0 |= 1 << i;
}
pub fn remove(&mut self, i: usize) {
self.0 &= !(1 << i);
}
pub fn union(self, other: Self) -> Self {
Self(self.0 | other.0)
}
pub fn intersection(self, other: Self) -> Self {
Self(self.0 & other.0)
}
pub fn difference(self, other: Self) -> Self {
Self(self.0 & !other.0)
}
pub fn is_subset_of(self, other: Self) -> bool {
(self.0 & other.0) == self.0
}
pub fn iter(self) -> impl Iterator<Item = usize> {
(0..64).filter(move |&i| self.contains(i))
}
pub fn subsets(self) -> SubsetIterator {
SubsetIterator {
full: self.0,
current: Some(self.0),
}
}
}
pub struct SubsetIterator {
full: u64,
current: Option<u64>,
}
impl Iterator for SubsetIterator {
type Item = BitSet;
fn next(&mut self) -> Option<Self::Item> {
let current = self.current?;
if current == 0 {
self.current = None;
return Some(BitSet(0));
}
let result = current;
self.current = Some((current.wrapping_sub(1)) & self.full);
if self.current == Some(self.full) {
self.current = None;
}
Some(BitSet(result))
}
}
#[derive(Debug, Clone)]
pub struct JoinPlan {
pub nodes: BitSet,
pub operator: LogicalOperator,
pub cost: Cost,
pub cardinality: f64,
}
pub struct DPccp<'a> {
graph: &'a JoinGraph,
cost_model: &'a CostModel,
card_estimator: &'a CardinalityEstimator,
memo: HashMap<BitSet, JoinPlan>,
}
impl<'a> DPccp<'a> {
pub fn new(
graph: &'a JoinGraph,
cost_model: &'a CostModel,
card_estimator: &'a CardinalityEstimator,
) -> Self {
Self {
graph,
cost_model,
card_estimator,
memo: HashMap::new(),
}
}
pub fn optimize(&mut self) -> Option<JoinPlan> {
let n = self.graph.node_count();
if n == 0 {
return None;
}
if n == 1 {
let node = &self.graph.nodes[0];
let cardinality = self.card_estimator.estimate(&node.relation);
let cost = self.cost_model.estimate(&node.relation, cardinality);
return Some(JoinPlan {
nodes: BitSet::singleton(0),
operator: node.relation.clone(),
cost,
cardinality,
});
}
for (i, node) in self.graph.nodes.iter().enumerate() {
let subset = BitSet::singleton(i);
let cardinality = self.card_estimator.estimate(&node.relation);
let cost = self.cost_model.estimate(&node.relation, cardinality);
self.memo.insert(
subset,
JoinPlan {
nodes: subset,
operator: node.relation.clone(),
cost,
cardinality,
},
);
}
let full_set = BitSet::full(n);
self.enumerate_ccp(full_set);
self.memo.get(&full_set).cloned()
}
fn enumerate_ccp(&mut self, s: BitSet) {
for s1 in s.subsets() {
if s1.is_empty() || s1 == s {
continue;
}
let s2 = s.difference(s1);
if s2.is_empty() {
continue;
}
if !self.is_connected(s1) || !self.is_connected(s2) {
continue;
}
if !self.graph.are_connected(&s1, &s2) {
continue;
}
if !self.memo.contains_key(&s1) {
self.enumerate_ccp(s1);
}
if !self.memo.contains_key(&s2) {
self.enumerate_ccp(s2);
}
if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
let conditions = self.graph.get_conditions(&s1, &s2);
let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
let should_update = self.memo.get(&s).map_or(true, |existing| {
new_plan.cost.total() < existing.cost.total()
});
if should_update {
self.memo.insert(s, new_plan);
}
}
}
}
fn is_connected(&self, subset: BitSet) -> bool {
if subset.len() <= 1 {
return true;
}
let start = subset
.iter()
.next()
.expect("subset is non-empty: len >= 2 checked on line 400");
let mut visited = BitSet::singleton(start);
let mut queue = vec![start];
while let Some(node) = queue.pop() {
for neighbor in self.graph.neighbors(node) {
if subset.contains(neighbor) && !visited.contains(neighbor) {
visited.insert(neighbor);
queue.push(neighbor);
}
}
}
visited == subset
}
fn build_join_plan(
&self,
left: JoinPlan,
right: JoinPlan,
conditions: Vec<JoinCondition>,
) -> JoinPlan {
let nodes = left.nodes.union(right.nodes);
let join_op = LogicalOperator::Join(JoinOp {
left: Box::new(left.operator),
right: Box::new(right.operator),
join_type: JoinType::Inner,
conditions,
});
let cardinality = self.card_estimator.estimate(&join_op);
let join_cost = self.cost_model.estimate(&join_op, cardinality);
let total_cost = left.cost + right.cost + join_cost;
JoinPlan {
nodes,
operator: join_op,
cost: total_cost,
cardinality,
}
}
}
pub struct JoinGraphBuilder {
graph: JoinGraph,
variable_to_node: HashMap<String, usize>,
}
impl JoinGraphBuilder {
pub fn new() -> Self {
Self {
graph: JoinGraph::new(),
variable_to_node: HashMap::new(),
}
}
pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
let id = self.graph.add_node(variable.to_string(), relation);
self.variable_to_node.insert(variable.to_string(), id);
id
}
pub fn add_join_condition(
&mut self,
left_var: &str,
right_var: &str,
left_expr: LogicalExpression,
right_expr: LogicalExpression,
) {
if let (Some(&left_id), Some(&right_id)) = (
self.variable_to_node.get(left_var),
self.variable_to_node.get(right_var),
) {
self.graph.add_edge(
left_id,
right_id,
vec![JoinCondition {
left: left_expr,
right: right_expr,
}],
);
}
}
pub fn build(self) -> JoinGraph {
self.graph
}
}
impl Default for JoinGraphBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::query::plan::NodeScanOp;
fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
LogicalOperator::NodeScan(NodeScanOp {
variable: var.to_string(),
label: Some(label.to_string()),
input: None,
})
}
#[test]
fn test_bitset_operations() {
let a = BitSet::singleton(0);
let b = BitSet::singleton(1);
let _c = BitSet::singleton(2);
assert!(a.contains(0));
assert!(!a.contains(1));
let ab = a.union(b);
assert!(ab.contains(0));
assert!(ab.contains(1));
assert!(!ab.contains(2));
let full = BitSet::full(3);
assert_eq!(full.len(), 3);
assert!(full.contains(0));
assert!(full.contains(1));
assert!(full.contains(2));
}
#[test]
fn test_subset_iteration() {
let set = BitSet::from_iter([0, 1].into_iter());
let subsets: Vec<_> = set.subsets().collect();
assert_eq!(subsets.len(), 4);
assert!(subsets.contains(&BitSet::empty()));
assert!(subsets.contains(&BitSet::singleton(0)));
assert!(subsets.contains(&BitSet::singleton(1)));
assert!(subsets.contains(&set));
}
#[test]
fn test_join_graph_construction() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "Person"));
builder.add_relation("b", create_node_scan("b", "Person"));
builder.add_relation("c", create_node_scan("c", "Company"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Property {
variable: "a".to_string(),
property: "id".to_string(),
},
LogicalExpression::Property {
variable: "b".to_string(),
property: "friend_id".to_string(),
},
);
builder.add_join_condition(
"a",
"c",
LogicalExpression::Property {
variable: "a".to_string(),
property: "company_id".to_string(),
},
LogicalExpression::Property {
variable: "c".to_string(),
property: "id".to_string(),
},
);
let graph = builder.build();
assert_eq!(graph.node_count(), 3);
}
#[test]
fn test_dpccp_single_relation() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "Person"));
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 1);
}
#[test]
fn test_dpccp_two_relations() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "Person"));
builder.add_relation("b", create_node_scan("b", "Company"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Property {
variable: "a".to_string(),
property: "company_id".to_string(),
},
LogicalExpression::Property {
variable: "b".to_string(),
property: "id".to_string(),
},
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 2);
if let LogicalOperator::Join(_) = plan.operator {
} else {
panic!("Expected Join operator");
}
}
#[test]
fn test_dpccp_three_relations_chain() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "Person"));
builder.add_relation("b", create_node_scan("b", "Person"));
builder.add_relation("c", create_node_scan("c", "Company"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Property {
variable: "a".to_string(),
property: "knows".to_string(),
},
LogicalExpression::Property {
variable: "b".to_string(),
property: "id".to_string(),
},
);
builder.add_join_condition(
"b",
"c",
LogicalExpression::Property {
variable: "b".to_string(),
property: "company_id".to_string(),
},
LogicalExpression::Property {
variable: "c".to_string(),
property: "id".to_string(),
},
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 3);
}
#[test]
fn test_dpccp_prefers_smaller_intermediate() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("s", create_node_scan("s", "Small"));
builder.add_relation("m", create_node_scan("m", "Medium"));
builder.add_relation("l", create_node_scan("l", "Large"));
builder.add_join_condition(
"s",
"m",
LogicalExpression::Property {
variable: "s".to_string(),
property: "id".to_string(),
},
LogicalExpression::Property {
variable: "m".to_string(),
property: "s_id".to_string(),
},
);
builder.add_join_condition(
"m",
"l",
LogicalExpression::Property {
variable: "m".to_string(),
property: "id".to_string(),
},
LogicalExpression::Property {
variable: "l".to_string(),
property: "m_id".to_string(),
},
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 3);
assert!(plan.cost.total() > 0.0);
}
#[test]
fn test_bitset_empty() {
let empty = BitSet::empty();
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
assert!(!empty.contains(0));
}
#[test]
fn test_bitset_insert_remove() {
let mut set = BitSet::empty();
set.insert(3);
assert!(set.contains(3));
assert_eq!(set.len(), 1);
set.insert(5);
assert!(set.contains(5));
assert_eq!(set.len(), 2);
set.remove(3);
assert!(!set.contains(3));
assert_eq!(set.len(), 1);
}
#[test]
fn test_bitset_intersection() {
let a = BitSet::from_iter([0, 1, 2].into_iter());
let b = BitSet::from_iter([1, 2, 3].into_iter());
let intersection = a.intersection(b);
assert!(intersection.contains(1));
assert!(intersection.contains(2));
assert!(!intersection.contains(0));
assert!(!intersection.contains(3));
assert_eq!(intersection.len(), 2);
}
#[test]
fn test_bitset_difference() {
let a = BitSet::from_iter([0, 1, 2].into_iter());
let b = BitSet::from_iter([1, 2, 3].into_iter());
let diff = a.difference(b);
assert!(diff.contains(0));
assert!(!diff.contains(1));
assert!(!diff.contains(2));
assert_eq!(diff.len(), 1);
}
#[test]
fn test_bitset_is_subset_of() {
let a = BitSet::from_iter([1, 2].into_iter());
let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
assert!(a.is_subset_of(b));
assert!(!b.is_subset_of(a));
assert!(a.is_subset_of(a));
}
#[test]
fn test_bitset_iter() {
let set = BitSet::from_iter([0, 2, 5].into_iter());
let elements: Vec<_> = set.iter().collect();
assert_eq!(elements, vec![0, 2, 5]);
}
#[test]
fn test_join_graph_empty() {
let graph = JoinGraph::new();
assert_eq!(graph.node_count(), 0);
}
#[test]
fn test_join_graph_neighbors() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"a",
"c",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("c".to_string()),
);
let graph = builder.build();
let neighbors_a: Vec<_> = graph.neighbors(0).collect();
assert_eq!(neighbors_a.len(), 2);
assert!(neighbors_a.contains(&1));
assert!(neighbors_a.contains(&2));
let neighbors_b: Vec<_> = graph.neighbors(1).collect();
assert_eq!(neighbors_b.len(), 1);
assert!(neighbors_b.contains(&0));
}
#[test]
fn test_join_graph_are_connected() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
let graph = builder.build();
let set_a = BitSet::singleton(0);
let set_b = BitSet::singleton(1);
let set_c = BitSet::singleton(2);
assert!(graph.are_connected(&set_a, &set_b));
assert!(graph.are_connected(&set_b, &set_a));
assert!(!graph.are_connected(&set_a, &set_c));
assert!(!graph.are_connected(&set_b, &set_c));
}
#[test]
fn test_join_graph_get_conditions() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Property {
variable: "a".to_string(),
property: "id".to_string(),
},
LogicalExpression::Property {
variable: "b".to_string(),
property: "a_id".to_string(),
},
);
let graph = builder.build();
let set_a = BitSet::singleton(0);
let set_b = BitSet::singleton(1);
let conditions = graph.get_conditions(&set_a, &set_b);
assert_eq!(conditions.len(), 1);
}
#[test]
fn test_dpccp_empty_graph() {
let graph = JoinGraph::new();
let cost_model = CostModel::new();
let card_estimator = CardinalityEstimator::new();
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_none());
}
#[test]
fn test_dpccp_star_query() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("center", create_node_scan("center", "Center"));
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"center",
"a",
LogicalExpression::Variable("center".to_string()),
LogicalExpression::Variable("a".to_string()),
);
builder.add_join_condition(
"center",
"b",
LogicalExpression::Variable("center".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"center",
"c",
LogicalExpression::Variable("center".to_string()),
LogicalExpression::Variable("c".to_string()),
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 4);
assert!(plan.cost.total() > 0.0);
}
#[test]
fn test_dpccp_cycle_query() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"b",
"c",
LogicalExpression::Variable("b".to_string()),
LogicalExpression::Variable("c".to_string()),
);
builder.add_join_condition(
"c",
"a",
LogicalExpression::Variable("c".to_string()),
LogicalExpression::Variable("a".to_string()),
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 3);
}
#[test]
fn test_dpccp_four_relations() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_relation("d", create_node_scan("d", "D"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"b",
"c",
LogicalExpression::Variable("b".to_string()),
LogicalExpression::Variable("c".to_string()),
);
builder.add_join_condition(
"c",
"d",
LogicalExpression::Variable("c".to_string()),
LogicalExpression::Variable("d".to_string()),
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 4);
}
#[test]
fn test_join_plan_cardinality_and_cost() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize().unwrap();
assert!(plan.cardinality > 0.0);
assert!(plan.cost.total() > 0.0);
}
#[test]
fn test_join_graph_default() {
let graph = JoinGraph::default();
assert_eq!(graph.node_count(), 0);
}
#[test]
fn test_join_graph_builder_default() {
let builder = JoinGraphBuilder::default();
let graph = builder.build();
assert_eq!(graph.node_count(), 0);
}
#[test]
fn test_join_graph_nodes_accessor() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
let graph = builder.build();
let nodes = graph.nodes();
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].variable, "a");
assert_eq!(nodes[1].variable, "b");
}
#[test]
fn test_join_node_equality() {
let node1 = JoinNode {
id: 0,
variable: "a".to_string(),
relation: create_node_scan("a", "A"),
};
let node2 = JoinNode {
id: 0,
variable: "a".to_string(),
relation: create_node_scan("a", "A"),
};
let node3 = JoinNode {
id: 1,
variable: "a".to_string(),
relation: create_node_scan("a", "A"),
};
assert_eq!(node1, node2);
assert_ne!(node1, node3);
}
#[test]
fn test_join_node_hash() {
use std::collections::HashSet;
let node1 = JoinNode {
id: 0,
variable: "a".to_string(),
relation: create_node_scan("a", "A"),
};
let node2 = JoinNode {
id: 0,
variable: "a".to_string(),
relation: create_node_scan("a", "A"),
};
let mut set = HashSet::new();
set.insert(node1.clone());
assert!(set.contains(&node2));
}
#[test]
fn test_add_join_condition_unknown_variable() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_join_condition(
"a",
"unknown",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("unknown".to_string()),
);
let graph = builder.build();
assert_eq!(graph.node_count(), 1);
}
#[test]
fn test_dpccp_with_different_cardinalities() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
builder.add_relation("huge", create_node_scan("huge", "Huge"));
builder.add_join_condition(
"tiny",
"huge",
LogicalExpression::Variable("tiny".to_string()),
LogicalExpression::Variable("huge".to_string()),
);
let graph = builder.build();
let cost_model = CostModel::new();
let mut card_estimator = CardinalityEstimator::new();
card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
let plan = dpccp.optimize();
assert!(plan.is_some());
let plan = plan.unwrap();
assert_eq!(plan.nodes.len(), 2);
}
#[test]
fn test_join_graph_cyclic_triangle() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"b",
"c",
LogicalExpression::Variable("b".to_string()),
LogicalExpression::Variable("c".to_string()),
);
builder.add_join_condition(
"c",
"a",
LogicalExpression::Variable("c".to_string()),
LogicalExpression::Variable("a".to_string()),
);
let graph = builder.build();
assert!(graph.is_cyclic());
}
#[test]
fn test_join_graph_acyclic_chain() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_relation("c", create_node_scan("c", "C"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
builder.add_join_condition(
"b",
"c",
LogicalExpression::Variable("b".to_string()),
LogicalExpression::Variable("c".to_string()),
);
let graph = builder.build();
assert!(!graph.is_cyclic());
}
#[test]
fn test_join_graph_empty_not_cyclic() {
let graph = JoinGraph::new();
assert!(!graph.is_cyclic());
}
#[test]
fn test_join_graph_edges_accessor() {
let mut builder = JoinGraphBuilder::new();
builder.add_relation("a", create_node_scan("a", "A"));
builder.add_relation("b", create_node_scan("b", "B"));
builder.add_join_condition(
"a",
"b",
LogicalExpression::Variable("a".to_string()),
LogicalExpression::Variable("b".to_string()),
);
let graph = builder.build();
assert_eq!(graph.edges().len(), 1);
}
}