Skip to main content

grafeo_engine/query/optimizer/
join_order.rs

1//! DPccp (Dynamic Programming with connected complement pairs) join ordering.
2//!
3//! This module implements the DPccp algorithm for finding optimal join orderings.
4//! The algorithm works by:
5//! 1. Building a join graph from the query
6//! 2. Enumerating all connected subgraphs
7//! 3. Finding optimal plans for each subgraph using dynamic programming
8//!
9//! Reference: Moerkotte, G., & Neumann, T. (2006). Analysis of Two Existing and
10//! One New Dynamic Programming Algorithm for the Generation of Optimal Bushy
11//! Join Trees without Cross Products.
12
13use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18/// A node in the join graph.
19#[derive(Debug, Clone)]
20pub struct JoinNode {
21    /// Unique identifier for this node.
22    pub id: usize,
23    /// Variable name (e.g., "a" for node (a:Person)).
24    pub variable: String,
25    /// The base relation (scan operator).
26    pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30    fn eq(&self, other: &Self) -> bool {
31        self.id == other.id && self.variable == other.variable
32    }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.id.hash(state);
40        self.variable.hash(state);
41    }
42}
43
44/// An edge in the join graph representing a join condition.
45#[derive(Debug, Clone)]
46pub struct JoinEdge {
47    /// Source node id.
48    pub from: usize,
49    /// Target node id.
50    pub to: usize,
51    /// Join conditions between these nodes.
52    pub conditions: Vec<JoinCondition>,
53}
54
55/// The join graph representing all relations and join conditions in a query.
56#[derive(Debug)]
57pub struct JoinGraph {
58    /// Nodes in the graph.
59    nodes: Vec<JoinNode>,
60    /// Edges (join conditions) between nodes.
61    edges: Vec<JoinEdge>,
62    /// Adjacency list for quick neighbor lookup.
63    adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67    /// Creates a new empty join graph.
68    pub fn new() -> Self {
69        Self {
70            nodes: Vec::new(),
71            edges: Vec::new(),
72            adjacency: HashMap::new(),
73        }
74    }
75
76    /// Adds a node to the graph.
77    pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78        let id = self.nodes.len();
79        self.nodes.push(JoinNode {
80            id,
81            variable,
82            relation,
83        });
84        self.adjacency.insert(id, HashSet::new());
85        id
86    }
87
88    /// Adds a join edge between two nodes.
89    ///
90    /// # Panics
91    ///
92    /// Panics if `from` or `to` were not previously added via `add_node`.
93    pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94        self.edges.push(JoinEdge {
95            from,
96            to,
97            conditions,
98        });
99        // Invariant: add_node() inserts node ID into adjacency map (line 84),
100        // so get_mut succeeds for any ID returned by add_node()
101        self.adjacency
102            .get_mut(&from)
103            .expect("'from' node must be added via add_node() before add_edge()")
104            .insert(to);
105        self.adjacency
106            .get_mut(&to)
107            .expect("'to' node must be added via add_node() before add_edge()")
108            .insert(from);
109    }
110
111    /// Returns the number of nodes.
112    pub fn node_count(&self) -> usize {
113        self.nodes.len()
114    }
115
116    /// Returns a reference to the nodes.
117    pub fn nodes(&self) -> &[JoinNode] {
118        &self.nodes
119    }
120
121    /// Returns neighbors of a node.
122    pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123        self.adjacency.get(&node_id).into_iter().flatten().copied()
124    }
125
126    /// Gets the join conditions between two node sets.
127    pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128        let mut conditions = Vec::new();
129        for edge in &self.edges {
130            let from_in_left = left.contains(edge.from);
131            let from_in_right = right.contains(edge.from);
132            let to_in_left = left.contains(edge.to);
133            let to_in_right = right.contains(edge.to);
134
135            // Edge crosses between left and right
136            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137                conditions.extend(edge.conditions.clone());
138            }
139        }
140        conditions
141    }
142
143    /// Returns the edges in the graph.
144    pub fn edges(&self) -> &[JoinEdge] {
145        &self.edges
146    }
147
148    /// Returns true if the join graph contains a cycle.
149    ///
150    /// A connected graph with N nodes and E edges is cyclic when E >= N.
151    #[must_use]
152    pub fn is_cyclic(&self) -> bool {
153        if self.nodes.is_empty() {
154            return false;
155        }
156        self.edges.len() >= self.nodes.len()
157    }
158
159    /// Checks if two node sets are connected by at least one edge.
160    pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
161        for edge in &self.edges {
162            let from_in_left = left.contains(edge.from);
163            let from_in_right = right.contains(edge.from);
164            let to_in_left = left.contains(edge.to);
165            let to_in_right = right.contains(edge.to);
166
167            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
168                return true;
169            }
170        }
171        false
172    }
173}
174
175impl Default for JoinGraph {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// A bitset for efficient subset representation.
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
183pub struct BitSet(u64);
184
185impl BitSet {
186    /// Creates an empty bitset.
187    pub fn empty() -> Self {
188        Self(0)
189    }
190
191    /// Creates a bitset with a single element.
192    pub fn singleton(i: usize) -> Self {
193        Self(1 << i)
194    }
195
196    /// Creates a bitset from an iterator of indices.
197    pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
198        let mut bits = 0u64;
199        for i in iter {
200            bits |= 1 << i;
201        }
202        Self(bits)
203    }
204
205    /// Creates a full bitset with elements 0..n.
206    pub fn full(n: usize) -> Self {
207        Self((1 << n) - 1)
208    }
209
210    /// Checks if the set is empty.
211    pub fn is_empty(&self) -> bool {
212        self.0 == 0
213    }
214
215    /// Returns the number of elements.
216    pub fn len(&self) -> usize {
217        self.0.count_ones() as usize
218    }
219
220    /// Checks if the set contains an element.
221    pub fn contains(&self, i: usize) -> bool {
222        (self.0 & (1 << i)) != 0
223    }
224
225    /// Inserts an element.
226    pub fn insert(&mut self, i: usize) {
227        self.0 |= 1 << i;
228    }
229
230    /// Removes an element.
231    pub fn remove(&mut self, i: usize) {
232        self.0 &= !(1 << i);
233    }
234
235    /// Returns the union of two sets.
236    pub fn union(self, other: Self) -> Self {
237        Self(self.0 | other.0)
238    }
239
240    /// Returns the intersection of two sets.
241    pub fn intersection(self, other: Self) -> Self {
242        Self(self.0 & other.0)
243    }
244
245    /// Returns the difference (self - other).
246    pub fn difference(self, other: Self) -> Self {
247        Self(self.0 & !other.0)
248    }
249
250    /// Checks if this set is a subset of another.
251    pub fn is_subset_of(self, other: Self) -> bool {
252        (self.0 & other.0) == self.0
253    }
254
255    /// Iterates over all elements in the set.
256    pub fn iter(self) -> impl Iterator<Item = usize> {
257        (0..64).filter(move |&i| self.contains(i))
258    }
259
260    /// Iterates over all non-empty subsets.
261    pub fn subsets(self) -> SubsetIterator {
262        SubsetIterator {
263            full: self.0,
264            current: Some(self.0),
265        }
266    }
267}
268
269/// Iterator over all subsets of a bitset.
270pub struct SubsetIterator {
271    full: u64,
272    current: Option<u64>,
273}
274
275impl Iterator for SubsetIterator {
276    type Item = BitSet;
277
278    fn next(&mut self) -> Option<Self::Item> {
279        let current = self.current?;
280        if current == 0 {
281            self.current = None;
282            return Some(BitSet(0));
283        }
284        let result = current;
285        // Gosper's hack variant for subset enumeration
286        self.current = Some((current.wrapping_sub(1)) & self.full);
287        if self.current == Some(self.full) {
288            self.current = None;
289        }
290        Some(BitSet(result))
291    }
292}
293
294/// Represents a (partial) join plan.
295#[derive(Debug, Clone)]
296pub struct JoinPlan {
297    /// The subset of nodes covered by this plan.
298    pub nodes: BitSet,
299    /// The logical operator representing this plan.
300    pub operator: LogicalOperator,
301    /// Estimated cost of this plan.
302    pub cost: Cost,
303    /// Estimated cardinality.
304    pub cardinality: f64,
305}
306
307/// DPccp join ordering optimizer.
308pub struct DPccp<'a> {
309    /// The join graph.
310    graph: &'a JoinGraph,
311    /// Cost model for estimating operator costs.
312    cost_model: &'a CostModel,
313    /// Cardinality estimator.
314    card_estimator: &'a CardinalityEstimator,
315    /// Memoization table: subset -> best plan.
316    memo: HashMap<BitSet, JoinPlan>,
317}
318
319impl<'a> DPccp<'a> {
320    /// Creates a new DPccp optimizer.
321    pub fn new(
322        graph: &'a JoinGraph,
323        cost_model: &'a CostModel,
324        card_estimator: &'a CardinalityEstimator,
325    ) -> Self {
326        Self {
327            graph,
328            cost_model,
329            card_estimator,
330            memo: HashMap::new(),
331        }
332    }
333
334    /// Finds the optimal join order for the graph.
335    pub fn optimize(&mut self) -> Option<JoinPlan> {
336        let n = self.graph.node_count();
337        if n == 0 {
338            return None;
339        }
340        if n == 1 {
341            let node = &self.graph.nodes[0];
342            let cardinality = self.card_estimator.estimate(&node.relation);
343            let cost = self.cost_model.estimate(&node.relation, cardinality);
344            return Some(JoinPlan {
345                nodes: BitSet::singleton(0),
346                operator: node.relation.clone(),
347                cost,
348                cardinality,
349            });
350        }
351
352        // Initialize with single relations
353        for (i, node) in self.graph.nodes.iter().enumerate() {
354            let subset = BitSet::singleton(i);
355            let cardinality = self.card_estimator.estimate(&node.relation);
356            let cost = self.cost_model.estimate(&node.relation, cardinality);
357            self.memo.insert(
358                subset,
359                JoinPlan {
360                    nodes: subset,
361                    operator: node.relation.clone(),
362                    cost,
363                    cardinality,
364                },
365            );
366        }
367
368        // Enumerate connected subgraph pairs (ccp)
369        let full_set = BitSet::full(n);
370        self.enumerate_ccp(full_set);
371
372        // Return the best plan for the full set
373        self.memo.get(&full_set).cloned()
374    }
375
376    /// Enumerates connected complement pairs using DPccp algorithm.
377    fn enumerate_ccp(&mut self, s: BitSet) {
378        // Iterate over all proper non-empty subsets
379        for s1 in s.subsets() {
380            if s1.is_empty() || s1 == s {
381                continue;
382            }
383
384            let s2 = s.difference(s1);
385            if s2.is_empty() {
386                continue;
387            }
388
389            // Both s1 and s2 must be connected subsets
390            if !self.is_connected(s1) || !self.is_connected(s2) {
391                continue;
392            }
393
394            // s1 and s2 must be connected to each other
395            if !self.graph.are_connected(&s1, &s2) {
396                continue;
397            }
398
399            // Recursively solve subproblems
400            if !self.memo.contains_key(&s1) {
401                self.enumerate_ccp(s1);
402            }
403            if !self.memo.contains_key(&s2) {
404                self.enumerate_ccp(s2);
405            }
406
407            // Try to build a plan for s by joining s1 and s2
408            if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
409                let conditions = self.graph.get_conditions(&s1, &s2);
410                let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
411
412                // Update memo if this is a better plan
413                let should_update = self.memo.get(&s).map_or(true, |existing| {
414                    new_plan.cost.total() < existing.cost.total()
415                });
416
417                if should_update {
418                    self.memo.insert(s, new_plan);
419                }
420            }
421        }
422    }
423
424    /// Checks if a subset forms a connected subgraph.
425    fn is_connected(&self, subset: BitSet) -> bool {
426        if subset.len() <= 1 {
427            return true;
428        }
429
430        // BFS to check connectivity
431        // Invariant: subset.len() >= 2 (guard on line 400), so iter().next() returns Some
432        let start = subset
433            .iter()
434            .next()
435            .expect("subset is non-empty: len >= 2 checked on line 400");
436        let mut visited = BitSet::singleton(start);
437        let mut queue = vec![start];
438
439        while let Some(node) = queue.pop() {
440            for neighbor in self.graph.neighbors(node) {
441                if subset.contains(neighbor) && !visited.contains(neighbor) {
442                    visited.insert(neighbor);
443                    queue.push(neighbor);
444                }
445            }
446        }
447
448        visited == subset
449    }
450
451    /// Builds a join plan from two sub-plans.
452    fn build_join_plan(
453        &self,
454        left: JoinPlan,
455        right: JoinPlan,
456        conditions: Vec<JoinCondition>,
457    ) -> JoinPlan {
458        let nodes = left.nodes.union(right.nodes);
459
460        // Create the join operator
461        let join_op = LogicalOperator::Join(JoinOp {
462            left: Box::new(left.operator),
463            right: Box::new(right.operator),
464            join_type: JoinType::Inner,
465            conditions,
466        });
467
468        // Estimate cardinality
469        let cardinality = self.card_estimator.estimate(&join_op);
470
471        // Calculate cost (child costs + join cost)
472        let join_cost = self.cost_model.estimate(&join_op, cardinality);
473        let total_cost = left.cost + right.cost + join_cost;
474
475        JoinPlan {
476            nodes,
477            operator: join_op,
478            cost: total_cost,
479            cardinality,
480        }
481    }
482}
483
484/// Extracts a join graph from a query pattern.
485pub struct JoinGraphBuilder {
486    graph: JoinGraph,
487    variable_to_node: HashMap<String, usize>,
488}
489
490impl JoinGraphBuilder {
491    /// Creates a new builder.
492    pub fn new() -> Self {
493        Self {
494            graph: JoinGraph::new(),
495            variable_to_node: HashMap::new(),
496        }
497    }
498
499    /// Adds a base relation (scan).
500    pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
501        let id = self.graph.add_node(variable.to_string(), relation);
502        self.variable_to_node.insert(variable.to_string(), id);
503        id
504    }
505
506    /// Adds a join condition between two variables.
507    pub fn add_join_condition(
508        &mut self,
509        left_var: &str,
510        right_var: &str,
511        left_expr: LogicalExpression,
512        right_expr: LogicalExpression,
513    ) {
514        if let (Some(&left_id), Some(&right_id)) = (
515            self.variable_to_node.get(left_var),
516            self.variable_to_node.get(right_var),
517        ) {
518            self.graph.add_edge(
519                left_id,
520                right_id,
521                vec![JoinCondition {
522                    left: left_expr,
523                    right: right_expr,
524                }],
525            );
526        }
527    }
528
529    /// Builds the join graph.
530    pub fn build(self) -> JoinGraph {
531        self.graph
532    }
533}
534
535impl Default for JoinGraphBuilder {
536    fn default() -> Self {
537        Self::new()
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use crate::query::plan::NodeScanOp;
545
546    fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
547        LogicalOperator::NodeScan(NodeScanOp {
548            variable: var.to_string(),
549            label: Some(label.to_string()),
550            input: None,
551        })
552    }
553
554    #[test]
555    fn test_bitset_operations() {
556        let a = BitSet::singleton(0);
557        let b = BitSet::singleton(1);
558        let _c = BitSet::singleton(2);
559
560        assert!(a.contains(0));
561        assert!(!a.contains(1));
562
563        let ab = a.union(b);
564        assert!(ab.contains(0));
565        assert!(ab.contains(1));
566        assert!(!ab.contains(2));
567
568        let full = BitSet::full(3);
569        assert_eq!(full.len(), 3);
570        assert!(full.contains(0));
571        assert!(full.contains(1));
572        assert!(full.contains(2));
573    }
574
575    #[test]
576    fn test_subset_iteration() {
577        let set = BitSet::from_iter([0, 1].into_iter());
578        let subsets: Vec<_> = set.subsets().collect();
579
580        // Should have 4 subsets: {}, {0}, {1}, {0,1}
581        assert_eq!(subsets.len(), 4);
582        assert!(subsets.contains(&BitSet::empty()));
583        assert!(subsets.contains(&BitSet::singleton(0)));
584        assert!(subsets.contains(&BitSet::singleton(1)));
585        assert!(subsets.contains(&set));
586    }
587
588    #[test]
589    fn test_join_graph_construction() {
590        let mut builder = JoinGraphBuilder::new();
591
592        builder.add_relation("a", create_node_scan("a", "Person"));
593        builder.add_relation("b", create_node_scan("b", "Person"));
594        builder.add_relation("c", create_node_scan("c", "Company"));
595
596        builder.add_join_condition(
597            "a",
598            "b",
599            LogicalExpression::Property {
600                variable: "a".to_string(),
601                property: "id".to_string(),
602            },
603            LogicalExpression::Property {
604                variable: "b".to_string(),
605                property: "friend_id".to_string(),
606            },
607        );
608
609        builder.add_join_condition(
610            "a",
611            "c",
612            LogicalExpression::Property {
613                variable: "a".to_string(),
614                property: "company_id".to_string(),
615            },
616            LogicalExpression::Property {
617                variable: "c".to_string(),
618                property: "id".to_string(),
619            },
620        );
621
622        let graph = builder.build();
623        assert_eq!(graph.node_count(), 3);
624    }
625
626    #[test]
627    fn test_dpccp_single_relation() {
628        let mut builder = JoinGraphBuilder::new();
629        builder.add_relation("a", create_node_scan("a", "Person"));
630        let graph = builder.build();
631
632        let cost_model = CostModel::new();
633        let mut card_estimator = CardinalityEstimator::new();
634        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
635
636        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
637        let plan = dpccp.optimize();
638
639        assert!(plan.is_some());
640        let plan = plan.unwrap();
641        assert_eq!(plan.nodes.len(), 1);
642    }
643
644    #[test]
645    fn test_dpccp_two_relations() {
646        let mut builder = JoinGraphBuilder::new();
647        builder.add_relation("a", create_node_scan("a", "Person"));
648        builder.add_relation("b", create_node_scan("b", "Company"));
649
650        builder.add_join_condition(
651            "a",
652            "b",
653            LogicalExpression::Property {
654                variable: "a".to_string(),
655                property: "company_id".to_string(),
656            },
657            LogicalExpression::Property {
658                variable: "b".to_string(),
659                property: "id".to_string(),
660            },
661        );
662
663        let graph = builder.build();
664
665        let cost_model = CostModel::new();
666        let mut card_estimator = CardinalityEstimator::new();
667        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
668        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
669
670        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
671        let plan = dpccp.optimize();
672
673        assert!(plan.is_some());
674        let plan = plan.unwrap();
675        assert_eq!(plan.nodes.len(), 2);
676
677        // The result should be a join
678        if let LogicalOperator::Join(_) = plan.operator {
679            // Good
680        } else {
681            panic!("Expected Join operator");
682        }
683    }
684
685    #[test]
686    fn test_dpccp_three_relations_chain() {
687        // a -[knows]-> b -[works_at]-> c
688        let mut builder = JoinGraphBuilder::new();
689        builder.add_relation("a", create_node_scan("a", "Person"));
690        builder.add_relation("b", create_node_scan("b", "Person"));
691        builder.add_relation("c", create_node_scan("c", "Company"));
692
693        builder.add_join_condition(
694            "a",
695            "b",
696            LogicalExpression::Property {
697                variable: "a".to_string(),
698                property: "knows".to_string(),
699            },
700            LogicalExpression::Property {
701                variable: "b".to_string(),
702                property: "id".to_string(),
703            },
704        );
705
706        builder.add_join_condition(
707            "b",
708            "c",
709            LogicalExpression::Property {
710                variable: "b".to_string(),
711                property: "company_id".to_string(),
712            },
713            LogicalExpression::Property {
714                variable: "c".to_string(),
715                property: "id".to_string(),
716            },
717        );
718
719        let graph = builder.build();
720
721        let cost_model = CostModel::new();
722        let mut card_estimator = CardinalityEstimator::new();
723        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
724        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
725
726        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
727        let plan = dpccp.optimize();
728
729        assert!(plan.is_some());
730        let plan = plan.unwrap();
731        assert_eq!(plan.nodes.len(), 3);
732    }
733
734    #[test]
735    fn test_dpccp_prefers_smaller_intermediate() {
736        // Test that DPccp prefers joining smaller tables first
737        // Setup: Small (100) -[r1]-> Medium (1000) -[r2]-> Large (10000)
738        // Without cost-based ordering, might get (Small ⋈ Large) ⋈ Medium
739        // With cost-based ordering, should get (Small ⋈ Medium) ⋈ Large
740
741        let mut builder = JoinGraphBuilder::new();
742        builder.add_relation("s", create_node_scan("s", "Small"));
743        builder.add_relation("m", create_node_scan("m", "Medium"));
744        builder.add_relation("l", create_node_scan("l", "Large"));
745
746        // Connect all three (star schema)
747        builder.add_join_condition(
748            "s",
749            "m",
750            LogicalExpression::Property {
751                variable: "s".to_string(),
752                property: "id".to_string(),
753            },
754            LogicalExpression::Property {
755                variable: "m".to_string(),
756                property: "s_id".to_string(),
757            },
758        );
759
760        builder.add_join_condition(
761            "m",
762            "l",
763            LogicalExpression::Property {
764                variable: "m".to_string(),
765                property: "id".to_string(),
766            },
767            LogicalExpression::Property {
768                variable: "l".to_string(),
769                property: "m_id".to_string(),
770            },
771        );
772
773        let graph = builder.build();
774
775        let cost_model = CostModel::new();
776        let mut card_estimator = CardinalityEstimator::new();
777        card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
778        card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
779        card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
780
781        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
782        let plan = dpccp.optimize();
783
784        assert!(plan.is_some());
785        let plan = plan.unwrap();
786
787        // The plan should cover all three relations
788        assert_eq!(plan.nodes.len(), 3);
789
790        // We can't easily verify the exact join order without inspecting the tree,
791        // but we can verify the plan was created successfully
792        assert!(plan.cost.total() > 0.0);
793    }
794
795    // Additional BitSet tests
796
797    #[test]
798    fn test_bitset_empty() {
799        let empty = BitSet::empty();
800        assert!(empty.is_empty());
801        assert_eq!(empty.len(), 0);
802        assert!(!empty.contains(0));
803    }
804
805    #[test]
806    fn test_bitset_insert_remove() {
807        let mut set = BitSet::empty();
808        set.insert(3);
809        assert!(set.contains(3));
810        assert_eq!(set.len(), 1);
811
812        set.insert(5);
813        assert!(set.contains(5));
814        assert_eq!(set.len(), 2);
815
816        set.remove(3);
817        assert!(!set.contains(3));
818        assert_eq!(set.len(), 1);
819    }
820
821    #[test]
822    fn test_bitset_intersection() {
823        let a = BitSet::from_iter([0, 1, 2].into_iter());
824        let b = BitSet::from_iter([1, 2, 3].into_iter());
825        let intersection = a.intersection(b);
826
827        assert!(intersection.contains(1));
828        assert!(intersection.contains(2));
829        assert!(!intersection.contains(0));
830        assert!(!intersection.contains(3));
831        assert_eq!(intersection.len(), 2);
832    }
833
834    #[test]
835    fn test_bitset_difference() {
836        let a = BitSet::from_iter([0, 1, 2].into_iter());
837        let b = BitSet::from_iter([1, 2, 3].into_iter());
838        let diff = a.difference(b);
839
840        assert!(diff.contains(0));
841        assert!(!diff.contains(1));
842        assert!(!diff.contains(2));
843        assert_eq!(diff.len(), 1);
844    }
845
846    #[test]
847    fn test_bitset_is_subset_of() {
848        let a = BitSet::from_iter([1, 2].into_iter());
849        let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
850
851        assert!(a.is_subset_of(b));
852        assert!(!b.is_subset_of(a));
853        assert!(a.is_subset_of(a));
854    }
855
856    #[test]
857    fn test_bitset_iter() {
858        let set = BitSet::from_iter([0, 2, 5].into_iter());
859        let elements: Vec<_> = set.iter().collect();
860
861        assert_eq!(elements, vec![0, 2, 5]);
862    }
863
864    // Additional JoinGraph tests
865
866    #[test]
867    fn test_join_graph_empty() {
868        let graph = JoinGraph::new();
869        assert_eq!(graph.node_count(), 0);
870    }
871
872    #[test]
873    fn test_join_graph_neighbors() {
874        let mut builder = JoinGraphBuilder::new();
875        builder.add_relation("a", create_node_scan("a", "A"));
876        builder.add_relation("b", create_node_scan("b", "B"));
877        builder.add_relation("c", create_node_scan("c", "C"));
878
879        builder.add_join_condition(
880            "a",
881            "b",
882            LogicalExpression::Variable("a".to_string()),
883            LogicalExpression::Variable("b".to_string()),
884        );
885        builder.add_join_condition(
886            "a",
887            "c",
888            LogicalExpression::Variable("a".to_string()),
889            LogicalExpression::Variable("c".to_string()),
890        );
891
892        let graph = builder.build();
893
894        // 'a' should have neighbors 'b' and 'c' (indices 1 and 2)
895        let neighbors_a: Vec<_> = graph.neighbors(0).collect();
896        assert_eq!(neighbors_a.len(), 2);
897        assert!(neighbors_a.contains(&1));
898        assert!(neighbors_a.contains(&2));
899
900        // 'b' should have only neighbor 'a'
901        let neighbors_b: Vec<_> = graph.neighbors(1).collect();
902        assert_eq!(neighbors_b.len(), 1);
903        assert!(neighbors_b.contains(&0));
904    }
905
906    #[test]
907    fn test_join_graph_are_connected() {
908        let mut builder = JoinGraphBuilder::new();
909        builder.add_relation("a", create_node_scan("a", "A"));
910        builder.add_relation("b", create_node_scan("b", "B"));
911        builder.add_relation("c", create_node_scan("c", "C"));
912
913        builder.add_join_condition(
914            "a",
915            "b",
916            LogicalExpression::Variable("a".to_string()),
917            LogicalExpression::Variable("b".to_string()),
918        );
919
920        let graph = builder.build();
921
922        let set_a = BitSet::singleton(0);
923        let set_b = BitSet::singleton(1);
924        let set_c = BitSet::singleton(2);
925
926        assert!(graph.are_connected(&set_a, &set_b));
927        assert!(graph.are_connected(&set_b, &set_a));
928        assert!(!graph.are_connected(&set_a, &set_c));
929        assert!(!graph.are_connected(&set_b, &set_c));
930    }
931
932    #[test]
933    fn test_join_graph_get_conditions() {
934        let mut builder = JoinGraphBuilder::new();
935        builder.add_relation("a", create_node_scan("a", "A"));
936        builder.add_relation("b", create_node_scan("b", "B"));
937
938        builder.add_join_condition(
939            "a",
940            "b",
941            LogicalExpression::Property {
942                variable: "a".to_string(),
943                property: "id".to_string(),
944            },
945            LogicalExpression::Property {
946                variable: "b".to_string(),
947                property: "a_id".to_string(),
948            },
949        );
950
951        let graph = builder.build();
952
953        let set_a = BitSet::singleton(0);
954        let set_b = BitSet::singleton(1);
955
956        let conditions = graph.get_conditions(&set_a, &set_b);
957        assert_eq!(conditions.len(), 1);
958    }
959
960    // Additional DPccp tests
961
962    #[test]
963    fn test_dpccp_empty_graph() {
964        let graph = JoinGraph::new();
965        let cost_model = CostModel::new();
966        let card_estimator = CardinalityEstimator::new();
967
968        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
969        let plan = dpccp.optimize();
970
971        assert!(plan.is_none());
972    }
973
974    #[test]
975    fn test_dpccp_star_query() {
976        // Star schema: center connected to all others
977        // center -> a, center -> b, center -> c
978        let mut builder = JoinGraphBuilder::new();
979        builder.add_relation("center", create_node_scan("center", "Center"));
980        builder.add_relation("a", create_node_scan("a", "A"));
981        builder.add_relation("b", create_node_scan("b", "B"));
982        builder.add_relation("c", create_node_scan("c", "C"));
983
984        builder.add_join_condition(
985            "center",
986            "a",
987            LogicalExpression::Variable("center".to_string()),
988            LogicalExpression::Variable("a".to_string()),
989        );
990        builder.add_join_condition(
991            "center",
992            "b",
993            LogicalExpression::Variable("center".to_string()),
994            LogicalExpression::Variable("b".to_string()),
995        );
996        builder.add_join_condition(
997            "center",
998            "c",
999            LogicalExpression::Variable("center".to_string()),
1000            LogicalExpression::Variable("c".to_string()),
1001        );
1002
1003        let graph = builder.build();
1004
1005        let cost_model = CostModel::new();
1006        let mut card_estimator = CardinalityEstimator::new();
1007        card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
1008        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
1009        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
1010        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
1011
1012        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1013        let plan = dpccp.optimize();
1014
1015        assert!(plan.is_some());
1016        let plan = plan.unwrap();
1017        assert_eq!(plan.nodes.len(), 4);
1018        assert!(plan.cost.total() > 0.0);
1019    }
1020
1021    #[test]
1022    fn test_dpccp_cycle_query() {
1023        // Cycle: a -> b -> c -> a
1024        let mut builder = JoinGraphBuilder::new();
1025        builder.add_relation("a", create_node_scan("a", "A"));
1026        builder.add_relation("b", create_node_scan("b", "B"));
1027        builder.add_relation("c", create_node_scan("c", "C"));
1028
1029        builder.add_join_condition(
1030            "a",
1031            "b",
1032            LogicalExpression::Variable("a".to_string()),
1033            LogicalExpression::Variable("b".to_string()),
1034        );
1035        builder.add_join_condition(
1036            "b",
1037            "c",
1038            LogicalExpression::Variable("b".to_string()),
1039            LogicalExpression::Variable("c".to_string()),
1040        );
1041        builder.add_join_condition(
1042            "c",
1043            "a",
1044            LogicalExpression::Variable("c".to_string()),
1045            LogicalExpression::Variable("a".to_string()),
1046        );
1047
1048        let graph = builder.build();
1049
1050        let cost_model = CostModel::new();
1051        let mut card_estimator = CardinalityEstimator::new();
1052        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1053        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1054        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1055
1056        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1057        let plan = dpccp.optimize();
1058
1059        assert!(plan.is_some());
1060        let plan = plan.unwrap();
1061        assert_eq!(plan.nodes.len(), 3);
1062    }
1063
1064    #[test]
1065    fn test_dpccp_four_relations() {
1066        // Chain: a -> b -> c -> d
1067        let mut builder = JoinGraphBuilder::new();
1068        builder.add_relation("a", create_node_scan("a", "A"));
1069        builder.add_relation("b", create_node_scan("b", "B"));
1070        builder.add_relation("c", create_node_scan("c", "C"));
1071        builder.add_relation("d", create_node_scan("d", "D"));
1072
1073        builder.add_join_condition(
1074            "a",
1075            "b",
1076            LogicalExpression::Variable("a".to_string()),
1077            LogicalExpression::Variable("b".to_string()),
1078        );
1079        builder.add_join_condition(
1080            "b",
1081            "c",
1082            LogicalExpression::Variable("b".to_string()),
1083            LogicalExpression::Variable("c".to_string()),
1084        );
1085        builder.add_join_condition(
1086            "c",
1087            "d",
1088            LogicalExpression::Variable("c".to_string()),
1089            LogicalExpression::Variable("d".to_string()),
1090        );
1091
1092        let graph = builder.build();
1093
1094        let cost_model = CostModel::new();
1095        let mut card_estimator = CardinalityEstimator::new();
1096        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1097        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1098        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1099        card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1100
1101        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1102        let plan = dpccp.optimize();
1103
1104        assert!(plan.is_some());
1105        let plan = plan.unwrap();
1106        assert_eq!(plan.nodes.len(), 4);
1107    }
1108
1109    #[test]
1110    fn test_join_plan_cardinality_and_cost() {
1111        let mut builder = JoinGraphBuilder::new();
1112        builder.add_relation("a", create_node_scan("a", "A"));
1113        builder.add_relation("b", create_node_scan("b", "B"));
1114
1115        builder.add_join_condition(
1116            "a",
1117            "b",
1118            LogicalExpression::Variable("a".to_string()),
1119            LogicalExpression::Variable("b".to_string()),
1120        );
1121
1122        let graph = builder.build();
1123
1124        let cost_model = CostModel::new();
1125        let mut card_estimator = CardinalityEstimator::new();
1126        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1127        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1128
1129        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1130        let plan = dpccp.optimize().unwrap();
1131
1132        // Plan should have non-zero cardinality and cost
1133        assert!(plan.cardinality > 0.0);
1134        assert!(plan.cost.total() > 0.0);
1135    }
1136
1137    #[test]
1138    fn test_join_graph_default() {
1139        let graph = JoinGraph::default();
1140        assert_eq!(graph.node_count(), 0);
1141    }
1142
1143    #[test]
1144    fn test_join_graph_builder_default() {
1145        let builder = JoinGraphBuilder::default();
1146        let graph = builder.build();
1147        assert_eq!(graph.node_count(), 0);
1148    }
1149
1150    #[test]
1151    fn test_join_graph_nodes_accessor() {
1152        let mut builder = JoinGraphBuilder::new();
1153        builder.add_relation("a", create_node_scan("a", "A"));
1154        builder.add_relation("b", create_node_scan("b", "B"));
1155
1156        let graph = builder.build();
1157        let nodes = graph.nodes();
1158
1159        assert_eq!(nodes.len(), 2);
1160        assert_eq!(nodes[0].variable, "a");
1161        assert_eq!(nodes[1].variable, "b");
1162    }
1163
1164    #[test]
1165    fn test_join_node_equality() {
1166        let node1 = JoinNode {
1167            id: 0,
1168            variable: "a".to_string(),
1169            relation: create_node_scan("a", "A"),
1170        };
1171        let node2 = JoinNode {
1172            id: 0,
1173            variable: "a".to_string(),
1174            relation: create_node_scan("a", "A"),
1175        };
1176        let node3 = JoinNode {
1177            id: 1,
1178            variable: "a".to_string(),
1179            relation: create_node_scan("a", "A"),
1180        };
1181
1182        assert_eq!(node1, node2);
1183        assert_ne!(node1, node3);
1184    }
1185
1186    #[test]
1187    fn test_join_node_hash() {
1188        use std::collections::HashSet;
1189
1190        let node1 = JoinNode {
1191            id: 0,
1192            variable: "a".to_string(),
1193            relation: create_node_scan("a", "A"),
1194        };
1195        let node2 = JoinNode {
1196            id: 0,
1197            variable: "a".to_string(),
1198            relation: create_node_scan("a", "A"),
1199        };
1200
1201        let mut set = HashSet::new();
1202        set.insert(node1.clone());
1203
1204        // Same id and variable should be considered equal
1205        assert!(set.contains(&node2));
1206    }
1207
1208    #[test]
1209    fn test_add_join_condition_unknown_variable() {
1210        let mut builder = JoinGraphBuilder::new();
1211        builder.add_relation("a", create_node_scan("a", "A"));
1212
1213        // Adding condition with unknown variable should do nothing (no panic)
1214        builder.add_join_condition(
1215            "a",
1216            "unknown",
1217            LogicalExpression::Variable("a".to_string()),
1218            LogicalExpression::Variable("unknown".to_string()),
1219        );
1220
1221        let graph = builder.build();
1222        assert_eq!(graph.node_count(), 1);
1223    }
1224
1225    #[test]
1226    fn test_dpccp_with_different_cardinalities() {
1227        // Test that DPccp handles vastly different cardinalities
1228        let mut builder = JoinGraphBuilder::new();
1229        builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1230        builder.add_relation("huge", create_node_scan("huge", "Huge"));
1231
1232        builder.add_join_condition(
1233            "tiny",
1234            "huge",
1235            LogicalExpression::Variable("tiny".to_string()),
1236            LogicalExpression::Variable("huge".to_string()),
1237        );
1238
1239        let graph = builder.build();
1240
1241        let cost_model = CostModel::new();
1242        let mut card_estimator = CardinalityEstimator::new();
1243        card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1244        card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1245
1246        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1247        let plan = dpccp.optimize();
1248
1249        assert!(plan.is_some());
1250        let plan = plan.unwrap();
1251        assert_eq!(plan.nodes.len(), 2);
1252    }
1253
1254    #[test]
1255    fn test_join_graph_cyclic_triangle() {
1256        // Triangle: a-b, b-c, c-a (3 nodes, 3 edges -> cyclic)
1257        let mut builder = JoinGraphBuilder::new();
1258        builder.add_relation("a", create_node_scan("a", "A"));
1259        builder.add_relation("b", create_node_scan("b", "B"));
1260        builder.add_relation("c", create_node_scan("c", "C"));
1261
1262        builder.add_join_condition(
1263            "a",
1264            "b",
1265            LogicalExpression::Variable("a".to_string()),
1266            LogicalExpression::Variable("b".to_string()),
1267        );
1268        builder.add_join_condition(
1269            "b",
1270            "c",
1271            LogicalExpression::Variable("b".to_string()),
1272            LogicalExpression::Variable("c".to_string()),
1273        );
1274        builder.add_join_condition(
1275            "c",
1276            "a",
1277            LogicalExpression::Variable("c".to_string()),
1278            LogicalExpression::Variable("a".to_string()),
1279        );
1280
1281        let graph = builder.build();
1282        assert!(graph.is_cyclic());
1283    }
1284
1285    #[test]
1286    fn test_join_graph_acyclic_chain() {
1287        // Chain: a-b, b-c (3 nodes, 2 edges -> acyclic)
1288        let mut builder = JoinGraphBuilder::new();
1289        builder.add_relation("a", create_node_scan("a", "A"));
1290        builder.add_relation("b", create_node_scan("b", "B"));
1291        builder.add_relation("c", create_node_scan("c", "C"));
1292
1293        builder.add_join_condition(
1294            "a",
1295            "b",
1296            LogicalExpression::Variable("a".to_string()),
1297            LogicalExpression::Variable("b".to_string()),
1298        );
1299        builder.add_join_condition(
1300            "b",
1301            "c",
1302            LogicalExpression::Variable("b".to_string()),
1303            LogicalExpression::Variable("c".to_string()),
1304        );
1305
1306        let graph = builder.build();
1307        assert!(!graph.is_cyclic());
1308    }
1309
1310    #[test]
1311    fn test_join_graph_empty_not_cyclic() {
1312        let graph = JoinGraph::new();
1313        assert!(!graph.is_cyclic());
1314    }
1315
1316    #[test]
1317    fn test_join_graph_edges_accessor() {
1318        let mut builder = JoinGraphBuilder::new();
1319        builder.add_relation("a", create_node_scan("a", "A"));
1320        builder.add_relation("b", create_node_scan("b", "B"));
1321
1322        builder.add_join_condition(
1323            "a",
1324            "b",
1325            LogicalExpression::Variable("a".to_string()),
1326            LogicalExpression::Variable("b".to_string()),
1327        );
1328
1329        let graph = builder.build();
1330        assert_eq!(graph.edges().len(), 1);
1331    }
1332}