Skip to main content

grafeo_engine/query/optimizer/
join_order.rs

1//! DPccp (Dynamic Programming with connected complement pairs) join ordering.
2//!
3//! This module implements the DPccp algorithm for finding optimal join orderings.
4//! The algorithm works by:
5//! 1. Building a join graph from the query
6//! 2. Enumerating all connected subgraphs
7//! 3. Finding optimal plans for each subgraph using dynamic programming
8//!
9//! Reference: Moerkotte, G., & Neumann, T. (2006). Analysis of Two Existing and
10//! One New Dynamic Programming Algorithm for the Generation of Optimal Bushy
11//! Join Trees without Cross Products.
12
13use super::cardinality::CardinalityEstimator;
14use super::cost::{Cost, CostModel};
15use crate::query::plan::{JoinCondition, JoinOp, JoinType, LogicalExpression, LogicalOperator};
16use std::collections::{HashMap, HashSet};
17
18/// A node in the join graph.
19#[derive(Debug, Clone)]
20pub struct JoinNode {
21    /// Unique identifier for this node.
22    pub id: usize,
23    /// Variable name (e.g., "a" for node (a:Person)).
24    pub variable: String,
25    /// The base relation (scan operator).
26    pub relation: LogicalOperator,
27}
28
29impl PartialEq for JoinNode {
30    fn eq(&self, other: &Self) -> bool {
31        self.id == other.id && self.variable == other.variable
32    }
33}
34
35impl Eq for JoinNode {}
36
37impl std::hash::Hash for JoinNode {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.id.hash(state);
40        self.variable.hash(state);
41    }
42}
43
44/// An edge in the join graph representing a join condition.
45#[derive(Debug, Clone)]
46pub struct JoinEdge {
47    /// Source node id.
48    pub from: usize,
49    /// Target node id.
50    pub to: usize,
51    /// Join conditions between these nodes.
52    pub conditions: Vec<JoinCondition>,
53}
54
55/// The join graph representing all relations and join conditions in a query.
56#[derive(Debug)]
57pub struct JoinGraph {
58    /// Nodes in the graph.
59    nodes: Vec<JoinNode>,
60    /// Edges (join conditions) between nodes.
61    edges: Vec<JoinEdge>,
62    /// Adjacency list for quick neighbor lookup.
63    adjacency: HashMap<usize, HashSet<usize>>,
64}
65
66impl JoinGraph {
67    /// Creates a new empty join graph.
68    pub fn new() -> Self {
69        Self {
70            nodes: Vec::new(),
71            edges: Vec::new(),
72            adjacency: HashMap::new(),
73        }
74    }
75
76    /// Adds a node to the graph.
77    pub fn add_node(&mut self, variable: String, relation: LogicalOperator) -> usize {
78        let id = self.nodes.len();
79        self.nodes.push(JoinNode {
80            id,
81            variable,
82            relation,
83        });
84        self.adjacency.insert(id, HashSet::new());
85        id
86    }
87
88    /// Adds a join edge between two nodes.
89    ///
90    /// # Panics
91    ///
92    /// Panics if `from` or `to` were not previously added via `add_node`.
93    pub fn add_edge(&mut self, from: usize, to: usize, conditions: Vec<JoinCondition>) {
94        self.edges.push(JoinEdge {
95            from,
96            to,
97            conditions,
98        });
99        // Invariant: add_node() inserts node ID into adjacency map (line 84),
100        // so get_mut succeeds for any ID returned by add_node()
101        self.adjacency
102            .get_mut(&from)
103            .expect("'from' node must be added via add_node() before add_edge()")
104            .insert(to);
105        self.adjacency
106            .get_mut(&to)
107            .expect("'to' node must be added via add_node() before add_edge()")
108            .insert(from);
109    }
110
111    /// Returns the number of nodes.
112    pub fn node_count(&self) -> usize {
113        self.nodes.len()
114    }
115
116    /// Returns a reference to the nodes.
117    pub fn nodes(&self) -> &[JoinNode] {
118        &self.nodes
119    }
120
121    /// Returns neighbors of a node.
122    pub fn neighbors(&self, node_id: usize) -> impl Iterator<Item = usize> + '_ {
123        self.adjacency.get(&node_id).into_iter().flatten().copied()
124    }
125
126    /// Gets the join conditions between two node sets.
127    pub fn get_conditions(&self, left: &BitSet, right: &BitSet) -> Vec<JoinCondition> {
128        let mut conditions = Vec::new();
129        for edge in &self.edges {
130            let from_in_left = left.contains(edge.from);
131            let from_in_right = right.contains(edge.from);
132            let to_in_left = left.contains(edge.to);
133            let to_in_right = right.contains(edge.to);
134
135            // Edge crosses between left and right
136            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
137                conditions.extend(edge.conditions.clone());
138            }
139        }
140        conditions
141    }
142
143    /// Checks if two node sets are connected by at least one edge.
144    pub fn are_connected(&self, left: &BitSet, right: &BitSet) -> bool {
145        for edge in &self.edges {
146            let from_in_left = left.contains(edge.from);
147            let from_in_right = right.contains(edge.from);
148            let to_in_left = left.contains(edge.to);
149            let to_in_right = right.contains(edge.to);
150
151            if (from_in_left && to_in_right) || (from_in_right && to_in_left) {
152                return true;
153            }
154        }
155        false
156    }
157}
158
159impl Default for JoinGraph {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165/// A bitset for efficient subset representation.
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167pub struct BitSet(u64);
168
169impl BitSet {
170    /// Creates an empty bitset.
171    pub fn empty() -> Self {
172        Self(0)
173    }
174
175    /// Creates a bitset with a single element.
176    pub fn singleton(i: usize) -> Self {
177        Self(1 << i)
178    }
179
180    /// Creates a bitset from an iterator of indices.
181    pub fn from_iter(iter: impl Iterator<Item = usize>) -> Self {
182        let mut bits = 0u64;
183        for i in iter {
184            bits |= 1 << i;
185        }
186        Self(bits)
187    }
188
189    /// Creates a full bitset with elements 0..n.
190    pub fn full(n: usize) -> Self {
191        Self((1 << n) - 1)
192    }
193
194    /// Checks if the set is empty.
195    pub fn is_empty(&self) -> bool {
196        self.0 == 0
197    }
198
199    /// Returns the number of elements.
200    pub fn len(&self) -> usize {
201        self.0.count_ones() as usize
202    }
203
204    /// Checks if the set contains an element.
205    pub fn contains(&self, i: usize) -> bool {
206        (self.0 & (1 << i)) != 0
207    }
208
209    /// Inserts an element.
210    pub fn insert(&mut self, i: usize) {
211        self.0 |= 1 << i;
212    }
213
214    /// Removes an element.
215    pub fn remove(&mut self, i: usize) {
216        self.0 &= !(1 << i);
217    }
218
219    /// Returns the union of two sets.
220    pub fn union(self, other: Self) -> Self {
221        Self(self.0 | other.0)
222    }
223
224    /// Returns the intersection of two sets.
225    pub fn intersection(self, other: Self) -> Self {
226        Self(self.0 & other.0)
227    }
228
229    /// Returns the difference (self - other).
230    pub fn difference(self, other: Self) -> Self {
231        Self(self.0 & !other.0)
232    }
233
234    /// Checks if this set is a subset of another.
235    pub fn is_subset_of(self, other: Self) -> bool {
236        (self.0 & other.0) == self.0
237    }
238
239    /// Iterates over all elements in the set.
240    pub fn iter(self) -> impl Iterator<Item = usize> {
241        (0..64).filter(move |&i| self.contains(i))
242    }
243
244    /// Iterates over all non-empty subsets.
245    pub fn subsets(self) -> SubsetIterator {
246        SubsetIterator {
247            full: self.0,
248            current: Some(self.0),
249        }
250    }
251}
252
253/// Iterator over all subsets of a bitset.
254pub struct SubsetIterator {
255    full: u64,
256    current: Option<u64>,
257}
258
259impl Iterator for SubsetIterator {
260    type Item = BitSet;
261
262    fn next(&mut self) -> Option<Self::Item> {
263        let current = self.current?;
264        if current == 0 {
265            self.current = None;
266            return Some(BitSet(0));
267        }
268        let result = current;
269        // Gosper's hack variant for subset enumeration
270        self.current = Some((current.wrapping_sub(1)) & self.full);
271        if self.current == Some(self.full) {
272            self.current = None;
273        }
274        Some(BitSet(result))
275    }
276}
277
278/// Represents a (partial) join plan.
279#[derive(Debug, Clone)]
280pub struct JoinPlan {
281    /// The subset of nodes covered by this plan.
282    pub nodes: BitSet,
283    /// The logical operator representing this plan.
284    pub operator: LogicalOperator,
285    /// Estimated cost of this plan.
286    pub cost: Cost,
287    /// Estimated cardinality.
288    pub cardinality: f64,
289}
290
291/// DPccp join ordering optimizer.
292pub struct DPccp<'a> {
293    /// The join graph.
294    graph: &'a JoinGraph,
295    /// Cost model for estimating operator costs.
296    cost_model: &'a CostModel,
297    /// Cardinality estimator.
298    card_estimator: &'a CardinalityEstimator,
299    /// Memoization table: subset -> best plan.
300    memo: HashMap<BitSet, JoinPlan>,
301}
302
303impl<'a> DPccp<'a> {
304    /// Creates a new DPccp optimizer.
305    pub fn new(
306        graph: &'a JoinGraph,
307        cost_model: &'a CostModel,
308        card_estimator: &'a CardinalityEstimator,
309    ) -> Self {
310        Self {
311            graph,
312            cost_model,
313            card_estimator,
314            memo: HashMap::new(),
315        }
316    }
317
318    /// Finds the optimal join order for the graph.
319    pub fn optimize(&mut self) -> Option<JoinPlan> {
320        let n = self.graph.node_count();
321        if n == 0 {
322            return None;
323        }
324        if n == 1 {
325            let node = &self.graph.nodes[0];
326            let cardinality = self.card_estimator.estimate(&node.relation);
327            let cost = self.cost_model.estimate(&node.relation, cardinality);
328            return Some(JoinPlan {
329                nodes: BitSet::singleton(0),
330                operator: node.relation.clone(),
331                cost,
332                cardinality,
333            });
334        }
335
336        // Initialize with single relations
337        for (i, node) in self.graph.nodes.iter().enumerate() {
338            let subset = BitSet::singleton(i);
339            let cardinality = self.card_estimator.estimate(&node.relation);
340            let cost = self.cost_model.estimate(&node.relation, cardinality);
341            self.memo.insert(
342                subset,
343                JoinPlan {
344                    nodes: subset,
345                    operator: node.relation.clone(),
346                    cost,
347                    cardinality,
348                },
349            );
350        }
351
352        // Enumerate connected subgraph pairs (ccp)
353        let full_set = BitSet::full(n);
354        self.enumerate_ccp(full_set);
355
356        // Return the best plan for the full set
357        self.memo.get(&full_set).cloned()
358    }
359
360    /// Enumerates connected complement pairs using DPccp algorithm.
361    fn enumerate_ccp(&mut self, s: BitSet) {
362        // Iterate over all proper non-empty subsets
363        for s1 in s.subsets() {
364            if s1.is_empty() || s1 == s {
365                continue;
366            }
367
368            let s2 = s.difference(s1);
369            if s2.is_empty() {
370                continue;
371            }
372
373            // Both s1 and s2 must be connected subsets
374            if !self.is_connected(s1) || !self.is_connected(s2) {
375                continue;
376            }
377
378            // s1 and s2 must be connected to each other
379            if !self.graph.are_connected(&s1, &s2) {
380                continue;
381            }
382
383            // Recursively solve subproblems
384            if !self.memo.contains_key(&s1) {
385                self.enumerate_ccp(s1);
386            }
387            if !self.memo.contains_key(&s2) {
388                self.enumerate_ccp(s2);
389            }
390
391            // Try to build a plan for s by joining s1 and s2
392            if let (Some(plan1), Some(plan2)) = (self.memo.get(&s1), self.memo.get(&s2)) {
393                let conditions = self.graph.get_conditions(&s1, &s2);
394                let new_plan = self.build_join_plan(plan1.clone(), plan2.clone(), conditions);
395
396                // Update memo if this is a better plan
397                let should_update = self
398                    .memo
399                    .get(&s)
400                    .map(|existing| new_plan.cost.total() < existing.cost.total())
401                    .unwrap_or(true);
402
403                if should_update {
404                    self.memo.insert(s, new_plan);
405                }
406            }
407        }
408    }
409
410    /// Checks if a subset forms a connected subgraph.
411    fn is_connected(&self, subset: BitSet) -> bool {
412        if subset.len() <= 1 {
413            return true;
414        }
415
416        // BFS to check connectivity
417        // Invariant: subset.len() >= 2 (guard on line 400), so iter().next() returns Some
418        let start = subset
419            .iter()
420            .next()
421            .expect("subset is non-empty: len >= 2 checked on line 400");
422        let mut visited = BitSet::singleton(start);
423        let mut queue = vec![start];
424
425        while let Some(node) = queue.pop() {
426            for neighbor in self.graph.neighbors(node) {
427                if subset.contains(neighbor) && !visited.contains(neighbor) {
428                    visited.insert(neighbor);
429                    queue.push(neighbor);
430                }
431            }
432        }
433
434        visited == subset
435    }
436
437    /// Builds a join plan from two sub-plans.
438    fn build_join_plan(
439        &self,
440        left: JoinPlan,
441        right: JoinPlan,
442        conditions: Vec<JoinCondition>,
443    ) -> JoinPlan {
444        let nodes = left.nodes.union(right.nodes);
445
446        // Create the join operator
447        let join_op = LogicalOperator::Join(JoinOp {
448            left: Box::new(left.operator),
449            right: Box::new(right.operator),
450            join_type: JoinType::Inner,
451            conditions,
452        });
453
454        // Estimate cardinality
455        let cardinality = self.card_estimator.estimate(&join_op);
456
457        // Calculate cost (child costs + join cost)
458        let join_cost = self.cost_model.estimate(&join_op, cardinality);
459        let total_cost = left.cost + right.cost + join_cost;
460
461        JoinPlan {
462            nodes,
463            operator: join_op,
464            cost: total_cost,
465            cardinality,
466        }
467    }
468}
469
470/// Extracts a join graph from a query pattern.
471pub struct JoinGraphBuilder {
472    graph: JoinGraph,
473    variable_to_node: HashMap<String, usize>,
474}
475
476impl JoinGraphBuilder {
477    /// Creates a new builder.
478    pub fn new() -> Self {
479        Self {
480            graph: JoinGraph::new(),
481            variable_to_node: HashMap::new(),
482        }
483    }
484
485    /// Adds a base relation (scan).
486    pub fn add_relation(&mut self, variable: &str, relation: LogicalOperator) -> usize {
487        let id = self.graph.add_node(variable.to_string(), relation);
488        self.variable_to_node.insert(variable.to_string(), id);
489        id
490    }
491
492    /// Adds a join condition between two variables.
493    pub fn add_join_condition(
494        &mut self,
495        left_var: &str,
496        right_var: &str,
497        left_expr: LogicalExpression,
498        right_expr: LogicalExpression,
499    ) {
500        if let (Some(&left_id), Some(&right_id)) = (
501            self.variable_to_node.get(left_var),
502            self.variable_to_node.get(right_var),
503        ) {
504            self.graph.add_edge(
505                left_id,
506                right_id,
507                vec![JoinCondition {
508                    left: left_expr,
509                    right: right_expr,
510                }],
511            );
512        }
513    }
514
515    /// Builds the join graph.
516    pub fn build(self) -> JoinGraph {
517        self.graph
518    }
519}
520
521impl Default for JoinGraphBuilder {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::query::plan::NodeScanOp;
531
532    fn create_node_scan(var: &str, label: &str) -> LogicalOperator {
533        LogicalOperator::NodeScan(NodeScanOp {
534            variable: var.to_string(),
535            label: Some(label.to_string()),
536            input: None,
537        })
538    }
539
540    #[test]
541    fn test_bitset_operations() {
542        let a = BitSet::singleton(0);
543        let b = BitSet::singleton(1);
544        let _c = BitSet::singleton(2);
545
546        assert!(a.contains(0));
547        assert!(!a.contains(1));
548
549        let ab = a.union(b);
550        assert!(ab.contains(0));
551        assert!(ab.contains(1));
552        assert!(!ab.contains(2));
553
554        let full = BitSet::full(3);
555        assert_eq!(full.len(), 3);
556        assert!(full.contains(0));
557        assert!(full.contains(1));
558        assert!(full.contains(2));
559    }
560
561    #[test]
562    fn test_subset_iteration() {
563        let set = BitSet::from_iter([0, 1].into_iter());
564        let subsets: Vec<_> = set.subsets().collect();
565
566        // Should have 4 subsets: {}, {0}, {1}, {0,1}
567        assert_eq!(subsets.len(), 4);
568        assert!(subsets.contains(&BitSet::empty()));
569        assert!(subsets.contains(&BitSet::singleton(0)));
570        assert!(subsets.contains(&BitSet::singleton(1)));
571        assert!(subsets.contains(&set));
572    }
573
574    #[test]
575    fn test_join_graph_construction() {
576        let mut builder = JoinGraphBuilder::new();
577
578        builder.add_relation("a", create_node_scan("a", "Person"));
579        builder.add_relation("b", create_node_scan("b", "Person"));
580        builder.add_relation("c", create_node_scan("c", "Company"));
581
582        builder.add_join_condition(
583            "a",
584            "b",
585            LogicalExpression::Property {
586                variable: "a".to_string(),
587                property: "id".to_string(),
588            },
589            LogicalExpression::Property {
590                variable: "b".to_string(),
591                property: "friend_id".to_string(),
592            },
593        );
594
595        builder.add_join_condition(
596            "a",
597            "c",
598            LogicalExpression::Property {
599                variable: "a".to_string(),
600                property: "company_id".to_string(),
601            },
602            LogicalExpression::Property {
603                variable: "c".to_string(),
604                property: "id".to_string(),
605            },
606        );
607
608        let graph = builder.build();
609        assert_eq!(graph.node_count(), 3);
610    }
611
612    #[test]
613    fn test_dpccp_single_relation() {
614        let mut builder = JoinGraphBuilder::new();
615        builder.add_relation("a", create_node_scan("a", "Person"));
616        let graph = builder.build();
617
618        let cost_model = CostModel::new();
619        let mut card_estimator = CardinalityEstimator::new();
620        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
621
622        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
623        let plan = dpccp.optimize();
624
625        assert!(plan.is_some());
626        let plan = plan.unwrap();
627        assert_eq!(plan.nodes.len(), 1);
628    }
629
630    #[test]
631    fn test_dpccp_two_relations() {
632        let mut builder = JoinGraphBuilder::new();
633        builder.add_relation("a", create_node_scan("a", "Person"));
634        builder.add_relation("b", create_node_scan("b", "Company"));
635
636        builder.add_join_condition(
637            "a",
638            "b",
639            LogicalExpression::Property {
640                variable: "a".to_string(),
641                property: "company_id".to_string(),
642            },
643            LogicalExpression::Property {
644                variable: "b".to_string(),
645                property: "id".to_string(),
646            },
647        );
648
649        let graph = builder.build();
650
651        let cost_model = CostModel::new();
652        let mut card_estimator = CardinalityEstimator::new();
653        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
654        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
655
656        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
657        let plan = dpccp.optimize();
658
659        assert!(plan.is_some());
660        let plan = plan.unwrap();
661        assert_eq!(plan.nodes.len(), 2);
662
663        // The result should be a join
664        if let LogicalOperator::Join(_) = plan.operator {
665            // Good
666        } else {
667            panic!("Expected Join operator");
668        }
669    }
670
671    #[test]
672    fn test_dpccp_three_relations_chain() {
673        // a -[knows]-> b -[works_at]-> c
674        let mut builder = JoinGraphBuilder::new();
675        builder.add_relation("a", create_node_scan("a", "Person"));
676        builder.add_relation("b", create_node_scan("b", "Person"));
677        builder.add_relation("c", create_node_scan("c", "Company"));
678
679        builder.add_join_condition(
680            "a",
681            "b",
682            LogicalExpression::Property {
683                variable: "a".to_string(),
684                property: "knows".to_string(),
685            },
686            LogicalExpression::Property {
687                variable: "b".to_string(),
688                property: "id".to_string(),
689            },
690        );
691
692        builder.add_join_condition(
693            "b",
694            "c",
695            LogicalExpression::Property {
696                variable: "b".to_string(),
697                property: "company_id".to_string(),
698            },
699            LogicalExpression::Property {
700                variable: "c".to_string(),
701                property: "id".to_string(),
702            },
703        );
704
705        let graph = builder.build();
706
707        let cost_model = CostModel::new();
708        let mut card_estimator = CardinalityEstimator::new();
709        card_estimator.add_table_stats("Person", super::super::cardinality::TableStats::new(1000));
710        card_estimator.add_table_stats("Company", super::super::cardinality::TableStats::new(100));
711
712        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
713        let plan = dpccp.optimize();
714
715        assert!(plan.is_some());
716        let plan = plan.unwrap();
717        assert_eq!(plan.nodes.len(), 3);
718    }
719
720    #[test]
721    fn test_dpccp_prefers_smaller_intermediate() {
722        // Test that DPccp prefers joining smaller tables first
723        // Setup: Small (100) -[r1]-> Medium (1000) -[r2]-> Large (10000)
724        // Without cost-based ordering, might get (Small ⋈ Large) ⋈ Medium
725        // With cost-based ordering, should get (Small ⋈ Medium) ⋈ Large
726
727        let mut builder = JoinGraphBuilder::new();
728        builder.add_relation("s", create_node_scan("s", "Small"));
729        builder.add_relation("m", create_node_scan("m", "Medium"));
730        builder.add_relation("l", create_node_scan("l", "Large"));
731
732        // Connect all three (star schema)
733        builder.add_join_condition(
734            "s",
735            "m",
736            LogicalExpression::Property {
737                variable: "s".to_string(),
738                property: "id".to_string(),
739            },
740            LogicalExpression::Property {
741                variable: "m".to_string(),
742                property: "s_id".to_string(),
743            },
744        );
745
746        builder.add_join_condition(
747            "m",
748            "l",
749            LogicalExpression::Property {
750                variable: "m".to_string(),
751                property: "id".to_string(),
752            },
753            LogicalExpression::Property {
754                variable: "l".to_string(),
755                property: "m_id".to_string(),
756            },
757        );
758
759        let graph = builder.build();
760
761        let cost_model = CostModel::new();
762        let mut card_estimator = CardinalityEstimator::new();
763        card_estimator.add_table_stats("Small", super::super::cardinality::TableStats::new(100));
764        card_estimator.add_table_stats("Medium", super::super::cardinality::TableStats::new(1000));
765        card_estimator.add_table_stats("Large", super::super::cardinality::TableStats::new(10000));
766
767        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
768        let plan = dpccp.optimize();
769
770        assert!(plan.is_some());
771        let plan = plan.unwrap();
772
773        // The plan should cover all three relations
774        assert_eq!(plan.nodes.len(), 3);
775
776        // We can't easily verify the exact join order without inspecting the tree,
777        // but we can verify the plan was created successfully
778        assert!(plan.cost.total() > 0.0);
779    }
780
781    // Additional BitSet tests
782
783    #[test]
784    fn test_bitset_empty() {
785        let empty = BitSet::empty();
786        assert!(empty.is_empty());
787        assert_eq!(empty.len(), 0);
788        assert!(!empty.contains(0));
789    }
790
791    #[test]
792    fn test_bitset_insert_remove() {
793        let mut set = BitSet::empty();
794        set.insert(3);
795        assert!(set.contains(3));
796        assert_eq!(set.len(), 1);
797
798        set.insert(5);
799        assert!(set.contains(5));
800        assert_eq!(set.len(), 2);
801
802        set.remove(3);
803        assert!(!set.contains(3));
804        assert_eq!(set.len(), 1);
805    }
806
807    #[test]
808    fn test_bitset_intersection() {
809        let a = BitSet::from_iter([0, 1, 2].into_iter());
810        let b = BitSet::from_iter([1, 2, 3].into_iter());
811        let intersection = a.intersection(b);
812
813        assert!(intersection.contains(1));
814        assert!(intersection.contains(2));
815        assert!(!intersection.contains(0));
816        assert!(!intersection.contains(3));
817        assert_eq!(intersection.len(), 2);
818    }
819
820    #[test]
821    fn test_bitset_difference() {
822        let a = BitSet::from_iter([0, 1, 2].into_iter());
823        let b = BitSet::from_iter([1, 2, 3].into_iter());
824        let diff = a.difference(b);
825
826        assert!(diff.contains(0));
827        assert!(!diff.contains(1));
828        assert!(!diff.contains(2));
829        assert_eq!(diff.len(), 1);
830    }
831
832    #[test]
833    fn test_bitset_is_subset_of() {
834        let a = BitSet::from_iter([1, 2].into_iter());
835        let b = BitSet::from_iter([0, 1, 2, 3].into_iter());
836
837        assert!(a.is_subset_of(b));
838        assert!(!b.is_subset_of(a));
839        assert!(a.is_subset_of(a));
840    }
841
842    #[test]
843    fn test_bitset_iter() {
844        let set = BitSet::from_iter([0, 2, 5].into_iter());
845        let elements: Vec<_> = set.iter().collect();
846
847        assert_eq!(elements, vec![0, 2, 5]);
848    }
849
850    // Additional JoinGraph tests
851
852    #[test]
853    fn test_join_graph_empty() {
854        let graph = JoinGraph::new();
855        assert_eq!(graph.node_count(), 0);
856    }
857
858    #[test]
859    fn test_join_graph_neighbors() {
860        let mut builder = JoinGraphBuilder::new();
861        builder.add_relation("a", create_node_scan("a", "A"));
862        builder.add_relation("b", create_node_scan("b", "B"));
863        builder.add_relation("c", create_node_scan("c", "C"));
864
865        builder.add_join_condition(
866            "a",
867            "b",
868            LogicalExpression::Variable("a".to_string()),
869            LogicalExpression::Variable("b".to_string()),
870        );
871        builder.add_join_condition(
872            "a",
873            "c",
874            LogicalExpression::Variable("a".to_string()),
875            LogicalExpression::Variable("c".to_string()),
876        );
877
878        let graph = builder.build();
879
880        // 'a' should have neighbors 'b' and 'c' (indices 1 and 2)
881        let neighbors_a: Vec<_> = graph.neighbors(0).collect();
882        assert_eq!(neighbors_a.len(), 2);
883        assert!(neighbors_a.contains(&1));
884        assert!(neighbors_a.contains(&2));
885
886        // 'b' should have only neighbor 'a'
887        let neighbors_b: Vec<_> = graph.neighbors(1).collect();
888        assert_eq!(neighbors_b.len(), 1);
889        assert!(neighbors_b.contains(&0));
890    }
891
892    #[test]
893    fn test_join_graph_are_connected() {
894        let mut builder = JoinGraphBuilder::new();
895        builder.add_relation("a", create_node_scan("a", "A"));
896        builder.add_relation("b", create_node_scan("b", "B"));
897        builder.add_relation("c", create_node_scan("c", "C"));
898
899        builder.add_join_condition(
900            "a",
901            "b",
902            LogicalExpression::Variable("a".to_string()),
903            LogicalExpression::Variable("b".to_string()),
904        );
905
906        let graph = builder.build();
907
908        let set_a = BitSet::singleton(0);
909        let set_b = BitSet::singleton(1);
910        let set_c = BitSet::singleton(2);
911
912        assert!(graph.are_connected(&set_a, &set_b));
913        assert!(graph.are_connected(&set_b, &set_a));
914        assert!(!graph.are_connected(&set_a, &set_c));
915        assert!(!graph.are_connected(&set_b, &set_c));
916    }
917
918    #[test]
919    fn test_join_graph_get_conditions() {
920        let mut builder = JoinGraphBuilder::new();
921        builder.add_relation("a", create_node_scan("a", "A"));
922        builder.add_relation("b", create_node_scan("b", "B"));
923
924        builder.add_join_condition(
925            "a",
926            "b",
927            LogicalExpression::Property {
928                variable: "a".to_string(),
929                property: "id".to_string(),
930            },
931            LogicalExpression::Property {
932                variable: "b".to_string(),
933                property: "a_id".to_string(),
934            },
935        );
936
937        let graph = builder.build();
938
939        let set_a = BitSet::singleton(0);
940        let set_b = BitSet::singleton(1);
941
942        let conditions = graph.get_conditions(&set_a, &set_b);
943        assert_eq!(conditions.len(), 1);
944    }
945
946    // Additional DPccp tests
947
948    #[test]
949    fn test_dpccp_empty_graph() {
950        let graph = JoinGraph::new();
951        let cost_model = CostModel::new();
952        let card_estimator = CardinalityEstimator::new();
953
954        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
955        let plan = dpccp.optimize();
956
957        assert!(plan.is_none());
958    }
959
960    #[test]
961    fn test_dpccp_star_query() {
962        // Star schema: center connected to all others
963        // center -> a, center -> b, center -> c
964        let mut builder = JoinGraphBuilder::new();
965        builder.add_relation("center", create_node_scan("center", "Center"));
966        builder.add_relation("a", create_node_scan("a", "A"));
967        builder.add_relation("b", create_node_scan("b", "B"));
968        builder.add_relation("c", create_node_scan("c", "C"));
969
970        builder.add_join_condition(
971            "center",
972            "a",
973            LogicalExpression::Variable("center".to_string()),
974            LogicalExpression::Variable("a".to_string()),
975        );
976        builder.add_join_condition(
977            "center",
978            "b",
979            LogicalExpression::Variable("center".to_string()),
980            LogicalExpression::Variable("b".to_string()),
981        );
982        builder.add_join_condition(
983            "center",
984            "c",
985            LogicalExpression::Variable("center".to_string()),
986            LogicalExpression::Variable("c".to_string()),
987        );
988
989        let graph = builder.build();
990
991        let cost_model = CostModel::new();
992        let mut card_estimator = CardinalityEstimator::new();
993        card_estimator.add_table_stats("Center", super::super::cardinality::TableStats::new(100));
994        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(1000));
995        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(500));
996        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(200));
997
998        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
999        let plan = dpccp.optimize();
1000
1001        assert!(plan.is_some());
1002        let plan = plan.unwrap();
1003        assert_eq!(plan.nodes.len(), 4);
1004        assert!(plan.cost.total() > 0.0);
1005    }
1006
1007    #[test]
1008    fn test_dpccp_cycle_query() {
1009        // Cycle: a -> b -> c -> a
1010        let mut builder = JoinGraphBuilder::new();
1011        builder.add_relation("a", create_node_scan("a", "A"));
1012        builder.add_relation("b", create_node_scan("b", "B"));
1013        builder.add_relation("c", create_node_scan("c", "C"));
1014
1015        builder.add_join_condition(
1016            "a",
1017            "b",
1018            LogicalExpression::Variable("a".to_string()),
1019            LogicalExpression::Variable("b".to_string()),
1020        );
1021        builder.add_join_condition(
1022            "b",
1023            "c",
1024            LogicalExpression::Variable("b".to_string()),
1025            LogicalExpression::Variable("c".to_string()),
1026        );
1027        builder.add_join_condition(
1028            "c",
1029            "a",
1030            LogicalExpression::Variable("c".to_string()),
1031            LogicalExpression::Variable("a".to_string()),
1032        );
1033
1034        let graph = builder.build();
1035
1036        let cost_model = CostModel::new();
1037        let mut card_estimator = CardinalityEstimator::new();
1038        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1039        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(100));
1040        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(100));
1041
1042        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1043        let plan = dpccp.optimize();
1044
1045        assert!(plan.is_some());
1046        let plan = plan.unwrap();
1047        assert_eq!(plan.nodes.len(), 3);
1048    }
1049
1050    #[test]
1051    fn test_dpccp_four_relations() {
1052        // Chain: a -> b -> c -> d
1053        let mut builder = JoinGraphBuilder::new();
1054        builder.add_relation("a", create_node_scan("a", "A"));
1055        builder.add_relation("b", create_node_scan("b", "B"));
1056        builder.add_relation("c", create_node_scan("c", "C"));
1057        builder.add_relation("d", create_node_scan("d", "D"));
1058
1059        builder.add_join_condition(
1060            "a",
1061            "b",
1062            LogicalExpression::Variable("a".to_string()),
1063            LogicalExpression::Variable("b".to_string()),
1064        );
1065        builder.add_join_condition(
1066            "b",
1067            "c",
1068            LogicalExpression::Variable("b".to_string()),
1069            LogicalExpression::Variable("c".to_string()),
1070        );
1071        builder.add_join_condition(
1072            "c",
1073            "d",
1074            LogicalExpression::Variable("c".to_string()),
1075            LogicalExpression::Variable("d".to_string()),
1076        );
1077
1078        let graph = builder.build();
1079
1080        let cost_model = CostModel::new();
1081        let mut card_estimator = CardinalityEstimator::new();
1082        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1083        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1084        card_estimator.add_table_stats("C", super::super::cardinality::TableStats::new(300));
1085        card_estimator.add_table_stats("D", super::super::cardinality::TableStats::new(400));
1086
1087        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1088        let plan = dpccp.optimize();
1089
1090        assert!(plan.is_some());
1091        let plan = plan.unwrap();
1092        assert_eq!(plan.nodes.len(), 4);
1093    }
1094
1095    #[test]
1096    fn test_join_plan_cardinality_and_cost() {
1097        let mut builder = JoinGraphBuilder::new();
1098        builder.add_relation("a", create_node_scan("a", "A"));
1099        builder.add_relation("b", create_node_scan("b", "B"));
1100
1101        builder.add_join_condition(
1102            "a",
1103            "b",
1104            LogicalExpression::Variable("a".to_string()),
1105            LogicalExpression::Variable("b".to_string()),
1106        );
1107
1108        let graph = builder.build();
1109
1110        let cost_model = CostModel::new();
1111        let mut card_estimator = CardinalityEstimator::new();
1112        card_estimator.add_table_stats("A", super::super::cardinality::TableStats::new(100));
1113        card_estimator.add_table_stats("B", super::super::cardinality::TableStats::new(200));
1114
1115        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1116        let plan = dpccp.optimize().unwrap();
1117
1118        // Plan should have non-zero cardinality and cost
1119        assert!(plan.cardinality > 0.0);
1120        assert!(plan.cost.total() > 0.0);
1121    }
1122
1123    #[test]
1124    fn test_join_graph_default() {
1125        let graph = JoinGraph::default();
1126        assert_eq!(graph.node_count(), 0);
1127    }
1128
1129    #[test]
1130    fn test_join_graph_builder_default() {
1131        let builder = JoinGraphBuilder::default();
1132        let graph = builder.build();
1133        assert_eq!(graph.node_count(), 0);
1134    }
1135
1136    #[test]
1137    fn test_join_graph_nodes_accessor() {
1138        let mut builder = JoinGraphBuilder::new();
1139        builder.add_relation("a", create_node_scan("a", "A"));
1140        builder.add_relation("b", create_node_scan("b", "B"));
1141
1142        let graph = builder.build();
1143        let nodes = graph.nodes();
1144
1145        assert_eq!(nodes.len(), 2);
1146        assert_eq!(nodes[0].variable, "a");
1147        assert_eq!(nodes[1].variable, "b");
1148    }
1149
1150    #[test]
1151    fn test_join_node_equality() {
1152        let node1 = JoinNode {
1153            id: 0,
1154            variable: "a".to_string(),
1155            relation: create_node_scan("a", "A"),
1156        };
1157        let node2 = JoinNode {
1158            id: 0,
1159            variable: "a".to_string(),
1160            relation: create_node_scan("a", "A"),
1161        };
1162        let node3 = JoinNode {
1163            id: 1,
1164            variable: "a".to_string(),
1165            relation: create_node_scan("a", "A"),
1166        };
1167
1168        assert_eq!(node1, node2);
1169        assert_ne!(node1, node3);
1170    }
1171
1172    #[test]
1173    fn test_join_node_hash() {
1174        use std::collections::HashSet;
1175
1176        let node1 = JoinNode {
1177            id: 0,
1178            variable: "a".to_string(),
1179            relation: create_node_scan("a", "A"),
1180        };
1181        let node2 = JoinNode {
1182            id: 0,
1183            variable: "a".to_string(),
1184            relation: create_node_scan("a", "A"),
1185        };
1186
1187        let mut set = HashSet::new();
1188        set.insert(node1.clone());
1189
1190        // Same id and variable should be considered equal
1191        assert!(set.contains(&node2));
1192    }
1193
1194    #[test]
1195    fn test_add_join_condition_unknown_variable() {
1196        let mut builder = JoinGraphBuilder::new();
1197        builder.add_relation("a", create_node_scan("a", "A"));
1198
1199        // Adding condition with unknown variable should do nothing (no panic)
1200        builder.add_join_condition(
1201            "a",
1202            "unknown",
1203            LogicalExpression::Variable("a".to_string()),
1204            LogicalExpression::Variable("unknown".to_string()),
1205        );
1206
1207        let graph = builder.build();
1208        assert_eq!(graph.node_count(), 1);
1209    }
1210
1211    #[test]
1212    fn test_dpccp_with_different_cardinalities() {
1213        // Test that DPccp handles vastly different cardinalities
1214        let mut builder = JoinGraphBuilder::new();
1215        builder.add_relation("tiny", create_node_scan("tiny", "Tiny"));
1216        builder.add_relation("huge", create_node_scan("huge", "Huge"));
1217
1218        builder.add_join_condition(
1219            "tiny",
1220            "huge",
1221            LogicalExpression::Variable("tiny".to_string()),
1222            LogicalExpression::Variable("huge".to_string()),
1223        );
1224
1225        let graph = builder.build();
1226
1227        let cost_model = CostModel::new();
1228        let mut card_estimator = CardinalityEstimator::new();
1229        card_estimator.add_table_stats("Tiny", super::super::cardinality::TableStats::new(10));
1230        card_estimator.add_table_stats("Huge", super::super::cardinality::TableStats::new(1000000));
1231
1232        let mut dpccp = DPccp::new(&graph, &cost_model, &card_estimator);
1233        let plan = dpccp.optimize();
1234
1235        assert!(plan.is_some());
1236        let plan = plan.unwrap();
1237        assert_eq!(plan.nodes.len(), 2);
1238    }
1239}