Skip to main content

robin_sparkless_polars/dataframe/
joins.rs

1//! Join operations for DataFrame.
2
3use std::collections::HashSet;
4
5use super::DataFrame;
6use crate::schema_conv::data_type_to_polars_type;
7use crate::type_coercion::coerce_expr_pair_for_join;
8use polars::prelude::{
9    DataType as PlDataType, Expr, JoinType as PlJoinType, Operator, PolarsError,
10    SchemaNamesAndDtypes, coalesce as pl_coalesce,
11};
12use polars_plan::dsl::functions::nth;
13
14fn expr_to_column_name(expr: &Expr) -> Option<String> {
15    use polars::prelude::Expr as PlExpr;
16    let mut e = expr;
17    loop {
18        match e {
19            PlExpr::Column(n) => return Some(n.as_str().to_string()),
20            PlExpr::Alias(inner, _) | PlExpr::Cast { expr: inner, .. } => e = inner.as_ref(),
21            _ => return None,
22        }
23    }
24}
25
26/// If `expr` contains an equality between two column refs (e.g. left.dept_id == right.dept_id),
27/// returns Some((left_col_name, right_col_name)) so the caller can use key-based join.
28/// Peels Alias and matches Eq or EqValidity, and also walks simple AND trees so that
29/// compound conditions like (a.id == b.id) & (a.amount > 30) still yield the key pair.
30/// Used for PySpark parity (#1049, #380).
31pub fn try_extract_join_eq_columns(expr: &Expr) -> Option<(String, String)> {
32    try_extract_join_eq_columns_all(expr).into_iter().next()
33}
34
35/// Collects all (left_col, right_col) equality pairs from an expression (e.g. AND of (a.id == b.id) & (a.x == b.x)).
36/// Used so condition joins on multiple keys use a single join with all keys (#1148).
37pub fn try_extract_join_eq_columns_all(expr: &Expr) -> Vec<(String, String)> {
38    use polars::prelude::Expr as PlExpr;
39
40    fn inner_extract_all(e: &Expr, out: &mut Vec<(String, String)>) {
41        let mut current = e;
42        while let PlExpr::Alias(inner, _) = current {
43            current = inner.as_ref();
44        }
45        match current {
46            PlExpr::BinaryExpr {
47                left,
48                op: Operator::Eq | Operator::EqValidity,
49                right,
50            } => {
51                if let (Some(l), Some(r)) = (
52                    expr_to_column_name(left.as_ref()),
53                    expr_to_column_name(right.as_ref()),
54                ) {
55                    out.push((l, r));
56                }
57            }
58            PlExpr::BinaryExpr {
59                left,
60                op: Operator::And,
61                right,
62            } => {
63                inner_extract_all(left.as_ref(), out);
64                inner_extract_all(right.as_ref(), out);
65            }
66            _ => {}
67        }
68    }
69
70    let mut pairs = Vec::new();
71    inner_extract_all(expr, &mut pairs);
72    pairs
73}
74
75/// Returns true if the expression is only AND and Eq (column refs). When true, a key-based join
76/// already enforces the condition, so we must not filter after the join (left/right/outer would
77/// otherwise lose unmatched rows). Used for PySpark parity (#1242).
78pub fn expr_contains_only_join_key_equalities(expr: &Expr) -> bool {
79    use polars::prelude::Expr as PlExpr;
80    fn only_join_equalities(e: &Expr) -> bool {
81        let mut current = e;
82        while let PlExpr::Alias(inner, _) = current {
83            current = inner.as_ref();
84        }
85        match current {
86            PlExpr::BinaryExpr {
87                left,
88                op: Operator::Eq | Operator::EqValidity,
89                right,
90            } => {
91                expr_to_column_name(left.as_ref()).is_some()
92                    && expr_to_column_name(right.as_ref()).is_some()
93            }
94            PlExpr::BinaryExpr {
95                left,
96                op: Operator::And,
97                right,
98            } => only_join_equalities(left.as_ref()) && only_join_equalities(right.as_ref()),
99            _ => false,
100        }
101    }
102    only_join_equalities(expr)
103}
104
105/// Join type for DataFrame joins (PySpark-compatible)
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum JoinType {
108    Inner,
109    Left,
110    Right,
111    Outer,
112    /// Rows from left that have a match in right; only left columns (PySpark left_semi).
113    LeftSemi,
114    /// Rows from left that have no match in right; only left columns (PySpark left_anti).
115    LeftAnti,
116}
117
118/// Origin for a join: column-name based vs condition-based.
119#[derive(Debug, Clone, Copy)]
120pub enum JoinOrigin {
121    /// join(on = [...]) style joins (column-name based)
122    ColumnOn,
123    /// Condition-based joins (e.g. left.col == right.col) where both key sides
124    /// must remain addressable separately (dept_id, dept_id_right).
125    Condition,
126}
127
128pub struct JoinOptions {
129    pub case_sensitive: bool,
130    pub coalesce_same_name_keys: bool,
131    pub mark_join_keys_ambiguous: bool,
132    pub origin: JoinOrigin,
133}
134
135/// Join with another DataFrame on the given columns. Preserves case_sensitive on result.
136/// When join key types differ (e.g. str vs int), coerces both sides to a common type (PySpark parity #274).
137/// When both tables have the same join key column name(s), renames the right's keys to temp names and
138/// uses left_on/right_on so Polars does not error with "duplicate column" (issue #580, PySpark parity).
139/// When left/right key names differ in casing or name (e.g. "id" vs "ID" or "id" vs "other_id"), aliases right keys to left names
140/// so the result has one key column name (PySpark parity #604, #743).
141/// For Right and Outer, reorders columns to match PySpark: key(s), then left non-key, then right non-key.
142/// `left_on` and `right_on` must have the same length; keys are matched by position.
143///
144/// When `coalesce_same_name_keys` is true (e.g. join(right, "id") or join(right, [col("id")])), duplicate
145/// key columns are coalesced into one so the result has a single key column (PySpark parity #1049, #353).
146/// When false (e.g. condition join left.x == right.x), both key columns are kept (dept_id, dept_id_right).
147///
148/// When `options.mark_join_keys_ambiguous` is true and left/right key names are the same, unqualified
149/// references to those key names are treated as ambiguous (PySpark parity for condition join: #1230).
150pub fn join(
151    left: &DataFrame,
152    right: &DataFrame,
153    left_on: Vec<&str>,
154    right_on: Vec<&str>,
155    how: JoinType,
156    options: JoinOptions,
157) -> Result<DataFrame, PolarsError> {
158    let JoinOptions {
159        case_sensitive,
160        coalesce_same_name_keys,
161        mark_join_keys_ambiguous,
162        origin,
163    } = options;
164    use polars::prelude::{JoinBuilder, JoinCoalesce, col};
165    if left_on.len() != right_on.len() {
166        return Err(PolarsError::ComputeError(
167            "join: left_on and right_on must have the same length".into(),
168        ));
169    }
170    let mut left_lf = left.lazy_frame();
171    let mut right_lf = right.lazy_frame();
172    // For full outer joins we preserve the left-side join key values in temporary columns so we
173    // can use them as the canonical join key after the join (unmatched right rows get null key).
174    let mut outer_left_key_copies: Vec<(String, String)> = Vec::new();
175    // Track any right-side join keys we renamed for outer joins so we can drop the suffixed
176    // key columns from the final result schema (PySpark parity for outer join keys).
177    let mut outer_join_renamed_right_keys: Vec<String> = Vec::new();
178
179    // Resolve key names on both sides so we can alias right keys to left names (#604, #743).
180    let left_key_names: Vec<String> = left_on
181        .iter()
182        .map(|k| {
183            left.resolve_column_name(k).map_err(|e| {
184                PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
185            })
186        })
187        .collect::<Result<_, _>>()?;
188    let mut right_key_names: Vec<String> = right_on
189        .iter()
190        .map(|k| {
191            right.resolve_column_name(k).map_err(|e| {
192                PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
193            })
194        })
195        .collect::<Result<_, _>>()?;
196    // For outer joins invoked via column-name based join (coalesce_same_name_keys = true),
197    // add temp copies of left join keys so we can restore them as canonical keys after the join
198    // (PySpark parity for grouping/selection on join keys when using join(on=...)).
199    if matches!(how, JoinType::Outer)
200        && coalesce_same_name_keys
201        && matches!(origin, JoinOrigin::ColumnOn)
202    {
203        use polars::prelude::col;
204        let mut copy_exprs: Vec<Expr> = Vec::new();
205        for name in &left_key_names {
206            let temp = format!("__rs_outer_key_{}", name);
207            outer_left_key_copies.push((name.clone(), temp.clone()));
208            copy_exprs.push(col(name.as_str()).alias(temp.as_str()));
209        }
210        if !copy_exprs.is_empty() {
211            left_lf = left_lf.with_columns(copy_exprs);
212        }
213    }
214    // For condition-based full outer joins on same-named keys, keep both key columns by
215    // renaming the right-side keys to a suffixed form (e.g. dept_id -> dept_id_right)
216    // so that left/right keys remain addressable separately (dept_id, dept_id_right).
217    // Inner/left/right condition joins keep the original column-name join semantics
218    // (single key column) so tests like issue #353 that use `on` as a Column do not
219    // see extra *_right key columns.
220    if matches!(origin, JoinOrigin::Condition)
221        && matches!(how, JoinType::Outer)
222        && left_key_names.len() == right_key_names.len()
223        && left_key_names
224            .iter()
225            .zip(right_key_names.iter())
226            .all(|(a, b)| a.eq_ignore_ascii_case(b))
227    {
228        use polars::prelude::col;
229        use std::collections::HashMap;
230        let mut rename_map: HashMap<String, String> = HashMap::new();
231        for name in &right_key_names {
232            rename_map.insert(name.clone(), format!("{name}_right"));
233        }
234        if !rename_map.is_empty() {
235            let current_names: Vec<String> = right.columns()?.into_iter().collect();
236            let exprs: Vec<Expr> = current_names
237                .iter()
238                .map(|n| {
239                    if let Some(new_name) = rename_map.get(n) {
240                        col(n.as_str()).alias(new_name.as_str())
241                    } else {
242                        col(n.as_str())
243                    }
244                })
245                .collect();
246            right_lf = right_lf.select(&exprs);
247            for rk in &mut right_key_names {
248                if let Some(new_name) = rename_map.get(rk) {
249                    *rk = new_name.clone();
250                }
251            }
252        }
253    }
254
255    // For full outer joins (via join(on=...)) where left/right use the same key names,
256    // rename right keys to a suffixed form (e.g. key -> key_right) so that we preserve
257    // both columns internally while building the join, but then drop the suffixed right
258    // key columns from the final result so the public schema matches PySpark (single
259    // join key column). Condition-based joins (on=Column) keep both key columns.
260    if matches!(how, JoinType::Outer)
261        && coalesce_same_name_keys
262        && left_key_names == right_key_names
263        && matches!(origin, JoinOrigin::ColumnOn)
264    {
265        use polars::prelude::col;
266        use std::collections::HashMap;
267        let mut rename_map: HashMap<String, String> = HashMap::new();
268        for name in &right_key_names {
269            rename_map.insert(name.clone(), format!("{name}_right"));
270        }
271        if !rename_map.is_empty() {
272            let current_names: Vec<String> = right.columns()?.into_iter().collect();
273            let exprs: Vec<Expr> = current_names
274                .iter()
275                .map(|n| {
276                    if let Some(new_name) = rename_map.get(n) {
277                        col(n.as_str()).alias(new_name.as_str())
278                    } else {
279                        col(n.as_str())
280                    }
281                })
282                .collect();
283            right_lf = right_lf.select(&exprs);
284            // Update right_key_names to the new suffixed names and remember them so we can
285            // drop the suffixed right key columns from the final result schema.
286            for rk in &mut right_key_names {
287                if let Some(new_name) = rename_map.get(rk) {
288                    *rk = new_name.clone();
289                }
290            }
291            outer_join_renamed_right_keys = right_key_names.clone();
292        }
293    }
294
295    let keys_differ = left_key_names != right_key_names;
296    // When coalesce_same_name_keys and !case_sensitive, treat keys as same if they match case-insensitively (#297).
297    let keys_match_for_coalesce = !keys_differ
298        || (coalesce_same_name_keys
299            && !case_sensitive
300            && left_key_names.len() == right_key_names.len()
301            && left_key_names
302                .iter()
303                .zip(right_key_names.iter())
304                .all(|(a, b)| a.eq_ignore_ascii_case(b)));
305
306    if keys_match_for_coalesce {
307        // #1009, #1019: When aliasing right key to left name, right may already have a column with that name (e.g. self-join).
308        let right_names: Vec<String> = right.columns()?.into_iter().collect();
309        let mut renames: std::collections::HashMap<String, String> =
310            std::collections::HashMap::new();
311        for (i, _) in left_on.iter().enumerate() {
312            let target_name = &left_key_names[i];
313            let right_key = &right_key_names[i];
314            if target_name != right_key && right_names.iter().any(|n| n == target_name) {
315                renames.insert(target_name.clone(), format!("{}_right", target_name));
316            }
317        }
318        if !renames.is_empty() {
319            let exprs: Vec<Expr> = right_names
320                .iter()
321                .map(|n| {
322                    if let Some(suffix) = renames.get(n) {
323                        col(n.as_str()).alias(suffix.as_str())
324                    } else {
325                        col(n.as_str())
326                    }
327                })
328                .collect();
329            right_lf = right_lf.select(&exprs);
330        }
331
332        // Coerce join keys to a common type when left/right dtypes differ (PySpark #274).
333        // Alias right keys to left key names so result has one key column name (#604, #743).
334        // Collect each side's schema once to avoid repeated schema_or_collect in get_column_dtype (performance, e.g. #1430).
335        let left_schema = left.polars_schema()?;
336        let right_schema = right.polars_schema()?;
337        let mut left_casts: Vec<Expr> = Vec::new();
338        let mut right_casts: Vec<Expr> = Vec::new();
339        for (i, key) in left_on.iter().enumerate() {
340            let left_name = &left_key_names[i];
341            let right_name = &right_key_names[i];
342            let left_dtype = left_schema
343                .get(left_name.as_str())
344                .cloned()
345                .ok_or_else(|| {
346                    PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
347                })?;
348            let right_dtype = right_schema
349                .get(right_name.as_str())
350                .cloned()
351                .ok_or_else(|| {
352                    PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
353                })?;
354            let target_name = left_name.as_str();
355            if left_dtype != right_dtype {
356                let (l, r) = coerce_expr_pair_for_join(
357                    left_name.as_str(),
358                    right_name.as_str(),
359                    &left_dtype,
360                    &right_dtype,
361                    target_name,
362                )?;
363                left_casts.push(l);
364                right_casts.push(r);
365            } else if left_name != right_name {
366                right_casts.push(col(right_name.as_str()).alias(target_name));
367            }
368        }
369        if !left_casts.is_empty() {
370            left_lf = left_lf.with_columns(left_casts);
371        }
372        if !right_casts.is_empty() {
373            right_lf = right_lf.with_columns(right_casts);
374            let drop_right: std::collections::HashSet<String> = left_on
375                .iter()
376                .enumerate()
377                .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
378                .map(|(i, _)| right_key_names[i].clone())
379                .collect();
380            if !drop_right.is_empty() {
381                let current_right_names: Vec<String> = right_lf
382                    .collect_schema()
383                    .map(|s| s.iter_names().map(|n| n.to_string()).collect())?;
384                let keep_names: Vec<&str> = current_right_names
385                    .iter()
386                    .filter(|n| !drop_right.contains(*n))
387                    .map(String::as_str)
388                    .collect();
389                let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
390                right_lf = right_lf.select(&keep);
391                // Right keys were aliased to left names; use left names for join (#297).
392                right_key_names = left_key_names.clone();
393            }
394        }
395    }
396
397    let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
398    let polars_how: PlJoinType = match how {
399        JoinType::Inner => PlJoinType::Inner,
400        JoinType::Left => PlJoinType::Left,
401        JoinType::Right => PlJoinType::Right,
402        JoinType::Outer => PlJoinType::Full, // PySpark Outer = Polars Full
403        JoinType::LeftSemi => PlJoinType::Semi,
404        JoinType::LeftAnti => PlJoinType::Anti,
405    };
406
407    // Build join key expressions, coercing types when needed.
408    let mut left_on_exprs: Vec<Expr> = Vec::with_capacity(left_key_names.len());
409    let mut right_on_exprs: Vec<Expr> = Vec::with_capacity(right_key_names.len());
410
411    if keys_differ {
412        // left_on/right_on or condition join: coerce to common type but keep distinct column names
413        // so both key columns remain visible (PySpark parity #241, #1106).
414        use crate::type_coercion::find_common_type_for_join;
415        let right_schema = right_lf.collect_schema()?;
416        for i in 0..left_key_names.len() {
417            let left_name = &left_key_names[i];
418            let right_name = &right_key_names[i];
419            let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
420                PolarsError::ComputeError(
421                    format!("join key '{}' not found on left", left_name).into(),
422                )
423            })?;
424            let right_dtype = right_schema
425                .get(right_name.as_str())
426                .cloned()
427                .ok_or_else(|| {
428                    PolarsError::ComputeError(
429                        format!("join key '{}' not found on right", right_name).into(),
430                    )
431                })?;
432            if left_dtype == right_dtype {
433                left_on_exprs.push(col(left_name.as_str()));
434                right_on_exprs.push(col(right_name.as_str()));
435            } else {
436                let common = find_common_type_for_join(&left_dtype, &right_dtype)?;
437                left_on_exprs.push(col(left_name.as_str()).cast(common.clone()));
438                right_on_exprs.push(col(right_name.as_str()).cast(common));
439            }
440        }
441    } else {
442        left_on_exprs = left_key_names.iter().map(|n| col(n.as_str())).collect();
443        right_on_exprs = right_key_names.iter().map(|n| col(n.as_str())).collect();
444    }
445
446    // When same-named keys (e.g. join on "id" or left.id == right.id), coalesce so result has one
447    // key column and no _right in row keys (PySpark parity #1049, #353, #1148). Use keys_match_for_coalesce
448    // so case-insensitive match (e.g. name/NAME) also coalesces (#297).
449    // Outer joins keep separate key columns so canonical key comes from left (issue #280).
450    let coalesce = if !keys_match_for_coalesce {
451        JoinCoalesce::KeepColumns
452    } else if matches!(how, JoinType::Inner | JoinType::Left | JoinType::Right) {
453        JoinCoalesce::CoalesceColumns
454    } else if matches!(how, JoinType::Outer) {
455        JoinCoalesce::KeepColumns
456    } else {
457        JoinCoalesce::CoalesceColumns
458    };
459    let mut joined = JoinBuilder::new(left_lf)
460        .with(right_lf)
461        .how(polars_how)
462        .left_on(&left_on_exprs)
463        .right_on(&right_on_exprs)
464        .coalesce(coalesce)
465        .finish();
466
467    if matches!(how, JoinType::Outer) && !outer_left_key_copies.is_empty() {
468        use polars::prelude::col;
469        // For full outer joins with same-named keys, set the canonical join key column
470        // differently depending on how the join was invoked and whether there are
471        // overlapping non-key column names:
472        // - Column-name joins (on = "key") without non-key overlaps keep
473        //   coalesce(left_key, right_key) semantics so unmatched right rows keep
474        //   their key (issue #280, outer_join_then_groupby tests).
475        // - When there are overlapping non-key columns (e.g. "name" on both sides
476        //   in the employees/departments joins), or when we explicitly mark join
477        //   keys as ambiguous (expression-based joins), the canonical key should
478        //   come from the left side only so that condition-based parity fixtures
479        //   (outer_join / left_join / right_join) and Python join parity tests
480        //   see the expected left-key/null pattern.
481        let left_names_full: Vec<String> = left.columns()?.into_iter().collect();
482        let right_names_full: Vec<String> = right.columns()?.into_iter().collect();
483        let has_non_key_overlap = left_names_full.iter().any(|ln| {
484            !on_set.contains(ln.as_str())
485                && right_names_full
486                    .iter()
487                    .any(|rn| rn.eq_ignore_ascii_case(ln.as_str()))
488        });
489        for (i, (left_name, temp)) in outer_left_key_copies.iter().enumerate() {
490            let right_key_name = right_key_names.get(i).map(|s| s.as_str()).unwrap_or("");
491            let expr = if mark_join_keys_ambiguous || has_non_key_overlap {
492                // Expression / condition-style joins or joins with overlapping
493                // non-key column names: canonical key comes from the left side
494                // only; right key is exposed via the *_right column.
495                col(temp.as_str())
496            } else if right_key_name.is_empty() {
497                col(temp.as_str())
498            } else {
499                // Column-name outer join: coalesce(left_key_copy, right_key) so right-only
500                // rows keep their key instead of becoming null (issue #280).
501                pl_coalesce(&[col(temp.as_str()), col(right_key_name)])
502            };
503            joined = joined.with_column(expr.alias(left_name.as_str()));
504        }
505        // Drop the temp columns from the result.
506        let schema = joined.collect_schema()?;
507        let all_names: Vec<String> = schema.iter_names().map(|n| n.to_string()).collect();
508        let temp_set: std::collections::HashSet<&str> = outer_left_key_copies
509            .iter()
510            .map(|(_, t)| t.as_str())
511            .collect();
512        let keep_exprs: Vec<Expr> = all_names
513            .iter()
514            .filter(|n| !temp_set.contains(n.as_str()))
515            .map(|n| col(n.as_str()))
516            .collect();
517        joined = joined.select(&keep_exprs);
518
519        // For outer joins invoked via column-name based join (coalesce_same_name_keys = true)
520        // where join keys came from an explicit `on = ...` (not an expression-based join),
521        // drop the suffixed right-side key columns (e.g. dept_id_right) so the public schema
522        // has a single join key column, matching PySpark and the parity fixtures.
523        //
524        // Expression-based joins (where we mark join keys ambiguous) keep both key columns so
525        // that Python-level code can alias the right key separately (tests/parity/dataframe/test_join.py).
526        if !outer_join_renamed_right_keys.is_empty() && !mark_join_keys_ambiguous {
527            let schema = joined.collect_schema()?;
528            let all_names: Vec<String> = schema.iter_names().map(|n| n.to_string()).collect();
529            let drop_right_keys: std::collections::HashSet<&str> = outer_join_renamed_right_keys
530                .iter()
531                .map(|s| s.as_str())
532                .collect();
533            let keep_exprs: Vec<Expr> = all_names
534                .iter()
535                .filter(|n| !drop_right_keys.contains(n.as_str()))
536                .map(|n| col(n.as_str()))
537                .collect();
538            joined = joined.select(&keep_exprs);
539        }
540    }
541
542    let result_schema = joined.collect_schema()?;
543    let mut names: Vec<String> = result_schema.iter_names().map(|s| s.to_string()).collect();
544    // When same-named keys and Inner/Left/Right, select exactly: keys (once), left non-keys,
545    // Column order: left columns in original order, then right non-keys with _right for overlap
546    // (PySpark parity: same as fixture join_inner_dept_issue510 / join_on_string_issue513). Use
547    // keys_match_for_coalesce so case-insensitive key match (e.g. name/NAME) gets single key (#297).
548    if keys_match_for_coalesce && matches!(how, JoinType::Inner | JoinType::Left | JoinType::Right)
549    {
550        let left_names: Vec<String> = left.columns()?.into_iter().collect();
551        let right_names: Vec<String> = right.columns()?.into_iter().collect();
552        let key_set: std::collections::HashSet<&str> =
553            left_key_names.iter().map(|s| s.as_str()).collect();
554        let result_schema_ref = joined.collect_schema()?;
555        let result_names_vec: Vec<String> = result_schema_ref
556            .iter_names()
557            .map(|s| s.to_string())
558            .collect();
559        let result_names_set: std::collections::HashSet<String> =
560            result_names_vec.iter().cloned().collect();
561        // When !case_sensitive, coalesce columns that match case-insensitively (e.g. age + AGE) so
562        // select("age") and df1["age"] resolve to one column (#297). Same-case keys (id/id) alone
563        // stay ambiguous (#374).
564        let cast_exprs: Vec<Expr> = if !case_sensitive {
565            let left_struct = left.schema().ok();
566            let right_struct = right.schema().ok();
567            let mut exprs: Vec<Expr> = Vec::new();
568            for left_name in &left_names {
569                let matches: Vec<&String> = result_names_vec
570                    .iter()
571                    .filter(|r| r.eq_ignore_ascii_case(left_name))
572                    .collect();
573                if matches.is_empty() {
574                    continue;
575                }
576                let dtype = key_set
577                    .contains(left_name.as_str())
578                    .then(|| {
579                        left_struct
580                            .as_ref()
581                            .and_then(|s| {
582                                s.fields()
583                                    .iter()
584                                    .find(|f| f.name.as_str() == left_name.as_str())
585                                    .map(|f| data_type_to_polars_type(&f.data_type))
586                            })
587                            .or_else(|| left.get_column_dtype(left_name.as_str()))
588                    })
589                    .flatten()
590                    .or_else(|| left.get_column_dtype(left_name.as_str()));
591                let parts: Vec<Expr> = matches.iter().map(|m| col(m.as_str())).collect();
592                let e = if parts.len() == 1 {
593                    col(matches[0].as_str())
594                } else {
595                    pl_coalesce(&parts)
596                };
597                let e = match dtype {
598                    Some(dt) => e.cast(dt),
599                    None => e,
600                };
601                exprs.push(e.alias(left_name.as_str()));
602            }
603            let mut right_non_key_pos = 0_usize;
604            for right_name in &right_names {
605                if key_set.contains(right_name.as_str()) {
606                    continue;
607                }
608                let matches_left = left_names
609                    .iter()
610                    .any(|l| l.eq_ignore_ascii_case(right_name));
611                if matches_left {
612                    // Include right column by position only when it exists (join may coalesce keys).
613                    let result_idx = left_names.len() + right_non_key_pos;
614                    if result_idx < result_names_vec.len() {
615                        let dtype = right_struct
616                            .as_ref()
617                            .and_then(|s| {
618                                s.fields()
619                                    .iter()
620                                    .find(|f| f.name.as_str() == right_name.as_str())
621                                    .map(|f| data_type_to_polars_type(&f.data_type))
622                            })
623                            .or_else(|| right.get_column_dtype(right_name.as_str()));
624                        let alias_name = format!("{}_right", right_name);
625                        let e = nth(result_idx as i64).as_expr();
626                        let e = match dtype {
627                            Some(dt) => e.cast(dt),
628                            None => e,
629                        };
630                        exprs.push(e.alias(alias_name.as_str()));
631                    }
632                    right_non_key_pos += 1;
633                    continue;
634                }
635                if !result_names_set.contains(right_name) {
636                    continue;
637                }
638                let dtype = right_struct
639                    .as_ref()
640                    .and_then(|s| {
641                        s.fields()
642                            .iter()
643                            .find(|f| f.name.as_str() == right_name.as_str())
644                            .map(|f| data_type_to_polars_type(&f.data_type))
645                    })
646                    .or_else(|| right.get_column_dtype(right_name.as_str()));
647                let e = match dtype {
648                    Some(dt) => col(right_name.as_str()).cast(dt),
649                    None => col(right_name.as_str()),
650                };
651                exprs.push(e.alias(right_name.as_str()));
652                right_non_key_pos += 1;
653            }
654            Ok(exprs)
655        } else {
656            // Build desired from actual result schema so we never request a column index that
657            // doesn't exist (join may coalesce keys and produce fewer columns than left+right).
658            let schema_before = joined.collect_schema()?;
659            let dtypes_by_index: Vec<PlDataType> = schema_before
660                .iter_names_and_dtypes()
661                .map(|(_name, dt): (_, &PlDataType)| dt.clone())
662                .collect();
663            // Case-insensitive dedup so "id" and "ID" → "id", "ID_right" (#604 resolve_column_name).
664            let mut seen_lower: std::collections::HashSet<String> =
665                std::collections::HashSet::new();
666            let desired: Vec<String> = result_names_vec
667                .iter()
668                .map(|name| {
669                    let name_lower = name.to_lowercase();
670                    let alias = if seen_lower.contains(&name_lower) {
671                        format!("{}_right", name)
672                    } else {
673                        seen_lower.insert(name_lower);
674                        name.clone()
675                    };
676                    alias
677                })
678                .collect();
679            let left_struct = left.schema().ok();
680            let right_struct = right.schema().ok();
681            let exprs: Vec<Expr> = desired
682                .iter()
683                .enumerate()
684                .map(|(idx, alias_name)| {
685                    let result_name = &result_names_vec[idx];
686                    let dtype = if idx < left_names.len() {
687                        left_struct
688                            .as_ref()
689                            .and_then(|s| {
690                                s.fields()
691                                    .iter()
692                                    .find(|f| f.name.as_str() == result_name.as_str())
693                                    .map(|f| data_type_to_polars_type(&f.data_type))
694                            })
695                            .or_else(|| left.get_column_dtype(result_name.as_str()))
696                    } else if let Some(base) = alias_name.strip_suffix("_right") {
697                        right_struct
698                            .as_ref()
699                            .and_then(|s| {
700                                s.fields()
701                                    .iter()
702                                    .find(|f| f.name.as_str() == base)
703                                    .map(|f| data_type_to_polars_type(&f.data_type))
704                            })
705                            .or_else(|| right.get_column_dtype(base))
706                    } else {
707                        right_struct
708                            .as_ref()
709                            .and_then(|s| {
710                                s.fields()
711                                    .iter()
712                                    .find(|f| f.name.as_str() == result_name.as_str())
713                                    .map(|f| data_type_to_polars_type(&f.data_type))
714                            })
715                            .or_else(|| right.get_column_dtype(result_name.as_str()))
716                    };
717                    let e = nth(idx as i64).as_expr();
718                    match (dtype, dtypes_by_index.get(idx)) {
719                        (Some(dt), _) => e.cast(dt).alias(alias_name.as_str()),
720                        (_, Some(dt)) => e.cast(dt.clone()).alias(alias_name.as_str()),
721                        _ => e.alias(alias_name.as_str()),
722                    }
723                })
724                .collect();
725            Ok::<_, PolarsError>(exprs)
726        }?;
727        if !cast_exprs.is_empty() {
728            joined = joined.select(&cast_exprs);
729            let result_schema = joined.collect_schema()?;
730            names = result_schema.iter_names().map(|s| s.to_string()).collect();
731        }
732    }
733    let mut seen = std::collections::HashSet::new();
734    let mut unique_order: Vec<String> = Vec::new();
735    for n in &names {
736        if seen.insert(n.clone()) {
737            unique_order.push(n.clone());
738        }
739    }
740    if unique_order.len() < names.len() {
741        // Preserve column dtypes when deduplicating by position (#1165). nth(idx) can lose
742        // type in the logical schema; cast to the join result dtype so collect() returns
743        // correct types (e.g. v=10, w=20 as int, not string).
744        let schema_before_nth = joined.collect_schema()?;
745        let dtypes_by_index: Vec<PlDataType> = schema_before_nth
746            .iter_names_and_dtypes()
747            .map(|(_name, dt): (_, &PlDataType)| dt.clone())
748            .collect();
749        let exprs: Vec<Expr> = unique_order
750            .iter()
751            .map(|name| {
752                let idx = names.iter().position(|n| n == name).unwrap();
753                let e = nth(idx as i64).as_expr();
754                if let Some(dt) = dtypes_by_index.get(idx) {
755                    e.cast(dt.clone()).alias(name.as_str())
756                } else {
757                    e.alias(name.as_str())
758                }
759            })
760            .collect();
761        joined = joined.select(&exprs);
762    }
763    // For Right/Outer, reorder columns: keys, left non-keys, right non-keys (PySpark order).
764    let mut result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
765        let left_names = left.columns()?;
766        let right_names = right.columns()?;
767        let result_schema = joined.collect_schema()?;
768        let result_names: std::collections::HashSet<String> =
769            result_schema.iter_names().map(|s| s.to_string()).collect();
770        let mut order: Vec<String> = Vec::new();
771        for k in &left_key_names {
772            order.push(k.clone());
773        }
774        for n in &left_names {
775            if !on_set.contains(n) {
776                order.push(n.clone());
777            }
778        }
779        for n in &right_names {
780            let use_name = if left_names.iter().any(|l| l == n) {
781                format!("{n}_right")
782            } else {
783                n.clone()
784            };
785            if result_names.contains(&use_name) {
786                order.push(use_name);
787            }
788        }
789        if order.len() == result_names.len() {
790            let select_exprs: Vec<polars::prelude::Expr> =
791                order.iter().map(|s| col(s.as_str())).collect();
792            joined.select(select_exprs.as_slice())
793        } else {
794            joined
795        }
796    } else {
797        joined
798    };
799    // When !case_sensitive and we didn't run the coalesce/select block (keys_match_for_coalesce was
800    // false), the raw join can have both "id" and "ID"; rename duplicates to _right so
801    // resolve_column_name("ID") returns one column (#604).
802    let result_lf = if !case_sensitive {
803        let schema = result_lf.collect_schema()?;
804        let result_names: Vec<String> = schema.iter_names().map(|s| s.to_string()).collect();
805        let mut seen_lower: std::collections::HashSet<String> = std::collections::HashSet::new();
806        let mut need_rename = false;
807        let aliases: Vec<String> = result_names
808            .iter()
809            .map(|name| {
810                let name_lower = name.to_lowercase();
811                if seen_lower.contains(&name_lower) {
812                    need_rename = true;
813                    format!("{}_right", name)
814                } else {
815                    seen_lower.insert(name_lower);
816                    name.clone()
817                }
818            })
819            .collect();
820        if need_rename {
821            let dtypes: Vec<PlDataType> = schema
822                .iter_names_and_dtypes()
823                .map(|(_, dt)| dt.clone())
824                .collect();
825            let exprs: Vec<Expr> = aliases
826                .iter()
827                .enumerate()
828                .map(|(idx, alias)| {
829                    let e = nth(idx as i64).as_expr();
830                    if let Some(dt) = dtypes.get(idx) {
831                        e.cast(dt.clone()).alias(alias.as_str())
832                    } else {
833                        e.alias(alias.as_str())
834                    }
835                })
836                .collect();
837            result_lf.select(&exprs)
838        } else {
839            result_lf
840        }
841    } else {
842        result_lf
843    };
844    // When mark_join_keys_ambiguous is true (condition join on same-named keys), unqualified
845    // references to those key names must be treated as ambiguous (PySpark parity #1230 /
846    // issue #374), regardless of whether we coalesced the physical columns. Column-name
847    // joins (on = "id") never set mark_join_keys_ambiguous, so df1["age"] continues to work
848    // after coalescing same-named columns (#297).
849    let ambiguous_columns = if mark_join_keys_ambiguous {
850        Some(left_key_names.iter().cloned().collect::<HashSet<String>>())
851    } else {
852        None
853    };
854    Ok(super::DataFrame::from_lazy_with_options_and_ambiguous(
855        result_lf,
856        case_sensitive,
857        ambiguous_columns,
858    ))
859}
860
861#[cfg(test)]
862mod tests {
863    use super::{
864        JoinOptions, JoinOrigin, JoinType, expr_contains_only_join_key_equalities, join,
865        try_extract_join_eq_columns, try_extract_join_eq_columns_all,
866    };
867    use crate::functions::col;
868    use crate::{DataFrame, SparkSession};
869    use std::collections::HashMap;
870
871    #[test]
872    fn extract_join_eq_columns_from_eq_expr() {
873        let left = col("dept_id");
874        let right = col("dept_id");
875        let eq_expr = left.eq(right.into_expr());
876        let expr = eq_expr.into_expr();
877        let out = try_extract_join_eq_columns(&expr);
878        assert_eq!(out, Some(("dept_id".to_string(), "dept_id".to_string())));
879    }
880
881    #[test]
882    fn extract_join_eq_columns_all_from_and_of_equalities() {
883        // (a == a) & (b == b) yields both pairs (#1148).
884        let right = col("b").eq(col("b").into_expr());
885        let expr = col("a").eq(col("a").into_expr()).and_(&right).into_expr();
886        let out = try_extract_join_eq_columns_all(&expr);
887        assert_eq!(
888            out,
889            vec![
890                ("a".to_string(), "a".to_string()),
891                ("b".to_string(), "b".to_string()),
892            ]
893        );
894    }
895
896    #[test]
897    fn extract_join_eq_columns_from_aliased_eq() {
898        let eq_expr = col("a").eq(col("b").into_expr());
899        let expr = eq_expr.into_expr(); // adds Alias(..., "<expr>")
900        let out = try_extract_join_eq_columns(&expr);
901        assert_eq!(out, Some(("a".to_string(), "b".to_string())));
902    }
903
904    #[test]
905    fn expr_contains_only_join_key_equalities_simple_and_compound() {
906        // Only key equalities -> true (so we skip post-join filter for left/right/outer #1242).
907        let eq_expr = col("Key").eq(col("Name").into_expr()).into_expr();
908        assert!(expr_contains_only_join_key_equalities(&eq_expr));
909        let and_expr = col("a")
910            .eq(col("b").into_expr())
911            .and_(&col("c").eq(col("d").into_expr()))
912            .into_expr();
913        assert!(expr_contains_only_join_key_equalities(&and_expr));
914        // Compound (equality + other) -> false so we still apply filter (#380).
915        let gt_expr = col("a")
916            .eq(col("b").into_expr())
917            .and_(&col("x").gt(col("y").into_expr()))
918            .into_expr();
919        assert!(!expr_contains_only_join_key_equalities(&gt_expr));
920    }
921
922    fn left_df() -> DataFrame {
923        let spark = SparkSession::builder()
924            .app_name("join_tests")
925            .get_or_create();
926        spark
927            .create_dataframe(
928                vec![
929                    (1i64, 10i64, "a".to_string()),
930                    (2i64, 20i64, "b".to_string()),
931                ],
932                vec!["id", "v", "label"],
933            )
934            .unwrap()
935    }
936
937    fn right_df() -> DataFrame {
938        let spark = SparkSession::builder()
939            .app_name("join_tests")
940            .get_or_create();
941        spark
942            .create_dataframe(
943                vec![
944                    (1i64, 100i64, "x".to_string()),
945                    (3i64, 300i64, "z".to_string()),
946                ],
947                vec!["id", "w", "tag"],
948            )
949            .unwrap()
950    }
951
952    #[test]
953    fn inner_join() {
954        let left = left_df();
955        let right = right_df();
956        let out = join(
957            &left,
958            &right,
959            vec!["id"],
960            vec!["id"],
961            JoinType::Inner,
962            JoinOptions {
963                case_sensitive: false,
964                coalesce_same_name_keys: false,
965                mark_join_keys_ambiguous: false,
966                origin: JoinOrigin::ColumnOn,
967            },
968        )
969        .unwrap();
970        assert_eq!(out.count().unwrap(), 1);
971        let cols = out.columns().unwrap();
972        assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
973    }
974
975    /// #1165: Join with same-named keys and coalesce: non-key columns keep correct dtypes in schema and collect.
976    #[test]
977    fn join_coalesce_preserves_non_key_column_types() {
978        use robin_sparkless_core::DataType as CoreDataType;
979        let left = left_df();
980        let right = right_df();
981        let out = join(
982            &left,
983            &right,
984            vec!["id"],
985            vec!["id"],
986            JoinType::Inner,
987            JoinOptions {
988                case_sensitive: false,
989                coalesce_same_name_keys: true,
990                mark_join_keys_ambiguous: false,
991                origin: JoinOrigin::ColumnOn,
992            },
993        )
994        .unwrap();
995        assert_eq!(out.count().unwrap(), 1);
996        let schema = out.schema().unwrap();
997        let v_field = schema.fields().iter().find(|f| f.name == "v");
998        let w_field = schema.fields().iter().find(|f| f.name == "w");
999        assert!(
1000            matches!(v_field.map(|f| &f.data_type), Some(CoreDataType::Long)),
1001            "v should be Long"
1002        );
1003        assert!(
1004            matches!(w_field.map(|f| &f.data_type), Some(CoreDataType::Long)),
1005            "w should be Long"
1006        );
1007        let rows = out.collect_as_json_rows().unwrap();
1008        assert_eq!(rows.len(), 1);
1009        let row = &rows[0];
1010        assert!(
1011            row.get("v").and_then(|v| v.as_i64()).is_some(),
1012            "v should be number in JSON"
1013        );
1014        assert!(
1015            row.get("w").and_then(|v| v.as_i64()).is_some(),
1016            "w should be number in JSON"
1017        );
1018    }
1019
1020    #[test]
1021    fn left_join() {
1022        let left = left_df();
1023        let right = right_df();
1024        let out = join(
1025            &left,
1026            &right,
1027            vec!["id"],
1028            vec!["id"],
1029            JoinType::Left,
1030            JoinOptions {
1031                case_sensitive: false,
1032                coalesce_same_name_keys: false,
1033                mark_join_keys_ambiguous: false,
1034                origin: JoinOrigin::ColumnOn,
1035            },
1036        )
1037        .unwrap();
1038        assert_eq!(out.count().unwrap(), 2);
1039    }
1040
1041    #[test]
1042    fn right_join() {
1043        let left = left_df();
1044        let right = right_df();
1045        let out = join(
1046            &left,
1047            &right,
1048            vec!["id"],
1049            vec!["id"],
1050            JoinType::Right,
1051            JoinOptions {
1052                case_sensitive: false,
1053                coalesce_same_name_keys: false,
1054                mark_join_keys_ambiguous: false,
1055                origin: JoinOrigin::ColumnOn,
1056            },
1057        )
1058        .unwrap();
1059        assert_eq!(out.count().unwrap(), 2); // right has id 1,3; left matches 1
1060    }
1061
1062    #[test]
1063    fn outer_join() {
1064        let left = left_df();
1065        let right = right_df();
1066        let out = join(
1067            &left,
1068            &right,
1069            vec!["id"],
1070            vec!["id"],
1071            JoinType::Outer,
1072            JoinOptions {
1073                case_sensitive: false,
1074                coalesce_same_name_keys: false,
1075                mark_join_keys_ambiguous: false,
1076                origin: JoinOrigin::ColumnOn,
1077            },
1078        )
1079        .unwrap();
1080        assert_eq!(out.count().unwrap(), 3);
1081    }
1082
1083    #[test]
1084    fn left_semi_join() {
1085        let left = left_df();
1086        let right = right_df();
1087        let out = join(
1088            &left,
1089            &right,
1090            vec!["id"],
1091            vec!["id"],
1092            JoinType::LeftSemi,
1093            JoinOptions {
1094                case_sensitive: false,
1095                coalesce_same_name_keys: false,
1096                mark_join_keys_ambiguous: false,
1097                origin: JoinOrigin::ColumnOn,
1098            },
1099        )
1100        .unwrap();
1101        assert_eq!(out.count().unwrap(), 1); // left rows with match in right (id 1)
1102    }
1103
1104    #[test]
1105    fn left_anti_join() {
1106        let left = left_df();
1107        let right = right_df();
1108        let out = join(
1109            &left,
1110            &right,
1111            vec!["id"],
1112            vec!["id"],
1113            JoinType::LeftAnti,
1114            JoinOptions {
1115                case_sensitive: false,
1116                coalesce_same_name_keys: false,
1117                mark_join_keys_ambiguous: false,
1118                origin: JoinOrigin::ColumnOn,
1119            },
1120        )
1121        .unwrap();
1122        assert_eq!(out.count().unwrap(), 1); // left rows with no match (id 2)
1123    }
1124
1125    #[test]
1126    fn join_empty_right() {
1127        let spark = SparkSession::builder()
1128            .app_name("join_tests")
1129            .get_or_create();
1130        let left = left_df();
1131        let right = spark
1132            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
1133            .unwrap();
1134        let out = join(
1135            &left,
1136            &right,
1137            vec!["id"],
1138            vec!["id"],
1139            JoinType::Inner,
1140            JoinOptions {
1141                case_sensitive: false,
1142                coalesce_same_name_keys: false,
1143                mark_join_keys_ambiguous: false,
1144                origin: JoinOrigin::ColumnOn,
1145            },
1146        )
1147        .unwrap();
1148        assert_eq!(out.count().unwrap(), 0);
1149    }
1150
1151    /// Join when key types differ (str on left, int on right): coerces to common type (#274).
1152    #[test]
1153    fn join_key_type_coercion_str_int() {
1154        use polars::prelude::df;
1155        let spark = SparkSession::builder()
1156            .app_name("join_tests")
1157            .get_or_create();
1158        let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
1159        let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
1160        let left = spark.create_dataframe_from_polars(left_pl);
1161        let right = spark.create_dataframe_from_polars(right_pl);
1162        let out = join(
1163            &left,
1164            &right,
1165            vec!["id"],
1166            vec!["id"],
1167            JoinType::Inner,
1168            JoinOptions {
1169                case_sensitive: false,
1170                coalesce_same_name_keys: false,
1171                mark_join_keys_ambiguous: false,
1172                origin: JoinOrigin::ColumnOn,
1173            },
1174        )
1175        .unwrap();
1176        assert_eq!(out.count().unwrap(), 1);
1177        let rows = out.collect().unwrap();
1178        assert_eq!(rows.height(), 1);
1179        // Join key was coerced to common type (string); row matched id "1" with id 1.
1180        assert!(rows.column("label").is_ok());
1181        assert!(rows.column("x").is_ok());
1182    }
1183
1184    /// #681: Join when key types differ (Int64 on left, String on right): coerces to common type (String).
1185    #[test]
1186    fn join_key_type_coercion_int_str() {
1187        use polars::prelude::df;
1188        let spark = SparkSession::builder()
1189            .app_name("join_tests")
1190            .get_or_create();
1191        let left_pl = df!("id" => &[1i64, 2i64], "name" => &["alice", "bob"]).unwrap();
1192        let right_pl = df!("id" => &["1", "3"], "value" => &[100i64, 300i64]).unwrap();
1193        let left = spark.create_dataframe_from_polars(left_pl);
1194        let right = spark.create_dataframe_from_polars(right_pl);
1195        let out = join(
1196            &left,
1197            &right,
1198            vec!["id"],
1199            vec!["id"],
1200            JoinType::Inner,
1201            JoinOptions {
1202                case_sensitive: false,
1203                coalesce_same_name_keys: false,
1204                mark_join_keys_ambiguous: false,
1205                origin: JoinOrigin::ColumnOn,
1206            },
1207        )
1208        .unwrap();
1209        assert_eq!(out.count().unwrap(), 1, "inner join on id: 1 match (id=1)");
1210        let rows = out.collect().unwrap();
1211        assert_eq!(rows.height(), 1);
1212        assert!(rows.column("id").is_ok());
1213        assert!(rows.column("name").is_ok());
1214        assert!(rows.column("value").is_ok());
1215    }
1216
1217    #[test]
1218    fn outer_join_then_groupby_on_key_matches_pyspark_semantics() {
1219        // Mirror tests/test_issue_280_join_groupby_ambiguity.py::test_outer_join_then_groupby:
1220        // left keys: 1, 3; right keys: 1, 2. Canonical join key is coalesce(left, right) so
1221        // grouping on "key" yields {1: 1, 2: 1, 3: 1} (unmatched right row keeps key=2; #1207).
1222        let spark = SparkSession::builder()
1223            .app_name("outer_join_groupby_tests")
1224            .get_or_create();
1225
1226        let left_tuples = vec![
1227            (1i64, 0i64, "L1".to_string()),
1228            (3i64, 0i64, "L3".to_string()),
1229        ];
1230        let right_tuples = vec![
1231            (1i64, 0i64, "R1".to_string()),
1232            (2i64, 0i64, "R2".to_string()),
1233        ];
1234
1235        let left = spark
1236            .create_dataframe(left_tuples, vec!["key", "extra_left", "left_val"])
1237            .unwrap();
1238        let right = spark
1239            .create_dataframe(right_tuples, vec!["key", "extra_right", "right_val"])
1240            .unwrap();
1241
1242        let joined = join(
1243            &left,
1244            &right,
1245            vec!["key"],
1246            vec!["key"],
1247            JoinType::Outer,
1248            JoinOptions {
1249                case_sensitive: false,
1250                coalesce_same_name_keys: true,
1251                mark_join_keys_ambiguous: false,
1252                origin: JoinOrigin::ColumnOn,
1253            },
1254        )
1255        .unwrap();
1256
1257        let grouped = joined.group_by(vec!["key"]).unwrap();
1258        let out = grouped.count().unwrap();
1259        let pl_df = out.collect().unwrap();
1260
1261        let key_col = pl_df.column("key").unwrap().i64().unwrap();
1262        let count_col = pl_df.column("count").unwrap().i64().unwrap();
1263
1264        let mut by_key: HashMap<Option<i64>, i64> = HashMap::new();
1265        for idx in 0..key_col.len() {
1266            let key = key_col.get(idx);
1267            let cnt = count_col.get(idx).unwrap_or(0);
1268            by_key.insert(key, cnt);
1269        }
1270
1271        // Expect exactly three groups: key=1, key=2, key=3 (coalesce(left, right) so right-only row keeps 2).
1272        assert_eq!(by_key.len(), 3);
1273        assert_eq!(by_key.get(&Some(1)).copied(), Some(1));
1274        assert_eq!(by_key.get(&Some(2)).copied(), Some(1));
1275        assert_eq!(by_key.get(&Some(3)).copied(), Some(1));
1276    }
1277
1278    /// Issue #604: join when key names differ in case (left "id", right "ID"); collect must not fail with "not found: ID".
1279    #[test]
1280    fn join_column_resolution_case_insensitive() {
1281        use polars::prelude::df;
1282        let spark = SparkSession::builder()
1283            .app_name("join_tests")
1284            .get_or_create();
1285        let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
1286        let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
1287        let left = spark.create_dataframe_from_polars(left_pl);
1288        let right = spark.create_dataframe_from_polars(right_pl);
1289        let out = join(
1290            &left,
1291            &right,
1292            vec!["id"],
1293            vec!["id"],
1294            JoinType::Inner,
1295            JoinOptions {
1296                case_sensitive: false,
1297                coalesce_same_name_keys: false,
1298                mark_join_keys_ambiguous: false,
1299                origin: JoinOrigin::ColumnOn,
1300            },
1301        )
1302        .expect("issue #604: join on id/ID must succeed");
1303        assert_eq!(out.count().unwrap(), 1);
1304        let rows = out
1305            .collect()
1306            .expect("issue #604: collect must not fail with 'not found: ID'");
1307        assert_eq!(rows.height(), 1);
1308        assert!(rows.column("id").is_ok());
1309        assert!(rows.column("val").is_ok());
1310        assert!(rows.column("other").is_ok());
1311        // Resolving "ID" (case-insensitive) must work.
1312        assert!(out.resolve_column_name("ID").is_ok());
1313    }
1314}