Skip to main content

nexcore_dataframe/
join.rs

1//! Join: hash-based join operations on DataFrames.
2//!
3//! Supports 6 join types: Inner, Left, Right, Outer, Semi, Anti.
4//! Algorithm: hash join — build index on right table, probe from left.
5//!
6//! Primitive composition: μ(Mapping) + κ(Comparison) + ∂(Boundary) + ς(State)
7
8// HashMap is essential for O(1) join-key lookup during the probe phase.
9// Output row order follows left-table order for deterministic results.
10#[allow(
11    clippy::disallowed_types,
12    reason = "HashMap needed for O(1) hash-join probe; output follows left-table order, which is deterministic"
13)]
14use std::collections::HashMap;
15
16use crate::column::Column;
17use crate::dataframe::DataFrame;
18use crate::error::DataFrameError;
19use crate::scalar::Scalar;
20
21/// Null sentinel for key representation. Uses NUL bytes to avoid collision
22/// with any legitimate string value (improvement over GroupBy's "null" literal).
23const NULL_SENTINEL: &str = "\0NULL\0";
24
25/// Join type enumeration.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum JoinType {
29    /// Keep only rows that match in both tables.
30    Inner,
31    /// Keep all left rows; fill right with nulls where no match.
32    Left,
33    /// Keep all right rows; fill left with nulls where no match.
34    Right,
35    /// Keep all rows from both tables; fill with nulls where no match.
36    Outer,
37    /// Keep left rows that have at least one match in right (no right columns added).
38    Semi,
39    /// Keep left rows that have NO match in right (no right columns added).
40    Anti,
41}
42
43impl DataFrame {
44    /// Join two DataFrames on shared key column names.
45    ///
46    /// Equivalent to `self.join_on(other, on, on, how)` — both tables use
47    /// the same column names as join keys.
48    pub fn join(
49        &self,
50        other: &DataFrame,
51        on: &[&str],
52        how: JoinType,
53    ) -> Result<DataFrame, DataFrameError> {
54        self.join_on(other, on, on, how)
55    }
56
57    /// Join two DataFrames with potentially different key column names.
58    ///
59    /// `left_on` columns from `self`, `right_on` columns from `other`.
60    /// Key columns must have the same count. Column name collisions in non-key
61    /// columns are resolved with `_left` / `_right` suffixes.
62    pub fn join_on(
63        &self,
64        other: &DataFrame,
65        left_on: &[&str],
66        right_on: &[&str],
67        how: JoinType,
68    ) -> Result<DataFrame, DataFrameError> {
69        // Validate key counts match
70        if left_on.len() != right_on.len() {
71            return Err(DataFrameError::JoinKeyMismatch {
72                left_count: left_on.len(),
73                right_count: right_on.len(),
74            });
75        }
76        if left_on.is_empty() {
77            return Err(DataFrameError::Other(
78                "join requires at least one key column".to_string(),
79            ));
80        }
81
82        // Validate key columns exist
83        for name in left_on {
84            self.column(name)?;
85        }
86        for name in right_on {
87            other.column(name)?;
88        }
89
90        // Build hash index on RIGHT table: key → Vec<row_index>
91        #[allow(
92            clippy::disallowed_types,
93            reason = "HashMap for O(1) hash-join index; see module-level allow"
94        )]
95        let mut right_index: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
96        for row_idx in 0..other.height() {
97            let key = extract_key(other, right_on, row_idx);
98            // Null keys never match (SQL standard) — skip indexing them
99            if key.iter().any(|k| k == NULL_SENTINEL) {
100                continue;
101            }
102            right_index.entry(key).or_default().push(row_idx);
103        }
104
105        // Probe from LEFT table
106        match how {
107            JoinType::Semi => self.join_semi(left_on, &right_index),
108            JoinType::Anti => self.join_anti(left_on, &right_index),
109            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Outer => {
110                self.join_matching(other, left_on, right_on, how, &right_index)
111            }
112        }
113    }
114
115    /// Inner, Left, Right, Outer joins — produce combined columns.
116    #[allow(
117        clippy::disallowed_types,
118        reason = "HashMap parameter from join_on caller"
119    )]
120    #[allow(
121        clippy::too_many_arguments,
122        reason = "Internal method — 6 args needed for left/right context + index; not part of public API"
123    )]
124    fn join_matching(
125        &self,
126        other: &DataFrame,
127        left_on: &[&str],
128        right_on: &[&str],
129        how: JoinType,
130        right_index: &HashMap<Vec<String>, Vec<usize>>,
131    ) -> Result<DataFrame, DataFrameError> {
132        let mut left_indices: Vec<Option<usize>> = Vec::new();
133        let mut right_indices: Vec<Option<usize>> = Vec::new();
134
135        // Track which right rows were matched (for Right/Outer joins)
136        let mut right_matched = vec![false; other.height()];
137
138        // Probe left table against right index
139        for left_row in 0..self.height() {
140            let key = extract_key(self, left_on, left_row);
141
142            // Null keys never match — but Left/Outer preserves the left row
143            if key.iter().any(|k| k == NULL_SENTINEL) {
144                match how {
145                    JoinType::Left | JoinType::Outer => {
146                        left_indices.push(Some(left_row));
147                        right_indices.push(None);
148                    }
149                    JoinType::Inner | JoinType::Right | JoinType::Semi | JoinType::Anti => {}
150                }
151                continue;
152            }
153
154            match right_index.get(&key) {
155                Some(matches) => {
156                    for &right_row in matches {
157                        left_indices.push(Some(left_row));
158                        right_indices.push(Some(right_row));
159                        // right_row < other.height() by construction (built from 0..other.height())
160                        #[allow(
161                            clippy::indexing_slicing,
162                            reason = "right_row is from right_index which was built from 0..other.height()"
163                        )]
164                        {
165                            right_matched[right_row] = true;
166                        }
167                    }
168                }
169                None => match how {
170                    JoinType::Left | JoinType::Outer => {
171                        left_indices.push(Some(left_row));
172                        right_indices.push(None);
173                    }
174                    JoinType::Inner | JoinType::Right | JoinType::Semi | JoinType::Anti => {}
175                },
176            }
177        }
178
179        // Append unmatched right rows for Right/Outer joins
180        if matches!(how, JoinType::Right | JoinType::Outer) {
181            for (right_row, matched) in right_matched.iter().enumerate() {
182                if !matched {
183                    left_indices.push(None);
184                    right_indices.push(Some(right_row));
185                }
186            }
187        }
188
189        // Assemble result columns
190        assemble_columns(
191            self,
192            other,
193            left_on,
194            right_on,
195            &left_indices,
196            &right_indices,
197        )
198    }
199
200    /// Semi join: keep left rows that have at least one match in right.
201    #[allow(
202        clippy::disallowed_types,
203        reason = "HashMap parameter from join_on caller"
204    )]
205    fn join_semi(
206        &self,
207        left_on: &[&str],
208        right_index: &HashMap<Vec<String>, Vec<usize>>,
209    ) -> Result<DataFrame, DataFrameError> {
210        let mut keep_indices: Vec<usize> = Vec::new();
211
212        for left_row in 0..self.height() {
213            let key = extract_key(self, left_on, left_row);
214            if key.iter().any(|k| k == NULL_SENTINEL) {
215                continue;
216            }
217            if right_index.contains_key(&key) {
218                keep_indices.push(left_row);
219            }
220        }
221
222        Ok(DataFrame::from_columns_unchecked(
223            self.columns()
224                .iter()
225                .map(|c| c.take(&keep_indices))
226                .collect(),
227        ))
228    }
229
230    /// Anti join: keep left rows that have NO match in right.
231    #[allow(
232        clippy::disallowed_types,
233        reason = "HashMap parameter from join_on caller"
234    )]
235    fn join_anti(
236        &self,
237        left_on: &[&str],
238        right_index: &HashMap<Vec<String>, Vec<usize>>,
239    ) -> Result<DataFrame, DataFrameError> {
240        let mut keep_indices: Vec<usize> = Vec::new();
241
242        for left_row in 0..self.height() {
243            let key = extract_key(self, left_on, left_row);
244            // Null keys never match → they are always "unmatched" → kept in anti-join
245            if key.iter().any(|k| k == NULL_SENTINEL) {
246                keep_indices.push(left_row);
247                continue;
248            }
249            if !right_index.contains_key(&key) {
250                keep_indices.push(left_row);
251            }
252        }
253
254        Ok(DataFrame::from_columns_unchecked(
255            self.columns()
256                .iter()
257                .map(|c| c.take(&keep_indices))
258                .collect(),
259        ))
260    }
261}
262
263/// Extract the join key for a row as a Vec<String>.
264/// Null values become the NULL_SENTINEL.
265fn extract_key(df: &DataFrame, key_cols: &[&str], row_idx: usize) -> Vec<String> {
266    key_cols
267        .iter()
268        .map(|name| {
269            df.column(name)
270                .ok()
271                .and_then(|col| col.get(row_idx))
272                .map_or_else(
273                    || NULL_SENTINEL.to_string(),
274                    |s| {
275                        if s.is_null() {
276                            NULL_SENTINEL.to_string()
277                        } else {
278                            s.to_string()
279                        }
280                    },
281                )
282        })
283        .collect()
284}
285
286/// Assemble result columns from left/right index pairs.
287/// Key columns come from whichever side has a value (left preferred).
288/// Non-key columns are included from both sides with collision suffixes.
289#[allow(
290    clippy::too_many_arguments,
291    reason = "Assembling columns requires both table refs, both key slices, and both index vecs — cannot reduce without wrapper struct"
292)]
293fn assemble_columns(
294    left: &DataFrame,
295    right: &DataFrame,
296    left_on: &[&str],
297    right_on: &[&str],
298    left_indices: &[Option<usize>],
299    right_indices: &[Option<usize>],
300) -> Result<DataFrame, DataFrameError> {
301    let mut result_columns: Vec<Column> = Vec::new();
302    let left_names: Vec<&str> = left.column_names();
303    let right_names: Vec<&str> = right.column_names();
304
305    // Build sets for fast lookup
306    let left_key_set: Vec<&str> = left_on.to_vec();
307    let right_key_set: Vec<&str> = right_on.to_vec();
308
309    // Determine name collisions between non-key columns
310    let left_non_keys: Vec<&str> = left_names
311        .iter()
312        .filter(|n| !left_key_set.contains(n))
313        .copied()
314        .collect();
315    let right_non_keys: Vec<&str> = right_names
316        .iter()
317        .filter(|n| !right_key_set.contains(n))
318        .copied()
319        .collect();
320
321    let collisions: Vec<&str> = left_non_keys
322        .iter()
323        .filter(|n| right_non_keys.contains(n))
324        .copied()
325        .collect();
326
327    // 1. Key columns — take from left where available, else from right
328    for (i, &left_key_name) in left_on.iter().enumerate() {
329        let left_col = left.column(left_key_name)?;
330        // i < right_on.len() guaranteed by key count validation in join_on
331        #[allow(
332            clippy::indexing_slicing,
333            reason = "i < left_on.len() == right_on.len() by JoinKeyMismatch validation in join_on"
334        )]
335        let right_key_name = right_on[i];
336        let right_col = right.column(right_key_name)?;
337
338        let merged = merge_key_column(left_col, right_col, left_indices, right_indices);
339        result_columns.push(merged.rename(left_key_name));
340    }
341
342    // 2. Left non-key columns
343    for &name in &left_non_keys {
344        let col = left.column(name)?;
345        let taken = col.take_optional(left_indices);
346        if collisions.contains(&name) {
347            result_columns.push(taken.rename(format!("{name}_left")));
348        } else {
349            result_columns.push(taken);
350        }
351    }
352
353    // 3. Right non-key columns
354    for &name in &right_non_keys {
355        let col = right.column(name)?;
356        let taken = col.take_optional(right_indices);
357        if collisions.contains(&name) {
358            result_columns.push(taken.rename(format!("{name}_right")));
359        } else {
360            result_columns.push(taken);
361        }
362    }
363
364    Ok(DataFrame::from_columns_unchecked(result_columns))
365}
366
367/// Merge a key column from left and right: prefer left value, fall back to right.
368fn merge_key_column(
369    left_col: &Column,
370    right_col: &Column,
371    left_indices: &[Option<usize>],
372    right_indices: &[Option<usize>],
373) -> Column {
374    // Key columns are always string-representable for hashing, but we want to
375    // preserve the original type. Use the left column's type as canonical.
376    let data = left_col.data();
377
378    // Build element-by-element: prefer left value, fall back to right.
379    let len = left_indices.len();
380    match data {
381        crate::column::ColumnData::Bool(_) => {
382            let vals: Vec<Option<bool>> = (0..len)
383                .map(|i| {
384                    // i < left_indices.len() == right_indices.len() by zip construction
385                    #[allow(
386                        clippy::indexing_slicing,
387                        reason = "i iterates 0..len where len = left_indices.len()"
388                    )]
389                    match (left_indices[i], right_indices[i]) {
390                        (Some(li), _) => left_col.get(li).and_then(|s| s.as_bool()),
391                        (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_bool()),
392                        (None, None) => None,
393                    }
394                })
395                .collect();
396            Column::new_bool(left_col.name(), vals)
397        }
398        crate::column::ColumnData::Int64(_) => {
399            let vals: Vec<Option<i64>> = (0..len)
400                .map(|i| {
401                    #[allow(
402                        clippy::indexing_slicing,
403                        reason = "i iterates 0..len where len = left_indices.len()"
404                    )]
405                    match (left_indices[i], right_indices[i]) {
406                        (Some(li), _) => left_col.get(li).and_then(|s| s.as_i64()),
407                        (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_i64()),
408                        (None, None) => None,
409                    }
410                })
411                .collect();
412            Column::new_i64(left_col.name(), vals)
413        }
414        crate::column::ColumnData::UInt64(_) => {
415            let vals: Vec<Option<u64>> = (0..len)
416                .map(|i| {
417                    #[allow(
418                        clippy::indexing_slicing,
419                        reason = "i iterates 0..len where len = left_indices.len()"
420                    )]
421                    match (left_indices[i], right_indices[i]) {
422                        (Some(li), _) => left_col.get(li).and_then(|s| s.as_u64()),
423                        (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_u64()),
424                        (None, None) => None,
425                    }
426                })
427                .collect();
428            Column::new_u64(left_col.name(), vals)
429        }
430        crate::column::ColumnData::Float64(_) => {
431            let vals: Vec<Option<f64>> = (0..len)
432                .map(|i| {
433                    #[allow(
434                        clippy::indexing_slicing,
435                        reason = "i iterates 0..len where len = left_indices.len()"
436                    )]
437                    match (left_indices[i], right_indices[i]) {
438                        (Some(li), _) => left_col.get(li).and_then(|s| s.as_f64()),
439                        (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_f64()),
440                        (None, None) => None,
441                    }
442                })
443                .collect();
444            Column::new_f64(left_col.name(), vals)
445        }
446        crate::column::ColumnData::String(_) => {
447            let vals: Vec<Option<String>> = (0..len)
448                .map(|i| {
449                    #[allow(
450                        clippy::indexing_slicing,
451                        reason = "i iterates 0..len where len = left_indices.len()"
452                    )]
453                    match (left_indices[i], right_indices[i]) {
454                        (Some(li), _) => left_col
455                            .get(li)
456                            .and_then(|s| s.as_str().map(|s| s.to_string())),
457                        (None, Some(ri)) => right_col
458                            .get(ri)
459                            .and_then(|s| s.as_str().map(|s| s.to_string())),
460                        (None, None) => None,
461                    }
462                })
463                .collect();
464            Column::new_string(left_col.name(), vals)
465        }
466    }
467}
468
469// =============================================================================
470// Tests
471// =============================================================================
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    fn drugs() -> DataFrame {
478        DataFrame::new(vec![
479            Column::from_strs("drug_id", &["D1", "D2", "D3", "D4"]),
480            Column::from_strs(
481                "drug_name",
482                &["aspirin", "metformin", "ibuprofen", "lisinopril"],
483            ),
484        ])
485        .unwrap_or_else(|_| unreachable!())
486    }
487
488    fn events() -> DataFrame {
489        DataFrame::new(vec![
490            Column::from_strs("drug_id", &["D1", "D1", "D2", "D5"]),
491            Column::from_strs("event", &["headache", "nausea", "rash", "dizziness"]),
492            Column::from_i64s("count", vec![10, 5, 3, 7]),
493        ])
494        .unwrap_or_else(|_| unreachable!())
495    }
496
497    // =========================================================================
498    // Inner join
499    // =========================================================================
500
501    #[test]
502    fn inner_join_basic() {
503        let result = drugs()
504            .join(&events(), &["drug_id"], JoinType::Inner)
505            .unwrap_or_else(|_| unreachable!());
506        // D1 matches 2 events, D2 matches 1 = 3 rows
507        assert_eq!(result.height(), 3);
508        // drug_id + drug_name + event + count = 4 columns
509        assert_eq!(result.width(), 4);
510    }
511
512    #[test]
513    fn inner_join_no_matches() {
514        let left = DataFrame::new(vec![
515            Column::from_strs("k", &["a", "b"]),
516            Column::from_i64s("v", vec![1, 2]),
517        ])
518        .unwrap_or_else(|_| unreachable!());
519        let right = DataFrame::new(vec![
520            Column::from_strs("k", &["c", "d"]),
521            Column::from_i64s("w", vec![3, 4]),
522        ])
523        .unwrap_or_else(|_| unreachable!());
524
525        let result = left
526            .join(&right, &["k"], JoinType::Inner)
527            .unwrap_or_else(|_| unreachable!());
528        assert_eq!(result.height(), 0);
529    }
530
531    // =========================================================================
532    // Left join
533    // =========================================================================
534
535    #[test]
536    fn left_join_basic() {
537        let result = drugs()
538            .join(&events(), &["drug_id"], JoinType::Left)
539            .unwrap_or_else(|_| unreachable!());
540        // D1→2, D2→1, D3→1(null right), D4→1(null right) = 5 rows
541        assert_eq!(result.height(), 5);
542        assert_eq!(result.width(), 4);
543
544        // Verify D3 has null event
545        let mut found_d3 = false;
546        for i in 0..result.height() {
547            let id = result
548                .column("drug_id")
549                .unwrap_or_else(|_| unreachable!())
550                .get(i);
551            if id == Some(Scalar::String("D3".into())) {
552                found_d3 = true;
553                assert_eq!(
554                    result
555                        .column("event")
556                        .unwrap_or_else(|_| unreachable!())
557                        .get(i),
558                    Some(Scalar::Null)
559                );
560            }
561        }
562        assert!(found_d3);
563    }
564
565    // =========================================================================
566    // Right join
567    // =========================================================================
568
569    #[test]
570    fn right_join_basic() {
571        let result = drugs()
572            .join(&events(), &["drug_id"], JoinType::Right)
573            .unwrap_or_else(|_| unreachable!());
574        // D1→2, D2→1, D5→1(null left) = 4 rows
575        assert_eq!(result.height(), 4);
576
577        // Verify D5 has null drug_name
578        let mut found_d5 = false;
579        for i in 0..result.height() {
580            let id = result
581                .column("drug_id")
582                .unwrap_or_else(|_| unreachable!())
583                .get(i);
584            if id == Some(Scalar::String("D5".into())) {
585                found_d5 = true;
586                assert_eq!(
587                    result
588                        .column("drug_name")
589                        .unwrap_or_else(|_| unreachable!())
590                        .get(i),
591                    Some(Scalar::Null)
592                );
593            }
594        }
595        assert!(found_d5);
596    }
597
598    // =========================================================================
599    // Outer join
600    // =========================================================================
601
602    #[test]
603    fn outer_join_basic() {
604        let result = drugs()
605            .join(&events(), &["drug_id"], JoinType::Outer)
606            .unwrap_or_else(|_| unreachable!());
607        // D1→2, D2→1, D3→1(null right), D4→1(null right), D5→1(null left) = 6
608        assert_eq!(result.height(), 6);
609        assert_eq!(result.width(), 4);
610    }
611
612    // =========================================================================
613    // Semi join
614    // =========================================================================
615
616    #[test]
617    fn semi_join_basic() {
618        let result = drugs()
619            .join(&events(), &["drug_id"], JoinType::Semi)
620            .unwrap_or_else(|_| unreachable!());
621        // D1 and D2 have matches → 2 rows
622        assert_eq!(result.height(), 2);
623        // Only left columns
624        assert_eq!(result.width(), 2);
625    }
626
627    // =========================================================================
628    // Anti join
629    // =========================================================================
630
631    #[test]
632    fn anti_join_basic() {
633        let result = drugs()
634            .join(&events(), &["drug_id"], JoinType::Anti)
635            .unwrap_or_else(|_| unreachable!());
636        // D3 and D4 have no matches → 2 rows
637        assert_eq!(result.height(), 2);
638        assert_eq!(result.width(), 2);
639    }
640
641    // =========================================================================
642    // Multi-key join
643    // =========================================================================
644
645    #[test]
646    fn multi_key_join() {
647        let left = DataFrame::new(vec![
648            Column::from_strs("drug", &["asp", "asp", "met"]),
649            Column::from_strs("event", &["ha", "na", "ha"]),
650            Column::from_i64s("left_val", vec![1, 2, 3]),
651        ])
652        .unwrap_or_else(|_| unreachable!());
653
654        let right = DataFrame::new(vec![
655            Column::from_strs("drug", &["asp", "met", "met"]),
656            Column::from_strs("event", &["ha", "ha", "na"]),
657            Column::from_i64s("right_val", vec![10, 20, 30]),
658        ])
659        .unwrap_or_else(|_| unreachable!());
660
661        let result = left
662            .join(&right, &["drug", "event"], JoinType::Inner)
663            .unwrap_or_else(|_| unreachable!());
664        // asp+ha → 1 match, met+ha → 1 match = 2 rows
665        assert_eq!(result.height(), 2);
666        assert_eq!(result.width(), 4); // drug + event + left_val + right_val
667    }
668
669    // =========================================================================
670    // Asymmetric keys (join_on)
671    // =========================================================================
672
673    #[test]
674    fn join_on_different_key_names() {
675        let left = DataFrame::new(vec![
676            Column::from_strs("id", &["a", "b", "c"]),
677            Column::from_i64s("val", vec![1, 2, 3]),
678        ])
679        .unwrap_or_else(|_| unreachable!());
680
681        let right = DataFrame::new(vec![
682            Column::from_strs("key", &["b", "c", "d"]),
683            Column::from_i64s("score", vec![10, 20, 30]),
684        ])
685        .unwrap_or_else(|_| unreachable!());
686
687        let result = left
688            .join_on(&right, &["id"], &["key"], JoinType::Inner)
689            .unwrap_or_else(|_| unreachable!());
690        assert_eq!(result.height(), 2); // b, c
691        assert_eq!(result.width(), 3); // id + val + score
692        // Key column uses left name "id"
693        assert!(result.column("id").is_ok());
694    }
695
696    // =========================================================================
697    // Null key handling
698    // =========================================================================
699
700    #[test]
701    fn null_keys_never_match() {
702        let left = DataFrame::new(vec![
703            Column::new_string("k", vec![Some("a".into()), None, Some("c".into())]),
704            Column::from_i64s("v", vec![1, 2, 3]),
705        ])
706        .unwrap_or_else(|_| unreachable!());
707
708        let right = DataFrame::new(vec![
709            Column::new_string("k", vec![Some("a".into()), None]),
710            Column::from_i64s("w", vec![10, 20]),
711        ])
712        .unwrap_or_else(|_| unreachable!());
713
714        // Inner: null keys don't match → only "a" matches
715        let inner = left
716            .join(&right, &["k"], JoinType::Inner)
717            .unwrap_or_else(|_| unreachable!());
718        assert_eq!(inner.height(), 1);
719
720        // Left: null-keyed left row preserved with null right
721        let lj = left
722            .join(&right, &["k"], JoinType::Left)
723            .unwrap_or_else(|_| unreachable!());
724        assert_eq!(lj.height(), 3); // a→1, null→1(null right), c→1(null right)
725    }
726
727    #[test]
728    fn null_keys_kept_in_anti_join() {
729        let left = DataFrame::new(vec![
730            Column::new_string("k", vec![Some("a".into()), None, Some("c".into())]),
731            Column::from_i64s("v", vec![1, 2, 3]),
732        ])
733        .unwrap_or_else(|_| unreachable!());
734
735        let right = DataFrame::new(vec![
736            Column::from_strs("k", &["a"]),
737            Column::from_i64s("w", vec![10]),
738        ])
739        .unwrap_or_else(|_| unreachable!());
740
741        let result = left
742            .join(&right, &["k"], JoinType::Anti)
743            .unwrap_or_else(|_| unreachable!());
744        // null and "c" don't match → 2 rows
745        assert_eq!(result.height(), 2);
746    }
747
748    // =========================================================================
749    // Name collision handling
750    // =========================================================================
751
752    #[test]
753    fn name_collision_suffixes() {
754        let left = DataFrame::new(vec![
755            Column::from_strs("k", &["a", "b"]),
756            Column::from_i64s("value", vec![1, 2]),
757        ])
758        .unwrap_or_else(|_| unreachable!());
759
760        let right = DataFrame::new(vec![
761            Column::from_strs("k", &["a", "b"]),
762            Column::from_i64s("value", vec![10, 20]),
763        ])
764        .unwrap_or_else(|_| unreachable!());
765
766        let result = left
767            .join(&right, &["k"], JoinType::Inner)
768            .unwrap_or_else(|_| unreachable!());
769        assert_eq!(result.height(), 2);
770        assert!(result.column("value_left").is_ok());
771        assert!(result.column("value_right").is_ok());
772    }
773
774    // =========================================================================
775    // Empty DataFrames
776    // =========================================================================
777
778    #[test]
779    fn join_empty_left() {
780        let left = DataFrame::new(vec![
781            Column::from_strs("k", &[]),
782            Column::from_i64s("v", vec![]),
783        ])
784        .unwrap_or_else(|_| unreachable!());
785
786        let right = DataFrame::new(vec![
787            Column::from_strs("k", &["a"]),
788            Column::from_i64s("w", vec![1]),
789        ])
790        .unwrap_or_else(|_| unreachable!());
791
792        let result = left
793            .join(&right, &["k"], JoinType::Inner)
794            .unwrap_or_else(|_| unreachable!());
795        assert_eq!(result.height(), 0);
796
797        let result = left
798            .join(&right, &["k"], JoinType::Left)
799            .unwrap_or_else(|_| unreachable!());
800        assert_eq!(result.height(), 0);
801    }
802
803    #[test]
804    fn join_empty_right() {
805        let left = DataFrame::new(vec![
806            Column::from_strs("k", &["a"]),
807            Column::from_i64s("v", vec![1]),
808        ])
809        .unwrap_or_else(|_| unreachable!());
810
811        let right = DataFrame::new(vec![
812            Column::from_strs("k", &[]),
813            Column::from_i64s("w", vec![]),
814        ])
815        .unwrap_or_else(|_| unreachable!());
816
817        let inner = left
818            .join(&right, &["k"], JoinType::Inner)
819            .unwrap_or_else(|_| unreachable!());
820        assert_eq!(inner.height(), 0);
821
822        let lj = left
823            .join(&right, &["k"], JoinType::Left)
824            .unwrap_or_else(|_| unreachable!());
825        assert_eq!(lj.height(), 1); // left row preserved
826    }
827
828    // =========================================================================
829    // Error cases
830    // =========================================================================
831
832    #[test]
833    fn error_key_mismatch() {
834        let left = DataFrame::new(vec![
835            Column::from_strs("a", &["x"]),
836            Column::from_strs("b", &["y"]),
837        ])
838        .unwrap_or_else(|_| unreachable!());
839
840        let right =
841            DataFrame::new(vec![Column::from_strs("c", &["x"])]).unwrap_or_else(|_| unreachable!());
842
843        let err = left.join_on(&right, &["a", "b"], &["c"], JoinType::Inner);
844        assert!(err.is_err());
845    }
846
847    #[test]
848    fn error_empty_keys() {
849        let left =
850            DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
851        let right =
852            DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
853
854        let err = left.join(&right, &[], JoinType::Inner);
855        assert!(err.is_err());
856    }
857
858    #[test]
859    fn error_missing_column() {
860        let left =
861            DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
862        let right =
863            DataFrame::new(vec![Column::from_strs("b", &["x"])]).unwrap_or_else(|_| unreachable!());
864
865        let err = left.join(&right, &["a"], JoinType::Inner);
866        assert!(err.is_err()); // "a" not found in right
867    }
868
869    // =========================================================================
870    // Type preservation
871    // =========================================================================
872
873    #[test]
874    fn type_preservation_numeric_keys() {
875        let left = DataFrame::new(vec![
876            Column::from_i64s("id", vec![1, 2, 3]),
877            Column::from_strs("name", &["a", "b", "c"]),
878        ])
879        .unwrap_or_else(|_| unreachable!());
880
881        let right = DataFrame::new(vec![
882            Column::from_i64s("id", vec![2, 3, 4]),
883            Column::from_f64s("score", vec![9.5, 8.0, 7.5]),
884        ])
885        .unwrap_or_else(|_| unreachable!());
886
887        let result = left
888            .join(&right, &["id"], JoinType::Inner)
889            .unwrap_or_else(|_| unreachable!());
890        assert_eq!(result.height(), 2); // id 2, 3
891
892        // Verify key column preserved as Int64
893        let id_col = result.column("id").unwrap_or_else(|_| unreachable!());
894        assert_eq!(id_col.dtype(), crate::column::DataType::Int64);
895
896        // Verify score preserved as Float64
897        let score_col = result.column("score").unwrap_or_else(|_| unreachable!());
898        assert_eq!(score_col.dtype(), crate::column::DataType::Float64);
899    }
900
901    #[test]
902    fn many_to_many_join() {
903        let left = DataFrame::new(vec![
904            Column::from_strs("k", &["a", "a", "b"]),
905            Column::from_i64s("lv", vec![1, 2, 3]),
906        ])
907        .unwrap_or_else(|_| unreachable!());
908
909        let right = DataFrame::new(vec![
910            Column::from_strs("k", &["a", "a"]),
911            Column::from_i64s("rv", vec![10, 20]),
912        ])
913        .unwrap_or_else(|_| unreachable!());
914
915        let result = left
916            .join(&right, &["k"], JoinType::Inner)
917            .unwrap_or_else(|_| unreachable!());
918        // 2 left "a" × 2 right "a" = 4 matches, b→0 = total 4
919        assert_eq!(result.height(), 4);
920    }
921
922    #[test]
923    fn semi_join_deduplicates() {
924        // Semi join should produce at most one row per left row, even with many right matches
925        let left = DataFrame::new(vec![
926            Column::from_strs("k", &["a", "b"]),
927            Column::from_i64s("v", vec![1, 2]),
928        ])
929        .unwrap_or_else(|_| unreachable!());
930
931        let right = DataFrame::new(vec![
932            Column::from_strs("k", &["a", "a", "a"]),
933            Column::from_i64s("w", vec![10, 20, 30]),
934        ])
935        .unwrap_or_else(|_| unreachable!());
936
937        let result = left
938            .join(&right, &["k"], JoinType::Semi)
939            .unwrap_or_else(|_| unreachable!());
940        // "a" appears once despite 3 right matches, "b" no match
941        assert_eq!(result.height(), 1);
942    }
943
944    #[test]
945    fn outer_join_preserves_all_keys() {
946        let left = DataFrame::new(vec![
947            Column::from_strs("k", &["a", "b"]),
948            Column::from_i64s("v", vec![1, 2]),
949        ])
950        .unwrap_or_else(|_| unreachable!());
951
952        let right = DataFrame::new(vec![
953            Column::from_strs("k", &["b", "c"]),
954            Column::from_i64s("w", vec![20, 30]),
955        ])
956        .unwrap_or_else(|_| unreachable!());
957
958        let result = left
959            .join(&right, &["k"], JoinType::Outer)
960            .unwrap_or_else(|_| unreachable!());
961        // a(left only), b(both), c(right only) = 3
962        assert_eq!(result.height(), 3);
963
964        // All key values present
965        let keys: Vec<Scalar> = (0..result.height())
966            .filter_map(|i| result.column("k").ok().and_then(|c| c.get(i)))
967            .collect();
968        assert_eq!(keys.len(), 3);
969    }
970
971    #[test]
972    fn right_join_symmetric_to_left() {
973        // right_join(A, B) should have same rows as left_join(B, A) (different column order)
974        let a = DataFrame::new(vec![
975            Column::from_strs("k", &["x", "y"]),
976            Column::from_i64s("a_val", vec![1, 2]),
977        ])
978        .unwrap_or_else(|_| unreachable!());
979
980        let b = DataFrame::new(vec![
981            Column::from_strs("k", &["y", "z"]),
982            Column::from_i64s("b_val", vec![10, 20]),
983        ])
984        .unwrap_or_else(|_| unreachable!());
985
986        let right_result = a
987            .join(&b, &["k"], JoinType::Right)
988            .unwrap_or_else(|_| unreachable!());
989        let left_result = b
990            .join(&a, &["k"], JoinType::Left)
991            .unwrap_or_else(|_| unreachable!());
992
993        assert_eq!(right_result.height(), left_result.height());
994    }
995}