Skip to main content

grafeo_engine/query/optimizer/
join_order.rs

1//! DPccp (Dynamic Programming with connected complement pairs) join ordering.
2//!
3//! This module implements the DPccp algorithm for finding optimal join orderings.
4//! The algorithm works by:
5//! 1. Building a join graph from the query
6//! 2. Enumerating all connected subgraphs
7//! 3. Finding optimal plans for each subgraph using dynamic programming
8//!
9//! Reference: Moerkotte, G., & Neumann, T. (2006). Analysis of Two Existing and
10//! One New Dynamic Programming Algorithm for the Generation of Optimal Bushy
11//! Join Trees without Cross Products.
12
13use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18/// A node in the join graph.
19#[derive(Debug, Clone)]
20pub struct JoinNode {
21    /// Unique identifier for this node.
22    pub id: usize,
23    /// Variable name (e.g., "a" for node (a:Person)).
24    pub variable: String,
25    /// The base relation (scan operator).
26    pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30    fn eq(&self, other: &Self) -> bool {
31        self.id == other.id && self.variable == other.variable
32    }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.id.hash(state);
40        self.variable.hash(state);
41    }
42}
43
44/// An edge in the join graph representing a join condition.
45#[derive(Debug, Clone)]
46pub struct JoinEdge {
47    /// Source node id.
48    pub from: usize,
49    /// Target node id.
50    pub to: usize,
51    /// Join conditions between these nodes.
52    pub conditions: Vec<JoinCondition>,
53}
54
55/// The join graph representing all relations and join conditions in a query.
56#[derive(Debug)]
57pub struct JoinGraph {
58    /// Nodes in the graph.
59    nodes: Vec<JoinNode>,
60    /// Edges (join conditions) between nodes.
61    edges: Vec<JoinEdge>,
62    /// Adjacency list for quick neighbor lookup.
63    adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67    /// Creates a new empty join graph.
68    pub fn new() -> Self {
69        Self {
70            nodes: Vec::new(),
71            edges: Vec::new(),
72            adjacency: HashMap::new(),
73        }
74    }
75
76    /// Adds a node to the graph.
77    pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78        let id = self.nodes.len();
79        self.nodes.push(JoinNode {
80            id,
81            variable,
82            relation,
83        });
84        self.adjacency.insert(id, HashSet::new());
85        id
86    }
87
88    /// Adds a join edge between two nodes.
89    ///
90    /// # Panics
91    ///
92    /// Panics if `from` or `to` were not previously added via `add_node`.
93    pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94        self.edges.push(JoinEdge {
95            from,
96            to,
97            conditions,
98        });
99        // Invariant: add_node() inserts node ID into adjacency map (line 84),
100        // so get_mut succeeds for any ID returned by add_node()
101        self.adjacency
102            .get_mut(&from)
103            .expect("'from' node must be added via add_node() before add_edge()")
104            .insert(to);
105        self.adjacency
106            .get_mut(&to)
107            .expect("'to' node must be added via add_node() before add_edge()")
108            .insert(from);
109    }
110
111    /// Returns the number of nodes.
112    pub fn node_count(&self) -> usize {
113        self.nodes.len()
114    }
115
116    /// Returns a reference to the nodes.
117    pub fn nodes(&self) -> &[JoinNode] {
118        &self.nodes
119    }
120
121    /// Returns neighbors of a node.
122    pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123        self.adjacency.get(&node_id).into_iter().flatten().copied()
124    }
125
126    /// Gets the join conditions between two node sets.
127    pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128        let mut conditions = Vec::new();
129        for edge in &self.edges {
130            let from_in_left = left.contains(edge.from);
131            let from_in_right = right.contains(edge.from);
132            let to_in_left = left.contains(edge.to);
133            let to_in_right = right.contains(edge.to);
134
135            // Edge crosses between left and right
136            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137                conditions.extend(edge.conditions.clone());
138            }
139        }
140        conditions
141    }
142
143    /// Returns the edges in the graph.
144    pub fn edges(&self) -> &[JoinEdge] {
145        &self.edges
146    }
147
148    /// Returns true if the join graph contains a cycle.
149    ///
150    /// A connected graph with N nodes and E edges is cyclic when E >= N.
151    #[must_use]
152    pub fn is_cyclic(&self) -> bool {
153        if self.nodes.is_empty() {
154            return false;
155        }
156        self.edges.len() >= self.nodes.len()
157    }
158
159    /// Checks if two node sets are connected by at least one edge.
160    pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
161        for edge in &self.edges {
162            let from_in_left = left.contains(edge.from);
163            let from_in_right = right.contains(edge.from);
164            let to_in_left = left.contains(edge.to);
165            let to_in_right = right.contains(edge.to);
166
167            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
168                return true;
169            }
170        }
171        false
172    }
173}
174
175impl Default for JoinGraph {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// A bitset for efficient subset representation.
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
183pub struct BitSet(u64);
184
185impl BitSet {
186    /// Creates an empty bitset.
187    pub fn empty() -> Self {
188        Self(0)
189    }
190
191    /// Creates a bitset with a single element.
192    pub fn singleton(i: usize) -> Self {
193        Self(1 << i)
194    }
195
196    /// Creates a bitset from an iterator of indices.
197    pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
198        let mut bits = 0u64;
199        for i in iter {
200            bits |= 1 << i;
201        }
202        Self(bits)
203    }
204
205    /// Creates a full bitset with elements 0..n.
206    pub fn full(n: usize) -> Self {
207        if n == 0 {
208            Self::empty()
209        } else if n >= 64 {
210            Self(u64::MAX)
211        } else {
212            Self((1_u64 << n) - 1)
213        }
214    }
215
216    /// Checks if the set is empty.
217    pub fn is_empty(&self) -> bool {
218        self.0 == 0
219    }
220
221    /// Returns the number of elements.
222    pub fn len(&self) -> usize {
223        self.0.count_ones() as usize
224    }
225
226    /// Checks if the set contains an element.
227    pub fn contains(&self, i: usize) -> bool {
228        (self.0 & (1 << i)) != 0
229    }
230
231    /// Inserts an element.
232    pub fn insert(&mut self, i: usize) {
233        self.0 |= 1 << i;
234    }
235
236    /// Removes an element.
237    pub fn remove(&mut self, i: usize) {
238        self.0 &= !(1 << i);
239    }
240
241    /// Returns the union of two sets.
242    pub fn union(self, other: Self) -> Self {
243        Self(self.0 | other.0)
244    }
245
246    /// Returns the intersection of two sets.
247    pub fn intersection(self, other: Self) -> Self {
248        Self(self.0 & other.0)
249    }
250
251    /// Returns the difference (self - other).
252    pub fn difference(self, other: Self) -> Self {
253        Self(self.0 & !other.0)
254    }
255
256    /// Checks if this set is a subset of another.
257    pub fn is_subset_of(self, other: Self) -> bool {
258        (self.0 & other.0) == self.0
259    }
260
261    /// Iterates over all elements in the set.
262    pub fn iter(self) -> impl Iterator<Item = usize> {
263        (0..64).filter(move |&i| self.contains(i))
264    }
265
266    /// Iterates over all non-empty subsets.
267    pub fn subsets(self) -> SubsetIterator {
268        SubsetIterator {
269            full: self.0,
270            current: Some(self.0),
271        }
272    }
273}
274
275/// Iterator over all subsets of a bitset.
276pub struct SubsetIterator {
277    full: u64,
278    current: Option<u64>,
279}
280
281impl Iterator for SubsetIterator {
282    type Item = BitSet;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        let current = self.current?;
286        if current == 0 {
287            self.current = None;
288            return Some(BitSet(0));
289        }
290        let result = current;
291        // Gosper's hack variant for subset enumeration
292        self.current = Some((current.wrapping_sub(1)) & self.full);
293        if self.current == Some(self.full) {
294            self.current = None;
295        }
296        Some(BitSet(result))
297    }
298}
299
300/// Represents a (partial) join plan.
301#[derive(Debug, Clone)]
302pub struct JoinPlan {
303    /// The subset of nodes covered by this plan.
304    pub nodes: BitSet,
305    /// The logical operator representing this plan.
306    pub operator: LogicalOperator,
307    /// Estimated cost of this plan.
308    pub cost: Cost,
309    /// Estimated cardinality.
310    pub cardinality: f64,
311}
312
313/// DPccp join ordering optimizer.
314pub struct DPccp<'a> {
315    /// The join graph.
316    graph: &'a JoinGraph,
317    /// Cost model for estimating operator costs.
318    cost_model: &'a CostModel,
319    /// Cardinality estimator.
320    card_estimator: &'a CardinalityEstimator,
321    /// Memoization table: subset -> best plan.
322    memo: HashMap<BitSet, JoinPlan>,
323    /// Iteration counter for budget enforcement.
324    iterations: usize,
325}
326
327/// Maximum iterations before falling back to heuristic ordering.
328const DPCCP_ITERATION_BUDGET: usize = 100_000;
329
330impl<'a> DPccp<'a> {
331    /// Creates a new DPccp optimizer.
332    pub fn new(
333        graph: &'a JoinGraph,
334        cost_model: &'a CostModel,
335        card_estimator: &'a CardinalityEstimator,
336    ) -> Self {
337        Self {
338            graph,
339            cost_model,
340            card_estimator,
341            memo: HashMap::new(),
342            iterations: 0,
343        }
344    }
345
346    /// Finds the optimal join order for the graph.
347    pub fn optimize(&mut self) -> Option<JoinPlan> {
348        let n = self.graph.node_count();
349        if n == 0 {
350            return None;
351        }
352        if n == 1 {
353            let node = &self.graph.nodes[0];
354            let cardinality = self.card_estimator.estimate(&node.relation);
355            let cost = self.cost_model.estimate(&node.relation, cardinality);
356            return Some(JoinPlan {
357                nodes: BitSet::singleton(0),
358                operator: node.relation.clone(),
359                cost,
360                cardinality,
361            });
362        }
363
364        // BitSet uses u64, so we can only handle up to 64 relations.
365        // For larger queries, return None to fall back to heuristic ordering.
366        if n > 64 {
367            return None;
368        }
369
370        // Initialize with single relations
371        for (i, node) in self.graph.nodes.iter().enumerate() {
372            let subset = BitSet::singleton(i);
373            let cardinality = self.card_estimator.estimate(&node.relation);
374            let cost = self.cost_model.estimate(&node.relation, cardinality);
375            self.memo.insert(
376                subset,
377                JoinPlan {
378                    nodes: subset,
379                    operator: node.relation.clone(),
380                    cost,
381                    cardinality,
382                },
383            );
384        }
385
386        // Enumerate connected subgraph pairs (ccp)
387        let full_set = BitSet::full(n);
388        self.enumerate_ccp(full_set);
389
390        // Return the best plan for the full set
391        self.memo.get(&full_set).cloned()
392    }
393
394    /// Enumerates connected complement pairs using DPccp algorithm.
395    fn enumerate_ccp(&mut self, s: BitSet) {
396        // Iterate over all proper non-empty subsets
397        for s1 in s.subsets() {
398            // Budget enforcement: stop exploring if we've exceeded the limit.
399            // The best plan found so far (if any) will be returned.
400            self.iterations += 1;
401            if self.iterations > DPCCP_ITERATION_BUDGET {
402                return;
403            }
404
405            if s1.is_empty() || s1 == s {
406                continue;
407            }
408
409            let s2 = s.difference(s1);
410            if s2.is_empty() {
411                continue;
412            }
413
414            // Both s1 and s2 must be connected subsets
415            if !self.is_connected(s1) || !self.is_connected(s2) {
416                continue;
417            }
418
419            // s1 and s2 must be connected to each other
420            if !self.graph.are_connected(&s1, &s2) {
421                continue;
422            }
423
424            // Recursively solve subproblems
425            if !self.memo.contains_key(&s1) {
426                self.enumerate_ccp(s1);
427            }
428            if !self.memo.contains_key(&s2) {
429                self.enumerate_ccp(s2);
430            }
431
432            // Try to build a plan for s by joining s1 and s2
433            if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
434                let conditions = self.graph.get_conditions(&s1, &s2);
435                let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
436
437                // Update memo if this is a better plan
438                let should_update = self.memo.get(&s).map_or(true, |existing| {
439                    new_plan.cost.total() < existing.cost.total()
440                });
441
442                if should_update {
443                    self.memo.insert(s, new_plan);
444                }
445            }
446        }
447    }
448
449    /// Checks if a subset forms a connected subgraph.
450    fn is_connected(&self, subset: BitSet) -> bool {
451        if subset.len() <= 1 {
452            return true;
453        }
454
455        // BFS to check connectivity
456        // Invariant: subset.len() >= 2 (guard on line 400), so iter().next() returns Some
457        let start = subset
458            .iter()
459            .next()
460            .expect("subset is non-empty: len >= 2 checked on line 400");
461        let mut visited = BitSet::singleton(start);
462        let mut queue = vec![start];
463
464        while let Some(node) = queue.pop() {
465            for neighbor in self.graph.neighbors(node) {
466                if subset.contains(neighbor) && !visited.contains(neighbor) {
467                    visited.insert(neighbor);
468                    queue.push(neighbor);
469                }
470            }
471        }
472
473        visited == subset
474    }
475
476    /// Builds a join plan from two sub-plans.
477    fn build_join_plan(
478        &self,
479        left: JoinPlan,
480        right: JoinPlan,
481        conditions: Vec<JoinCondition>,
482    ) -> JoinPlan {
483        let nodes = left.nodes.union(right.nodes);
484
485        // Create the join operator
486        let join_op = LogicalOperator::Join(JoinOp {
487            left: Box::new(left.operator),
488            right: Box::new(right.operator),
489            join_type: JoinType::Inner,
490            conditions,
491        });
492
493        // Estimate cardinality
494        let cardinality = self.card_estimator.estimate(&join_op);
495
496        // Calculate cost (child costs + join cost)
497        let join_cost = self.cost_model.estimate(&join_op, cardinality);
498        let total_cost = left.cost + right.cost + join_cost;
499
500        JoinPlan {
501            nodes,
502            operator: join_op,
503            cost: total_cost,
504            cardinality,
505        }
506    }
507}
508
509/// Extracts a join graph from a query pattern.
510pub struct JoinGraphBuilder {
511    graph: JoinGraph,
512    variable_to_node: HashMap<String, usize>,
513}
514
515impl JoinGraphBuilder {
516    /// Creates a new builder.
517    pub fn new() -> Self {
518        Self {
519            graph: JoinGraph::new(),
520            variable_to_node: HashMap::new(),
521        }
522    }
523
524    /// Adds a base relation (scan).
525    pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
526        let id = self.graph.add_node(variable.to_string(), relation);
527        self.variable_to_node.insert(variable.to_string(), id);
528        id
529    }
530
531    /// Adds a join condition between two variables.
532    pub fn add_join_condition(
533        &mut self,
534        left_var: &str,
535        right_var: &str,
536        left_expr: LogicalExpression,
537        right_expr: LogicalExpression,
538    ) {
539        if let (Some(&left_id), Some(&right_id)) = (
540            self.variable_to_node.get(left_var),
541            self.variable_to_node.get(right_var),
542        ) {
543            self.graph.add_edge(
544                left_id,
545                right_id,
546                vec![JoinCondition {
547                    left: left_expr,
548                    right: right_expr,
549                }],
550            );
551        }
552    }
553
554    /// Builds the join graph.
555    pub fn build(self) -> JoinGraph {
556        self.graph
557    }
558}
559
560impl Default for JoinGraphBuilder {
561    fn default() -> Self {
562        Self::new()
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use crate::query::plan::NodeScanOp;
570
571    fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
572        LogicalOperator::NodeScan(NodeScanOp {
573            variable: var.to_string(),
574            label: Some(label.to_string()),
575            input: None,
576        })
577    }
578
579    #[test]
580    fn test_bitset_operations() {
581        let a = BitSet::singleton(0);
582        let b = BitSet::singleton(1);
583        let _c = BitSet::singleton(2);
584
585        assert!(a.contains(0));
586        assert!(!a.contains(1));
587
588        let ab = a.union(b);
589        assert!(ab.contains(0));
590        assert!(ab.contains(1));
591        assert!(!ab.contains(2));
592
593        let full = BitSet::full(3);
594        assert_eq!(full.len(), 3);
595        assert!(full.contains(0));
596        assert!(full.contains(1));
597        assert!(full.contains(2));
598    }
599
600    #[test]
601    fn test_subset_iteration() {
602        let set = BitSet::from_iter([0, 1].into_iter());
603        let subsets: Vec<_> = set.subsets().collect();
604
605        // Should have 4 subsets: {}, {0}, {1}, {0,1}
606        assert_eq!(subsets.len(), 4);
607        assert!(subsets.contains(&BitSet::empty()));
608        assert!(subsets.contains(&BitSet::singleton(0)));
609        assert!(subsets.contains(&BitSet::singleton(1)));
610        assert!(subsets.contains(&set));
611    }
612
613    #[test]
614    fn test_join_graph_construction() {
615        let mut builder = JoinGraphBuilder::new();
616
617        builder.add_relation("a", create_node_scan("a", "Person"));
618        builder.add_relation("b", create_node_scan("b", "Person"));
619        builder.add_relation("c", create_node_scan("c", "Company"));
620
621        builder.add_join_condition(
622            "a",
623            "b",
624            LogicalExpression::Property {
625                variable: "a".to_string(),
626                property: "id".to_string(),
627            },
628            LogicalExpression::Property {
629                variable: "b".to_string(),
630                property: "friend_id".to_string(),
631            },
632        );
633
634        builder.add_join_condition(
635            "a",
636            "c",
637            LogicalExpression::Property {
638                variable: "a".to_string(),
639                property: "company_id".to_string(),
640            },
641            LogicalExpression::Property {
642                variable: "c".to_string(),
643                property: "id".to_string(),
644            },
645        );
646
647        let graph = builder.build();
648        assert_eq!(graph.node_count(), 3);
649    }
650
651    #[test]
652    fn test_dpccp_single_relation() {
653        let mut builder = JoinGraphBuilder::new();
654        builder.add_relation("a", create_node_scan("a", "Person"));
655        let graph = builder.build();
656
657        let cost_model = CostModel::new();
658        let mut card_estimator = CardinalityEstimator::new();
659        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
660
661        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
662        let plan = dpccp.optimize();
663
664        assert!(plan.is_some());
665        let plan = plan.unwrap();
666        assert_eq!(plan.nodes.len(), 1);
667    }
668
669    #[test]
670    fn test_dpccp_two_relations() {
671        let mut builder = JoinGraphBuilder::new();
672        builder.add_relation("a", create_node_scan("a", "Person"));
673        builder.add_relation("b", create_node_scan("b", "Company"));
674
675        builder.add_join_condition(
676            "a",
677            "b",
678            LogicalExpression::Property {
679                variable: "a".to_string(),
680                property: "company_id".to_string(),
681            },
682            LogicalExpression::Property {
683                variable: "b".to_string(),
684                property: "id".to_string(),
685            },
686        );
687
688        let graph = builder.build();
689
690        let cost_model = CostModel::new();
691        let mut card_estimator = CardinalityEstimator::new();
692        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
693        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
694
695        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
696        let plan = dpccp.optimize();
697
698        assert!(plan.is_some());
699        let plan = plan.unwrap();
700        assert_eq!(plan.nodes.len(), 2);
701
702        // The result should be a join
703        if let LogicalOperator::Join(_) = plan.operator {
704            // Good
705        } else {
706            panic!("Expected Join operator");
707        }
708    }
709
710    #[test]
711    fn test_dpccp_three_relations_chain() {
712        // a -[knows]-> b -[works_at]-> c
713        let mut builder = JoinGraphBuilder::new();
714        builder.add_relation("a", create_node_scan("a", "Person"));
715        builder.add_relation("b", create_node_scan("b", "Person"));
716        builder.add_relation("c", create_node_scan("c", "Company"));
717
718        builder.add_join_condition(
719            "a",
720            "b",
721            LogicalExpression::Property {
722                variable: "a".to_string(),
723                property: "knows".to_string(),
724            },
725            LogicalExpression::Property {
726                variable: "b".to_string(),
727                property: "id".to_string(),
728            },
729        );
730
731        builder.add_join_condition(
732            "b",
733            "c",
734            LogicalExpression::Property {
735                variable: "b".to_string(),
736                property: "company_id".to_string(),
737            },
738            LogicalExpression::Property {
739                variable: "c".to_string(),
740                property: "id".to_string(),
741            },
742        );
743
744        let graph = builder.build();
745
746        let cost_model = CostModel::new();
747        let mut card_estimator = CardinalityEstimator::new();
748        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
749        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
750
751        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
752        let plan = dpccp.optimize();
753
754        assert!(plan.is_some());
755        let plan = plan.unwrap();
756        assert_eq!(plan.nodes.len(), 3);
757    }
758
759    #[test]
760    fn test_dpccp_prefers_smaller_intermediate() {
761        // Test that DPccp prefers joining smaller tables first
762        // Setup: Small (100) -[r1]-> Medium (1000) -[r2]-> Large (10000)
763        // Without cost-based ordering, might get (Small ⋈ Large) ⋈ Medium
764        // With cost-based ordering, should get (Small ⋈ Medium) ⋈ Large
765
766        let mut builder = JoinGraphBuilder::new();
767        builder.add_relation("s", create_node_scan("s", "Small"));
768        builder.add_relation("m", create_node_scan("m", "Medium"));
769        builder.add_relation("l", create_node_scan("l", "Large"));
770
771        // Connect all three (star schema)
772        builder.add_join_condition(
773            "s",
774            "m",
775            LogicalExpression::Property {
776                variable: "s".to_string(),
777                property: "id".to_string(),
778            },
779            LogicalExpression::Property {
780                variable: "m".to_string(),
781                property: "s_id".to_string(),
782            },
783        );
784
785        builder.add_join_condition(
786            "m",
787            "l",
788            LogicalExpression::Property {
789                variable: "m".to_string(),
790                property: "id".to_string(),
791            },
792            LogicalExpression::Property {
793                variable: "l".to_string(),
794                property: "m_id".to_string(),
795            },
796        );
797
798        let graph = builder.build();
799
800        let cost_model = CostModel::new();
801        let mut card_estimator = CardinalityEstimator::new();
802        card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
803        card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
804        card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
805
806        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
807        let plan = dpccp.optimize();
808
809        assert!(plan.is_some());
810        let plan = plan.unwrap();
811
812        // The plan should cover all three relations
813        assert_eq!(plan.nodes.len(), 3);
814
815        // We can't easily verify the exact join order without inspecting the tree,
816        // but we can verify the plan was created successfully
817        assert!(plan.cost.total() > 0.0);
818    }
819
820    // Additional BitSet tests
821
822    #[test]
823    fn test_bitset_empty() {
824        let empty = BitSet::empty();
825        assert!(empty.is_empty());
826        assert_eq!(empty.len(), 0);
827        assert!(!empty.contains(0));
828    }
829
830    #[test]
831    fn test_bitset_insert_remove() {
832        let mut set = BitSet::empty();
833        set.insert(3);
834        assert!(set.contains(3));
835        assert_eq!(set.len(), 1);
836
837        set.insert(5);
838        assert!(set.contains(5));
839        assert_eq!(set.len(), 2);
840
841        set.remove(3);
842        assert!(!set.contains(3));
843        assert_eq!(set.len(), 1);
844    }
845
846    #[test]
847    fn test_bitset_intersection() {
848        let a = BitSet::from_iter([0, 1, 2].into_iter());
849        let b = BitSet::from_iter([1, 2, 3].into_iter());
850        let intersection = a.intersection(b);
851
852        assert!(intersection.contains(1));
853        assert!(intersection.contains(2));
854        assert!(!intersection.contains(0));
855        assert!(!intersection.contains(3));
856        assert_eq!(intersection.len(), 2);
857    }
858
859    #[test]
860    fn test_bitset_difference() {
861        let a = BitSet::from_iter([0, 1, 2].into_iter());
862        let b = BitSet::from_iter([1, 2, 3].into_iter());
863        let diff = a.difference(b);
864
865        assert!(diff.contains(0));
866        assert!(!diff.contains(1));
867        assert!(!diff.contains(2));
868        assert_eq!(diff.len(), 1);
869    }
870
871    #[test]
872    fn test_bitset_is_subset_of() {
873        let a = BitSet::from_iter([1, 2].into_iter());
874        let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
875
876        assert!(a.is_subset_of(b));
877        assert!(!b.is_subset_of(a));
878        assert!(a.is_subset_of(a));
879    }
880
881    #[test]
882    fn test_bitset_iter() {
883        let set = BitSet::from_iter([0, 2, 5].into_iter());
884        let elements: Vec<_> = set.iter().collect();
885
886        assert_eq!(elements, vec![0, 2, 5]);
887    }
888
889    // Additional JoinGraph tests
890
891    #[test]
892    fn test_join_graph_empty() {
893        let graph = JoinGraph::new();
894        assert_eq!(graph.node_count(), 0);
895    }
896
897    #[test]
898    fn test_join_graph_neighbors() {
899        let mut builder = JoinGraphBuilder::new();
900        builder.add_relation("a", create_node_scan("a", "A"));
901        builder.add_relation("b", create_node_scan("b", "B"));
902        builder.add_relation("c", create_node_scan("c", "C"));
903
904        builder.add_join_condition(
905            "a",
906            "b",
907            LogicalExpression::Variable("a".to_string()),
908            LogicalExpression::Variable("b".to_string()),
909        );
910        builder.add_join_condition(
911            "a",
912            "c",
913            LogicalExpression::Variable("a".to_string()),
914            LogicalExpression::Variable("c".to_string()),
915        );
916
917        let graph = builder.build();
918
919        // 'a' should have neighbors 'b' and 'c' (indices 1 and 2)
920        let neighbors_a: Vec<_> = graph.neighbors(0).collect();
921        assert_eq!(neighbors_a.len(), 2);
922        assert!(neighbors_a.contains(&1));
923        assert!(neighbors_a.contains(&2));
924
925        // 'b' should have only neighbor 'a'
926        let neighbors_b: Vec<_> = graph.neighbors(1).collect();
927        assert_eq!(neighbors_b.len(), 1);
928        assert!(neighbors_b.contains(&0));
929    }
930
931    #[test]
932    fn test_join_graph_are_connected() {
933        let mut builder = JoinGraphBuilder::new();
934        builder.add_relation("a", create_node_scan("a", "A"));
935        builder.add_relation("b", create_node_scan("b", "B"));
936        builder.add_relation("c", create_node_scan("c", "C"));
937
938        builder.add_join_condition(
939            "a",
940            "b",
941            LogicalExpression::Variable("a".to_string()),
942            LogicalExpression::Variable("b".to_string()),
943        );
944
945        let graph = builder.build();
946
947        let set_a = BitSet::singleton(0);
948        let set_b = BitSet::singleton(1);
949        let set_c = BitSet::singleton(2);
950
951        assert!(graph.are_connected(&set_a, &set_b));
952        assert!(graph.are_connected(&set_b, &set_a));
953        assert!(!graph.are_connected(&set_a, &set_c));
954        assert!(!graph.are_connected(&set_b, &set_c));
955    }
956
957    #[test]
958    fn test_join_graph_get_conditions() {
959        let mut builder = JoinGraphBuilder::new();
960        builder.add_relation("a", create_node_scan("a", "A"));
961        builder.add_relation("b", create_node_scan("b", "B"));
962
963        builder.add_join_condition(
964            "a",
965            "b",
966            LogicalExpression::Property {
967                variable: "a".to_string(),
968                property: "id".to_string(),
969            },
970            LogicalExpression::Property {
971                variable: "b".to_string(),
972                property: "a_id".to_string(),
973            },
974        );
975
976        let graph = builder.build();
977
978        let set_a = BitSet::singleton(0);
979        let set_b = BitSet::singleton(1);
980
981        let conditions = graph.get_conditions(&set_a, &set_b);
982        assert_eq!(conditions.len(), 1);
983    }
984
985    // Additional DPccp tests
986
987    #[test]
988    fn test_dpccp_empty_graph() {
989        let graph = JoinGraph::new();
990        let cost_model = CostModel::new();
991        let card_estimator = CardinalityEstimator::new();
992
993        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
994        let plan = dpccp.optimize();
995
996        assert!(plan.is_none());
997    }
998
999    #[test]
1000    fn test_dpccp_star_query() {
1001        // Star schema: center connected to all others
1002        // center -> a, center -> b, center -> c
1003        let mut builder = JoinGraphBuilder::new();
1004        builder.add_relation("center", create_node_scan("center", "Center"));
1005        builder.add_relation("a", create_node_scan("a", "A"));
1006        builder.add_relation("b", create_node_scan("b", "B"));
1007        builder.add_relation("c", create_node_scan("c", "C"));
1008
1009        builder.add_join_condition(
1010            "center",
1011            "a",
1012            LogicalExpression::Variable("center".to_string()),
1013            LogicalExpression::Variable("a".to_string()),
1014        );
1015        builder.add_join_condition(
1016            "center",
1017            "b",
1018            LogicalExpression::Variable("center".to_string()),
1019            LogicalExpression::Variable("b".to_string()),
1020        );
1021        builder.add_join_condition(
1022            "center",
1023            "c",
1024            LogicalExpression::Variable("center".to_string()),
1025            LogicalExpression::Variable("c".to_string()),
1026        );
1027
1028        let graph = builder.build();
1029
1030        let cost_model = CostModel::new();
1031        let mut card_estimator = CardinalityEstimator::new();
1032        card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
1033        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
1034        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
1035        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
1036
1037        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1038        let plan = dpccp.optimize();
1039
1040        assert!(plan.is_some());
1041        let plan = plan.unwrap();
1042        assert_eq!(plan.nodes.len(), 4);
1043        assert!(plan.cost.total() > 0.0);
1044    }
1045
1046    #[test]
1047    fn test_dpccp_cycle_query() {
1048        // Cycle: a -> b -> c -> a
1049        let mut builder = JoinGraphBuilder::new();
1050        builder.add_relation("a", create_node_scan("a", "A"));
1051        builder.add_relation("b", create_node_scan("b", "B"));
1052        builder.add_relation("c", create_node_scan("c", "C"));
1053
1054        builder.add_join_condition(
1055            "a",
1056            "b",
1057            LogicalExpression::Variable("a".to_string()),
1058            LogicalExpression::Variable("b".to_string()),
1059        );
1060        builder.add_join_condition(
1061            "b",
1062            "c",
1063            LogicalExpression::Variable("b".to_string()),
1064            LogicalExpression::Variable("c".to_string()),
1065        );
1066        builder.add_join_condition(
1067            "c",
1068            "a",
1069            LogicalExpression::Variable("c".to_string()),
1070            LogicalExpression::Variable("a".to_string()),
1071        );
1072
1073        let graph = builder.build();
1074
1075        let cost_model = CostModel::new();
1076        let mut card_estimator = CardinalityEstimator::new();
1077        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1078        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1079        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1080
1081        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1082        let plan = dpccp.optimize();
1083
1084        assert!(plan.is_some());
1085        let plan = plan.unwrap();
1086        assert_eq!(plan.nodes.len(), 3);
1087    }
1088
1089    #[test]
1090    fn test_dpccp_four_relations() {
1091        // Chain: a -> b -> c -> d
1092        let mut builder = JoinGraphBuilder::new();
1093        builder.add_relation("a", create_node_scan("a", "A"));
1094        builder.add_relation("b", create_node_scan("b", "B"));
1095        builder.add_relation("c", create_node_scan("c", "C"));
1096        builder.add_relation("d", create_node_scan("d", "D"));
1097
1098        builder.add_join_condition(
1099            "a",
1100            "b",
1101            LogicalExpression::Variable("a".to_string()),
1102            LogicalExpression::Variable("b".to_string()),
1103        );
1104        builder.add_join_condition(
1105            "b",
1106            "c",
1107            LogicalExpression::Variable("b".to_string()),
1108            LogicalExpression::Variable("c".to_string()),
1109        );
1110        builder.add_join_condition(
1111            "c",
1112            "d",
1113            LogicalExpression::Variable("c".to_string()),
1114            LogicalExpression::Variable("d".to_string()),
1115        );
1116
1117        let graph = builder.build();
1118
1119        let cost_model = CostModel::new();
1120        let mut card_estimator = CardinalityEstimator::new();
1121        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1122        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1123        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1124        card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1125
1126        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1127        let plan = dpccp.optimize();
1128
1129        assert!(plan.is_some());
1130        let plan = plan.unwrap();
1131        assert_eq!(plan.nodes.len(), 4);
1132    }
1133
1134    #[test]
1135    fn test_join_plan_cardinality_and_cost() {
1136        let mut builder = JoinGraphBuilder::new();
1137        builder.add_relation("a", create_node_scan("a", "A"));
1138        builder.add_relation("b", create_node_scan("b", "B"));
1139
1140        builder.add_join_condition(
1141            "a",
1142            "b",
1143            LogicalExpression::Variable("a".to_string()),
1144            LogicalExpression::Variable("b".to_string()),
1145        );
1146
1147        let graph = builder.build();
1148
1149        let cost_model = CostModel::new();
1150        let mut card_estimator = CardinalityEstimator::new();
1151        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1152        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1153
1154        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1155        let plan = dpccp.optimize().unwrap();
1156
1157        // Plan should have non-zero cardinality and cost
1158        assert!(plan.cardinality > 0.0);
1159        assert!(plan.cost.total() > 0.0);
1160    }
1161
1162    #[test]
1163    fn test_join_graph_default() {
1164        let graph = JoinGraph::default();
1165        assert_eq!(graph.node_count(), 0);
1166    }
1167
1168    #[test]
1169    fn test_join_graph_builder_default() {
1170        let builder = JoinGraphBuilder::default();
1171        let graph = builder.build();
1172        assert_eq!(graph.node_count(), 0);
1173    }
1174
1175    #[test]
1176    fn test_join_graph_nodes_accessor() {
1177        let mut builder = JoinGraphBuilder::new();
1178        builder.add_relation("a", create_node_scan("a", "A"));
1179        builder.add_relation("b", create_node_scan("b", "B"));
1180
1181        let graph = builder.build();
1182        let nodes = graph.nodes();
1183
1184        assert_eq!(nodes.len(), 2);
1185        assert_eq!(nodes[0].variable, "a");
1186        assert_eq!(nodes[1].variable, "b");
1187    }
1188
1189    #[test]
1190    fn test_join_node_equality() {
1191        let node1 = JoinNode {
1192            id: 0,
1193            variable: "a".to_string(),
1194            relation: create_node_scan("a", "A"),
1195        };
1196        let node2 = JoinNode {
1197            id: 0,
1198            variable: "a".to_string(),
1199            relation: create_node_scan("a", "A"),
1200        };
1201        let node3 = JoinNode {
1202            id: 1,
1203            variable: "a".to_string(),
1204            relation: create_node_scan("a", "A"),
1205        };
1206
1207        assert_eq!(node1, node2);
1208        assert_ne!(node1, node3);
1209    }
1210
1211    #[test]
1212    fn test_join_node_hash() {
1213        use std::collections::HashSet;
1214
1215        let node1 = JoinNode {
1216            id: 0,
1217            variable: "a".to_string(),
1218            relation: create_node_scan("a", "A"),
1219        };
1220        let node2 = JoinNode {
1221            id: 0,
1222            variable: "a".to_string(),
1223            relation: create_node_scan("a", "A"),
1224        };
1225
1226        let mut set = HashSet::new();
1227        set.insert(node1.clone());
1228
1229        // Same id and variable should be considered equal
1230        assert!(set.contains(&node2));
1231    }
1232
1233    #[test]
1234    fn test_add_join_condition_unknown_variable() {
1235        let mut builder = JoinGraphBuilder::new();
1236        builder.add_relation("a", create_node_scan("a", "A"));
1237
1238        // Adding condition with unknown variable should do nothing (no panic)
1239        builder.add_join_condition(
1240            "a",
1241            "unknown",
1242            LogicalExpression::Variable("a".to_string()),
1243            LogicalExpression::Variable("unknown".to_string()),
1244        );
1245
1246        let graph = builder.build();
1247        assert_eq!(graph.node_count(), 1);
1248    }
1249
1250    #[test]
1251    fn test_dpccp_with_different_cardinalities() {
1252        // Test that DPccp handles vastly different cardinalities
1253        let mut builder = JoinGraphBuilder::new();
1254        builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1255        builder.add_relation("huge", create_node_scan("huge", "Huge"));
1256
1257        builder.add_join_condition(
1258            "tiny",
1259            "huge",
1260            LogicalExpression::Variable("tiny".to_string()),
1261            LogicalExpression::Variable("huge".to_string()),
1262        );
1263
1264        let graph = builder.build();
1265
1266        let cost_model = CostModel::new();
1267        let mut card_estimator = CardinalityEstimator::new();
1268        card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1269        card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1270
1271        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1272        let plan = dpccp.optimize();
1273
1274        assert!(plan.is_some());
1275        let plan = plan.unwrap();
1276        assert_eq!(plan.nodes.len(), 2);
1277    }
1278
1279    #[test]
1280    fn test_join_graph_cyclic_triangle() {
1281        // Triangle: a-b, b-c, c-a (3 nodes, 3 edges -> cyclic)
1282        let mut builder = JoinGraphBuilder::new();
1283        builder.add_relation("a", create_node_scan("a", "A"));
1284        builder.add_relation("b", create_node_scan("b", "B"));
1285        builder.add_relation("c", create_node_scan("c", "C"));
1286
1287        builder.add_join_condition(
1288            "a",
1289            "b",
1290            LogicalExpression::Variable("a".to_string()),
1291            LogicalExpression::Variable("b".to_string()),
1292        );
1293        builder.add_join_condition(
1294            "b",
1295            "c",
1296            LogicalExpression::Variable("b".to_string()),
1297            LogicalExpression::Variable("c".to_string()),
1298        );
1299        builder.add_join_condition(
1300            "c",
1301            "a",
1302            LogicalExpression::Variable("c".to_string()),
1303            LogicalExpression::Variable("a".to_string()),
1304        );
1305
1306        let graph = builder.build();
1307        assert!(graph.is_cyclic());
1308    }
1309
1310    #[test]
1311    fn test_join_graph_acyclic_chain() {
1312        // Chain: a-b, b-c (3 nodes, 2 edges -> acyclic)
1313        let mut builder = JoinGraphBuilder::new();
1314        builder.add_relation("a", create_node_scan("a", "A"));
1315        builder.add_relation("b", create_node_scan("b", "B"));
1316        builder.add_relation("c", create_node_scan("c", "C"));
1317
1318        builder.add_join_condition(
1319            "a",
1320            "b",
1321            LogicalExpression::Variable("a".to_string()),
1322            LogicalExpression::Variable("b".to_string()),
1323        );
1324        builder.add_join_condition(
1325            "b",
1326            "c",
1327            LogicalExpression::Variable("b".to_string()),
1328            LogicalExpression::Variable("c".to_string()),
1329        );
1330
1331        let graph = builder.build();
1332        assert!(!graph.is_cyclic());
1333    }
1334
1335    #[test]
1336    fn test_join_graph_empty_not_cyclic() {
1337        let graph = JoinGraph::new();
1338        assert!(!graph.is_cyclic());
1339    }
1340
1341    #[test]
1342    fn test_join_graph_edges_accessor() {
1343        let mut builder = JoinGraphBuilder::new();
1344        builder.add_relation("a", create_node_scan("a", "A"));
1345        builder.add_relation("b", create_node_scan("b", "B"));
1346
1347        builder.add_join_condition(
1348            "a",
1349            "b",
1350            LogicalExpression::Variable("a".to_string()),
1351            LogicalExpression::Variable("b".to_string()),
1352        );
1353
1354        let graph = builder.build();
1355        assert_eq!(graph.edges().len(), 1);
1356    }
1357
1358    #[test]
1359    fn test_bitset_full_boundary_values() {
1360        // H4: BitSet::full must handle edge cases without overflow
1361        assert_eq!(BitSet::full(0), BitSet::empty());
1362        assert_eq!(BitSet::full(1).0, 1); // 0b1
1363        assert_eq!(BitSet::full(2).0, 3); // 0b11
1364        assert_eq!(BitSet::full(63).0, u64::MAX >> 1); // 2^63 - 1
1365        assert_eq!(BitSet::full(64).0, u64::MAX); // All bits set
1366    }
1367
1368    #[test]
1369    fn test_bitset_full_overflow_protection() {
1370        // H4: n > 64 must saturate to all-bits-set, not panic or wrap
1371        let big = BitSet::full(65);
1372        assert_eq!(big.0, u64::MAX);
1373        let huge = BitSet::full(100);
1374        assert_eq!(huge.0, u64::MAX);
1375        let max = BitSet::full(usize::MAX);
1376        assert_eq!(max.0, u64::MAX);
1377    }
1378}