Skip to main content

cynos_query/executor/join/
merge.rs

1//! Sort-Merge Join implementation.
2
3use crate::executor::{Relation, RelationEntry};
4use alloc::vec::Vec;
5use cynos_core::Value;
6use core::cmp::Ordering;
7
8/// Sort-Merge Join executor.
9///
10/// Efficient for joining pre-sorted relations or when both inputs
11/// can be sorted efficiently.
12pub struct SortMergeJoin {
13    /// Column index for the left relation.
14    left_key_index: usize,
15    /// Column index for the right relation.
16    right_key_index: usize,
17    /// Whether this is an outer join.
18    is_outer_join: bool,
19}
20
21impl SortMergeJoin {
22    /// Creates a new sort-merge join executor.
23    pub fn new(left_key_index: usize, right_key_index: usize, is_outer_join: bool) -> Self {
24        Self {
25            left_key_index,
26            right_key_index,
27            is_outer_join,
28        }
29    }
30
31    /// Creates an inner sort-merge join.
32    pub fn inner(left_key_index: usize, right_key_index: usize) -> Self {
33        Self::new(left_key_index, right_key_index, false)
34    }
35
36    /// Creates a left outer sort-merge join.
37    pub fn left_outer(left_key_index: usize, right_key_index: usize) -> Self {
38        Self::new(left_key_index, right_key_index, true)
39    }
40
41    /// Executes the sort-merge join.
42    /// Assumes both inputs are already sorted by their join keys.
43    pub fn execute(&self, left: Relation, right: Relation) -> Relation {
44        let mut result_entries = Vec::new();
45        let left_tables = left.tables().to_vec();
46        let right_tables = right.tables().to_vec();
47        let right_col_count = right
48            .entries
49            .first()
50            .map(|e| e.row.len())
51            .unwrap_or(0);
52
53        let left_entries: Vec<_> = left.entries.iter().collect();
54        let right_entries: Vec<_> = right.entries.iter().collect();
55
56        let mut left_idx = 0;
57        let mut right_idx = 0;
58
59        while left_idx < left_entries.len() {
60            let left_entry = left_entries[left_idx];
61            let left_value = left_entry.get_field(self.left_key_index);
62
63            // Handle null values
64            if left_value.map(|v| v.is_null()).unwrap_or(true) {
65                if self.is_outer_join {
66                    let combined = RelationEntry::combine_with_null(
67                        left_entry,
68                        &left_tables,
69                        right_col_count,
70                        &right_tables,
71                    );
72                    result_entries.push(combined);
73                }
74                left_idx += 1;
75                continue;
76            }
77
78            let left_val = left_value.unwrap();
79
80            // Skip right entries that are smaller than current left
81            while right_idx < right_entries.len() {
82                let right_value = right_entries[right_idx].get_field(self.right_key_index);
83                if right_value.map(|v| v.is_null()).unwrap_or(true) {
84                    right_idx += 1;
85                    continue;
86                }
87                if right_value.unwrap() < left_val {
88                    right_idx += 1;
89                } else {
90                    break;
91                }
92            }
93
94            // Find all matching right entries
95            let mut match_found = false;
96            let mut right_scan = right_idx;
97
98            while right_scan < right_entries.len() {
99                let right_entry = right_entries[right_scan];
100                let right_value = right_entry.get_field(self.right_key_index);
101
102                if right_value.map(|v| v.is_null()).unwrap_or(true) {
103                    right_scan += 1;
104                    continue;
105                }
106
107                let right_val = right_value.unwrap();
108
109                match left_val.cmp(right_val) {
110                    Ordering::Equal => {
111                        match_found = true;
112                        let combined = RelationEntry::combine(
113                            left_entry,
114                            &left_tables,
115                            right_entry,
116                            &right_tables,
117                        );
118                        result_entries.push(combined);
119                        right_scan += 1;
120                    }
121                    Ordering::Less => break,
122                    Ordering::Greater => {
123                        right_scan += 1;
124                    }
125                }
126            }
127
128            // For outer join, add unmatched left entries with nulls
129            if self.is_outer_join && !match_found {
130                let combined = RelationEntry::combine_with_null(
131                    left_entry,
132                    &left_tables,
133                    right_col_count,
134                    &right_tables,
135                );
136                result_entries.push(combined);
137            }
138
139            left_idx += 1;
140        }
141
142        let mut tables = left_tables;
143        tables.extend(right_tables);
144
145        // Compute combined table column counts
146        let mut table_column_counts = left.table_column_counts.clone();
147        table_column_counts.extend(right.table_column_counts.iter().cloned());
148
149        Relation {
150            entries: result_entries,
151            tables,
152            table_column_counts,
153        }
154    }
155
156    /// Executes the sort-merge join, sorting inputs first.
157    pub fn execute_with_sort(&self, mut left: Relation, mut right: Relation) -> Relation {
158        // Sort both relations by their join keys
159        left.entries.sort_by(|a, b| {
160            let a_val = a.get_field(self.left_key_index);
161            let b_val = b.get_field(self.left_key_index);
162            compare_values(a_val, b_val)
163        });
164
165        right.entries.sort_by(|a, b| {
166            let a_val = a.get_field(self.right_key_index);
167            let b_val = b.get_field(self.right_key_index);
168            compare_values(a_val, b_val)
169        });
170
171        self.execute(left, right)
172    }
173}
174
175fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
176    match (a, b) {
177        (None, None) => Ordering::Equal,
178        (None, Some(_)) => Ordering::Less,
179        (Some(_), None) => Ordering::Greater,
180        (Some(av), Some(bv)) => av.cmp(bv),
181    }
182}
183
184/// Performs a sort-merge join on pre-sorted slices.
185pub fn sort_merge_join<L, R, K, O, LK, RK, OF>(
186    left: &mut [L],
187    right: &mut [R],
188    left_key: LK,
189    right_key: RK,
190    output_fn: OF,
191) -> Vec<O>
192where
193    K: Ord,
194    LK: Fn(&L) -> K,
195    RK: Fn(&R) -> K,
196    OF: Fn(&L, &R) -> O,
197{
198    // Sort both inputs
199    left.sort_by(|a, b| left_key(a).cmp(&left_key(b)));
200    right.sort_by(|a, b| right_key(a).cmp(&right_key(b)));
201
202    let mut results = Vec::new();
203    let mut right_idx = 0;
204
205    for l in left.iter() {
206        let left_k = left_key(l);
207
208        // Skip right entries that are smaller
209        while right_idx < right.len() && right_key(&right[right_idx]) < left_k {
210            right_idx += 1;
211        }
212
213        // Find all matching right entries
214        let mut scan = right_idx;
215        while scan < right.len() {
216            let right_k = right_key(&right[scan]);
217            match left_k.cmp(&right_k) {
218                Ordering::Equal => {
219                    results.push(output_fn(l, &right[scan]));
220                    scan += 1;
221                }
222                Ordering::Less => break,
223                Ordering::Greater => {
224                    scan += 1;
225                }
226            }
227        }
228    }
229
230    results
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use cynos_core::Row;
237    use alloc::vec;
238
239    #[test]
240    fn test_sort_merge_join_inner() {
241        // Pre-sorted inputs
242        let left_rows = vec![
243            Row::new(0, vec![Value::Int64(1)]),
244            Row::new(1, vec![Value::Int64(2)]),
245            Row::new(2, vec![Value::Int64(3)]),
246        ];
247        let right_rows = vec![
248            Row::new(10, vec![Value::Int64(1)]),
249            Row::new(11, vec![Value::Int64(2)]),
250            Row::new(12, vec![Value::Int64(4)]),
251        ];
252
253        let left = Relation::from_rows_owned(left_rows, vec!["left".into()]);
254        let right = Relation::from_rows_owned(right_rows, vec!["right".into()]);
255
256        let join = SortMergeJoin::inner(0, 0);
257        let result = join.execute(left, right);
258
259        // Should match on keys 1 and 2
260        assert_eq!(result.len(), 2);
261    }
262
263    #[test]
264    fn test_sort_merge_join_with_duplicates() {
265        let left_rows = vec![
266            Row::new(0, vec![Value::Int64(1)]),
267            Row::new(1, vec![Value::Int64(1)]),
268            Row::new(2, vec![Value::Int64(2)]),
269        ];
270        let right_rows = vec![
271            Row::new(10, vec![Value::Int64(1)]),
272            Row::new(11, vec![Value::Int64(1)]),
273        ];
274
275        let left = Relation::from_rows_owned(left_rows, vec!["left".into()]);
276        let right = Relation::from_rows_owned(right_rows, vec!["right".into()]);
277
278        let join = SortMergeJoin::inner(0, 0);
279        let result = join.execute(left, right);
280
281        // 2 left rows with key 1 * 2 right rows with key 1 = 4 matches
282        assert_eq!(result.len(), 4);
283    }
284
285    #[test]
286    fn test_sort_merge_join_left_outer() {
287        let left_rows = vec![
288            Row::new(0, vec![Value::Int64(1)]),
289            Row::new(1, vec![Value::Int64(2)]),
290            Row::new(2, vec![Value::Int64(3)]),
291        ];
292        let right_rows = vec![
293            Row::new(10, vec![Value::Int64(1)]),
294        ];
295
296        let left = Relation::from_rows_owned(left_rows, vec!["left".into()]);
297        let right = Relation::from_rows_owned(right_rows, vec!["right".into()]);
298
299        let join = SortMergeJoin::left_outer(0, 0);
300        let result = join.execute(left, right);
301
302        // 1 match + 2 unmatched with nulls
303        assert_eq!(result.len(), 3);
304    }
305
306    #[test]
307    fn test_sort_merge_join_function() {
308        let mut left = vec![(3, "C"), (1, "A"), (2, "B")];
309        let mut right = vec![(2, "Y"), (1, "X"), (4, "Z")];
310
311        let result = sort_merge_join(
312            &mut left,
313            &mut right,
314            |l| l.0,
315            |r| r.0,
316            |l, r| (l.1, r.1),
317        );
318
319        assert_eq!(result.len(), 2);
320        assert!(result.contains(&("A", "X")));
321        assert!(result.contains(&("B", "Y")));
322    }
323
324    #[test]
325    fn test_sort_merge_join_with_sort() {
326        // Unsorted inputs
327        let left_rows = vec![
328            Row::new(0, vec![Value::Int64(3)]),
329            Row::new(1, vec![Value::Int64(1)]),
330            Row::new(2, vec![Value::Int64(2)]),
331        ];
332        let right_rows = vec![
333            Row::new(10, vec![Value::Int64(2)]),
334            Row::new(11, vec![Value::Int64(1)]),
335        ];
336
337        let left = Relation::from_rows_owned(left_rows, vec!["left".into()]);
338        let right = Relation::from_rows_owned(right_rows, vec!["right".into()]);
339
340        let join = SortMergeJoin::inner(0, 0);
341        let result = join.execute_with_sort(left, right);
342
343        // Should match on keys 1 and 2
344        assert_eq!(result.len(), 2);
345    }
346}