Skip to main content

grafeo_engine/query/optimizer/
join_order.rs

1//! DPccp (Dynamic Programming with connected complement pairs) join ordering.
2//!
3//! This module implements the DPccp algorithm for finding optimal join orderings.
4//! The algorithm works by:
5//! 1. Building a join graph from the query
6//! 2. Enumerating all connected subgraphs
7//! 3. Finding optimal plans for each subgraph using dynamic programming
8//!
9//! Reference: Moerkotte, G., & Neumann, T. (2006). Analysis of Two Existing and
10//! One New Dynamic Programming Algorithm for the Generation of Optimal Bushy
11//! Join Trees without Cross Products.
12
13use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18/// A node in the join graph.
19#[derive(Debug, Clone)]
20pub struct JoinNode {
21    /// Unique identifier for this node.
22    pub id: usize,
23    /// Variable name (e.g., "a" for node (a:Person)).
24    pub variable: String,
25    /// The base relation (scan operator).
26    pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30    fn eq(&self, other: &Self) -> bool {
31        self.id == other.id && self.variable == other.variable
32    }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.id.hash(state);
40        self.variable.hash(state);
41    }
42}
43
44/// An edge in the join graph representing a join condition.
45#[derive(Debug, Clone)]
46pub struct JoinEdge {
47    /// Source node id.
48    pub from: usize,
49    /// Target node id.
50    pub to: usize,
51    /// Join conditions between these nodes.
52    pub conditions: Vec<JoinCondition>,
53}
54
55/// The join graph representing all relations and join conditions in a query.
56#[derive(Debug)]
57pub struct JoinGraph {
58    /// Nodes in the graph.
59    nodes: Vec<JoinNode>,
60    /// Edges (join conditions) between nodes.
61    edges: Vec<JoinEdge>,
62    /// Adjacency list for quick neighbor lookup.
63    adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67    /// Creates a new empty join graph.
68    pub fn new() -> Self {
69        Self {
70            nodes: Vec::new(),
71            edges: Vec::new(),
72            adjacency: HashMap::new(),
73        }
74    }
75
76    /// Adds a node to the graph.
77    pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78        let id = self.nodes.len();
79        self.nodes.push(JoinNode {
80            id,
81            variable,
82            relation,
83        });
84        self.adjacency.insert(id, HashSet::new());
85        id
86    }
87
88    /// Adds a join edge between two nodes.
89    ///
90    /// # Panics
91    ///
92    /// Panics if `from` or `to` were not previously added via `add_node`.
93    pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94        self.edges.push(JoinEdge {
95            from,
96            to,
97            conditions,
98        });
99        // Invariant: add_node() inserts node ID into adjacency map (line 84),
100        // so get_mut succeeds for any ID returned by add_node()
101        self.adjacency
102            .get_mut(&from)
103            .expect("'from' node must be added via add_node() before add_edge()")
104            .insert(to);
105        self.adjacency
106            .get_mut(&to)
107            .expect("'to' node must be added via add_node() before add_edge()")
108            .insert(from);
109    }
110
111    /// Returns the number of nodes.
112    pub fn node_count(&self) -> usize {
113        self.nodes.len()
114    }
115
116    /// Returns a reference to the nodes.
117    pub fn nodes(&self) -> &[JoinNode] {
118        &self.nodes
119    }
120
121    /// Returns neighbors of a node.
122    pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123        self.adjacency.get(&node_id).into_iter().flatten().copied()
124    }
125
126    /// Gets the join conditions between two node sets.
127    pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128        let mut conditions = Vec::new();
129        for edge in &self.edges {
130            let from_in_left = left.contains(edge.from);
131            let from_in_right = right.contains(edge.from);
132            let to_in_left = left.contains(edge.to);
133            let to_in_right = right.contains(edge.to);
134
135            // Edge crosses between left and right
136            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137                conditions.extend(edge.conditions.clone());
138            }
139        }
140        conditions
141    }
142
143    /// Checks if two node sets are connected by at least one edge.
144    pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
145        for edge in &self.edges {
146            let from_in_left = left.contains(edge.from);
147            let from_in_right = right.contains(edge.from);
148            let to_in_left = left.contains(edge.to);
149            let to_in_right = right.contains(edge.to);
150
151            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
152                return true;
153            }
154        }
155        false
156    }
157}
158
159impl Default for JoinGraph {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165/// A bitset for efficient subset representation.
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167pub struct BitSet(u64);
168
169impl BitSet {
170    /// Creates an empty bitset.
171    pub fn empty() -> Self {
172        Self(0)
173    }
174
175    /// Creates a bitset with a single element.
176    pub fn singleton(i: usize) -> Self {
177        Self(1 << i)
178    }
179
180    /// Creates a bitset from an iterator of indices.
181    pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
182        let mut bits = 0u64;
183        for i in iter {
184            bits |= 1 << i;
185        }
186        Self(bits)
187    }
188
189    /// Creates a full bitset with elements 0..n.
190    pub fn full(n: usize) -> Self {
191        Self((1 << n) - 1)
192    }
193
194    /// Checks if the set is empty.
195    pub fn is_empty(&self) -> bool {
196        self.0 == 0
197    }
198
199    /// Returns the number of elements.
200    pub fn len(&self) -> usize {
201        self.0.count_ones() as usize
202    }
203
204    /// Checks if the set contains an element.
205    pub fn contains(&self, i: usize) -> bool {
206        (self.0 & (1 << i)) != 0
207    }
208
209    /// Inserts an element.
210    pub fn insert(&mut self, i: usize) {
211        self.0 |= 1 << i;
212    }
213
214    /// Removes an element.
215    pub fn remove(&mut self, i: usize) {
216        self.0 &= !(1 << i);
217    }
218
219    /// Returns the union of two sets.
220    pub fn union(self, other: Self) -> Self {
221        Self(self.0 | other.0)
222    }
223
224    /// Returns the intersection of two sets.
225    pub fn intersection(self, other: Self) -> Self {
226        Self(self.0 & other.0)
227    }
228
229    /// Returns the difference (self - other).
230    pub fn difference(self, other: Self) -> Self {
231        Self(self.0 & !other.0)
232    }
233
234    /// Checks if this set is a subset of another.
235    pub fn is_subset_of(self, other: Self) -> bool {
236        (self.0 & other.0) == self.0
237    }
238
239    /// Iterates over all elements in the set.
240    pub fn iter(self) -> impl Iterator<Item = usize> {
241        (0..64).filter(move |&i| self.contains(i))
242    }
243
244    /// Iterates over all non-empty subsets.
245    pub fn subsets(self) -> SubsetIterator {
246        SubsetIterator {
247            full: self.0,
248            current: Some(self.0),
249        }
250    }
251}
252
253/// Iterator over all subsets of a bitset.
254pub struct SubsetIterator {
255    full: u64,
256    current: Option<u64>,
257}
258
259impl Iterator for SubsetIterator {
260    type Item = BitSet;
261
262    fn next(&mut self) -> Option<Self::Item> {
263        let current = self.current?;
264        if current == 0 {
265            self.current = None;
266            return Some(BitSet(0));
267        }
268        let result = current;
269        // Gosper's hack variant for subset enumeration
270        self.current = Some((current.wrapping_sub(1)) & self.full);
271        if self.current == Some(self.full) {
272            self.current = None;
273        }
274        Some(BitSet(result))
275    }
276}
277
278/// Represents a (partial) join plan.
279#[derive(Debug, Clone)]
280pub struct JoinPlan {
281    /// The subset of nodes covered by this plan.
282    pub nodes: BitSet,
283    /// The logical operator representing this plan.
284    pub operator: LogicalOperator,
285    /// Estimated cost of this plan.
286    pub cost: Cost,
287    /// Estimated cardinality.
288    pub cardinality: f64,
289}
290
291/// DPccp join ordering optimizer.
292pub struct DPccp<'a> {
293    /// The join graph.
294    graph: &'a JoinGraph,
295    /// Cost model for estimating operator costs.
296    cost_model: &'a CostModel,
297    /// Cardinality estimator.
298    card_estimator: &'a CardinalityEstimator,
299    /// Memoization table: subset -> best plan.
300    memo: HashMap<BitSet, JoinPlan>,
301}
302
303impl<'a> DPccp<'a> {
304    /// Creates a new DPccp optimizer.
305    pub fn new(
306        graph: &'a JoinGraph,
307        cost_model: &'a CostModel,
308        card_estimator: &'a CardinalityEstimator,
309    ) -> Self {
310        Self {
311            graph,
312            cost_model,
313            card_estimator,
314            memo: HashMap::new(),
315        }
316    }
317
318    /// Finds the optimal join order for the graph.
319    pub fn optimize(&mut self) -> Option<JoinPlan> {
320        let n = self.graph.node_count();
321        if n == 0 {
322            return None;
323        }
324        if n == 1 {
325            let node = &self.graph.nodes[0];
326            let cardinality = self.card_estimator.estimate(&node.relation);
327            let cost = self.cost_model.estimate(&node.relation, cardinality);
328            return Some(JoinPlan {
329                nodes: BitSet::singleton(0),
330                operator: node.relation.clone(),
331                cost,
332                cardinality,
333            });
334        }
335
336        // Initialize with single relations
337        for (i, node) in self.graph.nodes.iter().enumerate() {
338            let subset = BitSet::singleton(i);
339            let cardinality = self.card_estimator.estimate(&node.relation);
340            let cost = self.cost_model.estimate(&node.relation, cardinality);
341            self.memo.insert(
342                subset,
343                JoinPlan {
344                    nodes: subset,
345                    operator: node.relation.clone(),
346                    cost,
347                    cardinality,
348                },
349            );
350        }
351
352        // Enumerate connected subgraph pairs (ccp)
353        let full_set = BitSet::full(n);
354        self.enumerate_ccp(full_set);
355
356        // Return the best plan for the full set
357        self.memo.get(&full_set).cloned()
358    }
359
360    /// Enumerates connected complement pairs using DPccp algorithm.
361    fn enumerate_ccp(&mut self, s: BitSet) {
362        // Iterate over all proper non-empty subsets
363        for s1 in s.subsets() {
364            if s1.is_empty() || s1 == s {
365                continue;
366            }
367
368            let s2 = s.difference(s1);
369            if s2.is_empty() {
370                continue;
371            }
372
373            // Both s1 and s2 must be connected subsets
374            if !self.is_connected(s1) || !self.is_connected(s2) {
375                continue;
376            }
377
378            // s1 and s2 must be connected to each other
379            if !self.graph.are_connected(&s1, &s2) {
380                continue;
381            }
382
383            // Recursively solve subproblems
384            if !self.memo.contains_key(&s1) {
385                self.enumerate_ccp(s1);
386            }
387            if !self.memo.contains_key(&s2) {
388                self.enumerate_ccp(s2);
389            }
390
391            // Try to build a plan for s by joining s1 and s2
392            if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
393                let conditions = self.graph.get_conditions(&s1, &s2);
394                let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
395
396                // Update memo if this is a better plan
397                let should_update = self.memo.get(&s).map_or(true, |existing| {
398                    new_plan.cost.total() < existing.cost.total()
399                });
400
401                if should_update {
402                    self.memo.insert(s, new_plan);
403                }
404            }
405        }
406    }
407
408    /// Checks if a subset forms a connected subgraph.
409    fn is_connected(&self, subset: BitSet) -> bool {
410        if subset.len() <= 1 {
411            return true;
412        }
413
414        // BFS to check connectivity
415        // Invariant: subset.len() >= 2 (guard on line 400), so iter().next() returns Some
416        let start = subset
417            .iter()
418            .next()
419            .expect("subset is non-empty: len >= 2 checked on line 400");
420        let mut visited = BitSet::singleton(start);
421        let mut queue = vec![start];
422
423        while let Some(node) = queue.pop() {
424            for neighbor in self.graph.neighbors(node) {
425                if subset.contains(neighbor) && !visited.contains(neighbor) {
426                    visited.insert(neighbor);
427                    queue.push(neighbor);
428                }
429            }
430        }
431
432        visited == subset
433    }
434
435    /// Builds a join plan from two sub-plans.
436    fn build_join_plan(
437        &self,
438        left: JoinPlan,
439        right: JoinPlan,
440        conditions: Vec<JoinCondition>,
441    ) -> JoinPlan {
442        let nodes = left.nodes.union(right.nodes);
443
444        // Create the join operator
445        let join_op = LogicalOperator::Join(JoinOp {
446            left: Box::new(left.operator),
447            right: Box::new(right.operator),
448            join_type: JoinType::Inner,
449            conditions,
450        });
451
452        // Estimate cardinality
453        let cardinality = self.card_estimator.estimate(&join_op);
454
455        // Calculate cost (child costs + join cost)
456        let join_cost = self.cost_model.estimate(&join_op, cardinality);
457        let total_cost = left.cost + right.cost + join_cost;
458
459        JoinPlan {
460            nodes,
461            operator: join_op,
462            cost: total_cost,
463            cardinality,
464        }
465    }
466}
467
468/// Extracts a join graph from a query pattern.
469pub struct JoinGraphBuilder {
470    graph: JoinGraph,
471    variable_to_node: HashMap<String, usize>,
472}
473
474impl JoinGraphBuilder {
475    /// Creates a new builder.
476    pub fn new() -> Self {
477        Self {
478            graph: JoinGraph::new(),
479            variable_to_node: HashMap::new(),
480        }
481    }
482
483    /// Adds a base relation (scan).
484    pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
485        let id = self.graph.add_node(variable.to_string(), relation);
486        self.variable_to_node.insert(variable.to_string(), id);
487        id
488    }
489
490    /// Adds a join condition between two variables.
491    pub fn add_join_condition(
492        &mut self,
493        left_var: &str,
494        right_var: &str,
495        left_expr: LogicalExpression,
496        right_expr: LogicalExpression,
497    ) {
498        if let (Some(&left_id), Some(&right_id)) = (
499            self.variable_to_node.get(left_var),
500            self.variable_to_node.get(right_var),
501        ) {
502            self.graph.add_edge(
503                left_id,
504                right_id,
505                vec![JoinCondition {
506                    left: left_expr,
507                    right: right_expr,
508                }],
509            );
510        }
511    }
512
513    /// Builds the join graph.
514    pub fn build(self) -> JoinGraph {
515        self.graph
516    }
517}
518
519impl Default for JoinGraphBuilder {
520    fn default() -> Self {
521        Self::new()
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use crate::query::plan::NodeScanOp;
529
530    fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
531        LogicalOperator::NodeScan(NodeScanOp {
532            variable: var.to_string(),
533            label: Some(label.to_string()),
534            input: None,
535        })
536    }
537
538    #[test]
539    fn test_bitset_operations() {
540        let a = BitSet::singleton(0);
541        let b = BitSet::singleton(1);
542        let _c = BitSet::singleton(2);
543
544        assert!(a.contains(0));
545        assert!(!a.contains(1));
546
547        let ab = a.union(b);
548        assert!(ab.contains(0));
549        assert!(ab.contains(1));
550        assert!(!ab.contains(2));
551
552        let full = BitSet::full(3);
553        assert_eq!(full.len(), 3);
554        assert!(full.contains(0));
555        assert!(full.contains(1));
556        assert!(full.contains(2));
557    }
558
559    #[test]
560    fn test_subset_iteration() {
561        let set = BitSet::from_iter([0, 1].into_iter());
562        let subsets: Vec<_> = set.subsets().collect();
563
564        // Should have 4 subsets: {}, {0}, {1}, {0,1}
565        assert_eq!(subsets.len(), 4);
566        assert!(subsets.contains(&BitSet::empty()));
567        assert!(subsets.contains(&BitSet::singleton(0)));
568        assert!(subsets.contains(&BitSet::singleton(1)));
569        assert!(subsets.contains(&set));
570    }
571
572    #[test]
573    fn test_join_graph_construction() {
574        let mut builder = JoinGraphBuilder::new();
575
576        builder.add_relation("a", create_node_scan("a", "Person"));
577        builder.add_relation("b", create_node_scan("b", "Person"));
578        builder.add_relation("c", create_node_scan("c", "Company"));
579
580        builder.add_join_condition(
581            "a",
582            "b",
583            LogicalExpression::Property {
584                variable: "a".to_string(),
585                property: "id".to_string(),
586            },
587            LogicalExpression::Property {
588                variable: "b".to_string(),
589                property: "friend_id".to_string(),
590            },
591        );
592
593        builder.add_join_condition(
594            "a",
595            "c",
596            LogicalExpression::Property {
597                variable: "a".to_string(),
598                property: "company_id".to_string(),
599            },
600            LogicalExpression::Property {
601                variable: "c".to_string(),
602                property: "id".to_string(),
603            },
604        );
605
606        let graph = builder.build();
607        assert_eq!(graph.node_count(), 3);
608    }
609
610    #[test]
611    fn test_dpccp_single_relation() {
612        let mut builder = JoinGraphBuilder::new();
613        builder.add_relation("a", create_node_scan("a", "Person"));
614        let graph = builder.build();
615
616        let cost_model = CostModel::new();
617        let mut card_estimator = CardinalityEstimator::new();
618        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
619
620        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
621        let plan = dpccp.optimize();
622
623        assert!(plan.is_some());
624        let plan = plan.unwrap();
625        assert_eq!(plan.nodes.len(), 1);
626    }
627
628    #[test]
629    fn test_dpccp_two_relations() {
630        let mut builder = JoinGraphBuilder::new();
631        builder.add_relation("a", create_node_scan("a", "Person"));
632        builder.add_relation("b", create_node_scan("b", "Company"));
633
634        builder.add_join_condition(
635            "a",
636            "b",
637            LogicalExpression::Property {
638                variable: "a".to_string(),
639                property: "company_id".to_string(),
640            },
641            LogicalExpression::Property {
642                variable: "b".to_string(),
643                property: "id".to_string(),
644            },
645        );
646
647        let graph = builder.build();
648
649        let cost_model = CostModel::new();
650        let mut card_estimator = CardinalityEstimator::new();
651        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
652        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
653
654        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
655        let plan = dpccp.optimize();
656
657        assert!(plan.is_some());
658        let plan = plan.unwrap();
659        assert_eq!(plan.nodes.len(), 2);
660
661        // The result should be a join
662        if let LogicalOperator::Join(_) = plan.operator {
663            // Good
664        } else {
665            panic!("Expected Join operator");
666        }
667    }
668
669    #[test]
670    fn test_dpccp_three_relations_chain() {
671        // a -[knows]-> b -[works_at]-> c
672        let mut builder = JoinGraphBuilder::new();
673        builder.add_relation("a", create_node_scan("a", "Person"));
674        builder.add_relation("b", create_node_scan("b", "Person"));
675        builder.add_relation("c", create_node_scan("c", "Company"));
676
677        builder.add_join_condition(
678            "a",
679            "b",
680            LogicalExpression::Property {
681                variable: "a".to_string(),
682                property: "knows".to_string(),
683            },
684            LogicalExpression::Property {
685                variable: "b".to_string(),
686                property: "id".to_string(),
687            },
688        );
689
690        builder.add_join_condition(
691            "b",
692            "c",
693            LogicalExpression::Property {
694                variable: "b".to_string(),
695                property: "company_id".to_string(),
696            },
697            LogicalExpression::Property {
698                variable: "c".to_string(),
699                property: "id".to_string(),
700            },
701        );
702
703        let graph = builder.build();
704
705        let cost_model = CostModel::new();
706        let mut card_estimator = CardinalityEstimator::new();
707        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
708        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
709
710        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
711        let plan = dpccp.optimize();
712
713        assert!(plan.is_some());
714        let plan = plan.unwrap();
715        assert_eq!(plan.nodes.len(), 3);
716    }
717
718    #[test]
719    fn test_dpccp_prefers_smaller_intermediate() {
720        // Test that DPccp prefers joining smaller tables first
721        // Setup: Small (100) -[r1]-> Medium (1000) -[r2]-> Large (10000)
722        // Without cost-based ordering, might get (Small ⋈ Large) ⋈ Medium
723        // With cost-based ordering, should get (Small ⋈ Medium) ⋈ Large
724
725        let mut builder = JoinGraphBuilder::new();
726        builder.add_relation("s", create_node_scan("s", "Small"));
727        builder.add_relation("m", create_node_scan("m", "Medium"));
728        builder.add_relation("l", create_node_scan("l", "Large"));
729
730        // Connect all three (star schema)
731        builder.add_join_condition(
732            "s",
733            "m",
734            LogicalExpression::Property {
735                variable: "s".to_string(),
736                property: "id".to_string(),
737            },
738            LogicalExpression::Property {
739                variable: "m".to_string(),
740                property: "s_id".to_string(),
741            },
742        );
743
744        builder.add_join_condition(
745            "m",
746            "l",
747            LogicalExpression::Property {
748                variable: "m".to_string(),
749                property: "id".to_string(),
750            },
751            LogicalExpression::Property {
752                variable: "l".to_string(),
753                property: "m_id".to_string(),
754            },
755        );
756
757        let graph = builder.build();
758
759        let cost_model = CostModel::new();
760        let mut card_estimator = CardinalityEstimator::new();
761        card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
762        card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
763        card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
764
765        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
766        let plan = dpccp.optimize();
767
768        assert!(plan.is_some());
769        let plan = plan.unwrap();
770
771        // The plan should cover all three relations
772        assert_eq!(plan.nodes.len(), 3);
773
774        // We can't easily verify the exact join order without inspecting the tree,
775        // but we can verify the plan was created successfully
776        assert!(plan.cost.total() > 0.0);
777    }
778
779    // Additional BitSet tests
780
781    #[test]
782    fn test_bitset_empty() {
783        let empty = BitSet::empty();
784        assert!(empty.is_empty());
785        assert_eq!(empty.len(), 0);
786        assert!(!empty.contains(0));
787    }
788
789    #[test]
790    fn test_bitset_insert_remove() {
791        let mut set = BitSet::empty();
792        set.insert(3);
793        assert!(set.contains(3));
794        assert_eq!(set.len(), 1);
795
796        set.insert(5);
797        assert!(set.contains(5));
798        assert_eq!(set.len(), 2);
799
800        set.remove(3);
801        assert!(!set.contains(3));
802        assert_eq!(set.len(), 1);
803    }
804
805    #[test]
806    fn test_bitset_intersection() {
807        let a = BitSet::from_iter([0, 1, 2].into_iter());
808        let b = BitSet::from_iter([1, 2, 3].into_iter());
809        let intersection = a.intersection(b);
810
811        assert!(intersection.contains(1));
812        assert!(intersection.contains(2));
813        assert!(!intersection.contains(0));
814        assert!(!intersection.contains(3));
815        assert_eq!(intersection.len(), 2);
816    }
817
818    #[test]
819    fn test_bitset_difference() {
820        let a = BitSet::from_iter([0, 1, 2].into_iter());
821        let b = BitSet::from_iter([1, 2, 3].into_iter());
822        let diff = a.difference(b);
823
824        assert!(diff.contains(0));
825        assert!(!diff.contains(1));
826        assert!(!diff.contains(2));
827        assert_eq!(diff.len(), 1);
828    }
829
830    #[test]
831    fn test_bitset_is_subset_of() {
832        let a = BitSet::from_iter([1, 2].into_iter());
833        let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
834
835        assert!(a.is_subset_of(b));
836        assert!(!b.is_subset_of(a));
837        assert!(a.is_subset_of(a));
838    }
839
840    #[test]
841    fn test_bitset_iter() {
842        let set = BitSet::from_iter([0, 2, 5].into_iter());
843        let elements: Vec<_> = set.iter().collect();
844
845        assert_eq!(elements, vec![0, 2, 5]);
846    }
847
848    // Additional JoinGraph tests
849
850    #[test]
851    fn test_join_graph_empty() {
852        let graph = JoinGraph::new();
853        assert_eq!(graph.node_count(), 0);
854    }
855
856    #[test]
857    fn test_join_graph_neighbors() {
858        let mut builder = JoinGraphBuilder::new();
859        builder.add_relation("a", create_node_scan("a", "A"));
860        builder.add_relation("b", create_node_scan("b", "B"));
861        builder.add_relation("c", create_node_scan("c", "C"));
862
863        builder.add_join_condition(
864            "a",
865            "b",
866            LogicalExpression::Variable("a".to_string()),
867            LogicalExpression::Variable("b".to_string()),
868        );
869        builder.add_join_condition(
870            "a",
871            "c",
872            LogicalExpression::Variable("a".to_string()),
873            LogicalExpression::Variable("c".to_string()),
874        );
875
876        let graph = builder.build();
877
878        // 'a' should have neighbors 'b' and 'c' (indices 1 and 2)
879        let neighbors_a: Vec<_> = graph.neighbors(0).collect();
880        assert_eq!(neighbors_a.len(), 2);
881        assert!(neighbors_a.contains(&1));
882        assert!(neighbors_a.contains(&2));
883
884        // 'b' should have only neighbor 'a'
885        let neighbors_b: Vec<_> = graph.neighbors(1).collect();
886        assert_eq!(neighbors_b.len(), 1);
887        assert!(neighbors_b.contains(&0));
888    }
889
890    #[test]
891    fn test_join_graph_are_connected() {
892        let mut builder = JoinGraphBuilder::new();
893        builder.add_relation("a", create_node_scan("a", "A"));
894        builder.add_relation("b", create_node_scan("b", "B"));
895        builder.add_relation("c", create_node_scan("c", "C"));
896
897        builder.add_join_condition(
898            "a",
899            "b",
900            LogicalExpression::Variable("a".to_string()),
901            LogicalExpression::Variable("b".to_string()),
902        );
903
904        let graph = builder.build();
905
906        let set_a = BitSet::singleton(0);
907        let set_b = BitSet::singleton(1);
908        let set_c = BitSet::singleton(2);
909
910        assert!(graph.are_connected(&set_a, &set_b));
911        assert!(graph.are_connected(&set_b, &set_a));
912        assert!(!graph.are_connected(&set_a, &set_c));
913        assert!(!graph.are_connected(&set_b, &set_c));
914    }
915
916    #[test]
917    fn test_join_graph_get_conditions() {
918        let mut builder = JoinGraphBuilder::new();
919        builder.add_relation("a", create_node_scan("a", "A"));
920        builder.add_relation("b", create_node_scan("b", "B"));
921
922        builder.add_join_condition(
923            "a",
924            "b",
925            LogicalExpression::Property {
926                variable: "a".to_string(),
927                property: "id".to_string(),
928            },
929            LogicalExpression::Property {
930                variable: "b".to_string(),
931                property: "a_id".to_string(),
932            },
933        );
934
935        let graph = builder.build();
936
937        let set_a = BitSet::singleton(0);
938        let set_b = BitSet::singleton(1);
939
940        let conditions = graph.get_conditions(&set_a, &set_b);
941        assert_eq!(conditions.len(), 1);
942    }
943
944    // Additional DPccp tests
945
946    #[test]
947    fn test_dpccp_empty_graph() {
948        let graph = JoinGraph::new();
949        let cost_model = CostModel::new();
950        let card_estimator = CardinalityEstimator::new();
951
952        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
953        let plan = dpccp.optimize();
954
955        assert!(plan.is_none());
956    }
957
958    #[test]
959    fn test_dpccp_star_query() {
960        // Star schema: center connected to all others
961        // center -> a, center -> b, center -> c
962        let mut builder = JoinGraphBuilder::new();
963        builder.add_relation("center", create_node_scan("center", "Center"));
964        builder.add_relation("a", create_node_scan("a", "A"));
965        builder.add_relation("b", create_node_scan("b", "B"));
966        builder.add_relation("c", create_node_scan("c", "C"));
967
968        builder.add_join_condition(
969            "center",
970            "a",
971            LogicalExpression::Variable("center".to_string()),
972            LogicalExpression::Variable("a".to_string()),
973        );
974        builder.add_join_condition(
975            "center",
976            "b",
977            LogicalExpression::Variable("center".to_string()),
978            LogicalExpression::Variable("b".to_string()),
979        );
980        builder.add_join_condition(
981            "center",
982            "c",
983            LogicalExpression::Variable("center".to_string()),
984            LogicalExpression::Variable("c".to_string()),
985        );
986
987        let graph = builder.build();
988
989        let cost_model = CostModel::new();
990        let mut card_estimator = CardinalityEstimator::new();
991        card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
992        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
993        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
994        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
995
996        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
997        let plan = dpccp.optimize();
998
999        assert!(plan.is_some());
1000        let plan = plan.unwrap();
1001        assert_eq!(plan.nodes.len(), 4);
1002        assert!(plan.cost.total() > 0.0);
1003    }
1004
1005    #[test]
1006    fn test_dpccp_cycle_query() {
1007        // Cycle: a -> b -> c -> a
1008        let mut builder = JoinGraphBuilder::new();
1009        builder.add_relation("a", create_node_scan("a", "A"));
1010        builder.add_relation("b", create_node_scan("b", "B"));
1011        builder.add_relation("c", create_node_scan("c", "C"));
1012
1013        builder.add_join_condition(
1014            "a",
1015            "b",
1016            LogicalExpression::Variable("a".to_string()),
1017            LogicalExpression::Variable("b".to_string()),
1018        );
1019        builder.add_join_condition(
1020            "b",
1021            "c",
1022            LogicalExpression::Variable("b".to_string()),
1023            LogicalExpression::Variable("c".to_string()),
1024        );
1025        builder.add_join_condition(
1026            "c",
1027            "a",
1028            LogicalExpression::Variable("c".to_string()),
1029            LogicalExpression::Variable("a".to_string()),
1030        );
1031
1032        let graph = builder.build();
1033
1034        let cost_model = CostModel::new();
1035        let mut card_estimator = CardinalityEstimator::new();
1036        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1037        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1038        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1039
1040        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1041        let plan = dpccp.optimize();
1042
1043        assert!(plan.is_some());
1044        let plan = plan.unwrap();
1045        assert_eq!(plan.nodes.len(), 3);
1046    }
1047
1048    #[test]
1049    fn test_dpccp_four_relations() {
1050        // Chain: a -> b -> c -> d
1051        let mut builder = JoinGraphBuilder::new();
1052        builder.add_relation("a", create_node_scan("a", "A"));
1053        builder.add_relation("b", create_node_scan("b", "B"));
1054        builder.add_relation("c", create_node_scan("c", "C"));
1055        builder.add_relation("d", create_node_scan("d", "D"));
1056
1057        builder.add_join_condition(
1058            "a",
1059            "b",
1060            LogicalExpression::Variable("a".to_string()),
1061            LogicalExpression::Variable("b".to_string()),
1062        );
1063        builder.add_join_condition(
1064            "b",
1065            "c",
1066            LogicalExpression::Variable("b".to_string()),
1067            LogicalExpression::Variable("c".to_string()),
1068        );
1069        builder.add_join_condition(
1070            "c",
1071            "d",
1072            LogicalExpression::Variable("c".to_string()),
1073            LogicalExpression::Variable("d".to_string()),
1074        );
1075
1076        let graph = builder.build();
1077
1078        let cost_model = CostModel::new();
1079        let mut card_estimator = CardinalityEstimator::new();
1080        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1081        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1082        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1083        card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1084
1085        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1086        let plan = dpccp.optimize();
1087
1088        assert!(plan.is_some());
1089        let plan = plan.unwrap();
1090        assert_eq!(plan.nodes.len(), 4);
1091    }
1092
1093    #[test]
1094    fn test_join_plan_cardinality_and_cost() {
1095        let mut builder = JoinGraphBuilder::new();
1096        builder.add_relation("a", create_node_scan("a", "A"));
1097        builder.add_relation("b", create_node_scan("b", "B"));
1098
1099        builder.add_join_condition(
1100            "a",
1101            "b",
1102            LogicalExpression::Variable("a".to_string()),
1103            LogicalExpression::Variable("b".to_string()),
1104        );
1105
1106        let graph = builder.build();
1107
1108        let cost_model = CostModel::new();
1109        let mut card_estimator = CardinalityEstimator::new();
1110        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1111        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1112
1113        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1114        let plan = dpccp.optimize().unwrap();
1115
1116        // Plan should have non-zero cardinality and cost
1117        assert!(plan.cardinality > 0.0);
1118        assert!(plan.cost.total() > 0.0);
1119    }
1120
1121    #[test]
1122    fn test_join_graph_default() {
1123        let graph = JoinGraph::default();
1124        assert_eq!(graph.node_count(), 0);
1125    }
1126
1127    #[test]
1128    fn test_join_graph_builder_default() {
1129        let builder = JoinGraphBuilder::default();
1130        let graph = builder.build();
1131        assert_eq!(graph.node_count(), 0);
1132    }
1133
1134    #[test]
1135    fn test_join_graph_nodes_accessor() {
1136        let mut builder = JoinGraphBuilder::new();
1137        builder.add_relation("a", create_node_scan("a", "A"));
1138        builder.add_relation("b", create_node_scan("b", "B"));
1139
1140        let graph = builder.build();
1141        let nodes = graph.nodes();
1142
1143        assert_eq!(nodes.len(), 2);
1144        assert_eq!(nodes[0].variable, "a");
1145        assert_eq!(nodes[1].variable, "b");
1146    }
1147
1148    #[test]
1149    fn test_join_node_equality() {
1150        let node1 = JoinNode {
1151            id: 0,
1152            variable: "a".to_string(),
1153            relation: create_node_scan("a", "A"),
1154        };
1155        let node2 = JoinNode {
1156            id: 0,
1157            variable: "a".to_string(),
1158            relation: create_node_scan("a", "A"),
1159        };
1160        let node3 = JoinNode {
1161            id: 1,
1162            variable: "a".to_string(),
1163            relation: create_node_scan("a", "A"),
1164        };
1165
1166        assert_eq!(node1, node2);
1167        assert_ne!(node1, node3);
1168    }
1169
1170    #[test]
1171    fn test_join_node_hash() {
1172        use std::collections::HashSet;
1173
1174        let node1 = JoinNode {
1175            id: 0,
1176            variable: "a".to_string(),
1177            relation: create_node_scan("a", "A"),
1178        };
1179        let node2 = JoinNode {
1180            id: 0,
1181            variable: "a".to_string(),
1182            relation: create_node_scan("a", "A"),
1183        };
1184
1185        let mut set = HashSet::new();
1186        set.insert(node1.clone());
1187
1188        // Same id and variable should be considered equal
1189        assert!(set.contains(&node2));
1190    }
1191
1192    #[test]
1193    fn test_add_join_condition_unknown_variable() {
1194        let mut builder = JoinGraphBuilder::new();
1195        builder.add_relation("a", create_node_scan("a", "A"));
1196
1197        // Adding condition with unknown variable should do nothing (no panic)
1198        builder.add_join_condition(
1199            "a",
1200            "unknown",
1201            LogicalExpression::Variable("a".to_string()),
1202            LogicalExpression::Variable("unknown".to_string()),
1203        );
1204
1205        let graph = builder.build();
1206        assert_eq!(graph.node_count(), 1);
1207    }
1208
1209    #[test]
1210    fn test_dpccp_with_different_cardinalities() {
1211        // Test that DPccp handles vastly different cardinalities
1212        let mut builder = JoinGraphBuilder::new();
1213        builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1214        builder.add_relation("huge", create_node_scan("huge", "Huge"));
1215
1216        builder.add_join_condition(
1217            "tiny",
1218            "huge",
1219            LogicalExpression::Variable("tiny".to_string()),
1220            LogicalExpression::Variable("huge".to_string()),
1221        );
1222
1223        let graph = builder.build();
1224
1225        let cost_model = CostModel::new();
1226        let mut card_estimator = CardinalityEstimator::new();
1227        card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1228        card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1229
1230        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1231        let plan = dpccp.optimize();
1232
1233        assert!(plan.is_some());
1234        let plan = plan.unwrap();
1235        assert_eq!(plan.nodes.len(), 2);
1236    }
1237}