Skip to main content

robin_sparkless_polars/dataframe/
joins.rs

1//! Join operations for DataFrame.
2
3use super::DataFrame;
4use crate::type_coercion::coerce_expr_pair;
5use polars::prelude::Expr;
6use polars::prelude::JoinType as PlJoinType;
7use polars::prelude::PolarsError;
8
9/// Join type for DataFrame joins (PySpark-compatible)
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum JoinType {
12    Inner,
13    Left,
14    Right,
15    Outer,
16    /// Rows from left that have a match in right; only left columns (PySpark left_semi).
17    LeftSemi,
18    /// Rows from left that have no match in right; only left columns (PySpark left_anti).
19    LeftAnti,
20}
21
22/// Join with another DataFrame on the given columns. Preserves case_sensitive on result.
23/// When join key types differ (e.g. str vs int), coerces both sides to a common type (PySpark parity #274).
24/// When both tables have the same join key column name(s), renames the right's keys to temp names and
25/// uses left_on/right_on so Polars does not error with "duplicate column" (issue #580, PySpark parity).
26/// When left/right key names differ in casing (e.g. "id" vs "ID"), aliases right keys to left names
27/// so the result has one key column name and col("ID")/col("id") both resolve (PySpark parity #604).
28/// For Right and Outer, reorders columns to match PySpark: key(s), then left non-key, then right non-key.
29pub fn join(
30    left: &DataFrame,
31    right: &DataFrame,
32    on: Vec<&str>,
33    how: JoinType,
34    case_sensitive: bool,
35) -> Result<DataFrame, PolarsError> {
36    use polars::prelude::{JoinBuilder, JoinCoalesce, col};
37    let mut left_lf = left.lazy_frame();
38    let mut right_lf = right.lazy_frame();
39
40    // Resolve key names on both sides so we can alias right keys to left names (#604).
41    let left_key_names: Vec<String> = on
42        .iter()
43        .map(|k| {
44            left.resolve_column_name(k).map_err(|e| {
45                PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
46            })
47        })
48        .collect::<Result<_, _>>()?;
49    let right_key_names: Vec<String> = on
50        .iter()
51        .map(|k| {
52            right.resolve_column_name(k).map_err(|e| {
53                PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
54            })
55        })
56        .collect::<Result<_, _>>()?;
57
58    // Coerce join keys to a common type when left/right dtypes differ (PySpark #274).
59    // Alias right keys to left key names so result has one key column name (#604).
60    let mut left_casts: Vec<Expr> = Vec::new();
61    let mut right_casts: Vec<Expr> = Vec::new();
62    for (i, key) in on.iter().enumerate() {
63        let left_name = &left_key_names[i];
64        let right_name = &right_key_names[i];
65        let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
66            PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
67        })?;
68        let right_dtype = right.get_column_dtype(right_name.as_str()).ok_or_else(|| {
69            PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
70        })?;
71        let target_name = left_name.as_str();
72        if left_dtype != right_dtype {
73            let (l, r) = coerce_expr_pair(
74                left_name.as_str(),
75                right_name.as_str(),
76                &left_dtype,
77                &right_dtype,
78                target_name,
79            )?;
80            left_casts.push(l);
81            right_casts.push(r);
82        } else if left_name != right_name {
83            right_casts.push(col(right_name.as_str()).alias(target_name));
84        }
85    }
86    if !left_casts.is_empty() {
87        left_lf = left_lf.with_columns(left_casts);
88    }
89    if !right_casts.is_empty() {
90        right_lf = right_lf.with_columns(right_casts);
91        // #614: Drop right's original key columns when we aliased to left names, so the result
92        // has only the left key name (e.g. "id") and collect does not fail with "not found: ID".
93        let drop_right: std::collections::HashSet<String> = on
94            .iter()
95            .enumerate()
96            .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
97            .map(|(i, _)| right_key_names[i].clone())
98            .collect();
99        if !drop_right.is_empty() {
100            let right_names = right.columns()?;
101            let mut keep_names: Vec<&str> = right_names
102                .iter()
103                .filter(|n| !drop_right.contains(*n))
104                .map(String::as_str)
105                .collect();
106            for (i, name) in left_key_names.iter().enumerate() {
107                if left_key_names[i] != right_key_names[i] {
108                    keep_names.push(name.as_str());
109                }
110            }
111            let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
112            right_lf = right_lf.select(&keep);
113        }
114    }
115
116    let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
117    let on_exprs: Vec<polars::prelude::Expr> = left_key_names
118        .iter()
119        .map(|name| col(name.as_str()))
120        .collect();
121    let polars_how: PlJoinType = match how {
122        JoinType::Inner => PlJoinType::Inner,
123        JoinType::Left => PlJoinType::Left,
124        JoinType::Right => PlJoinType::Right,
125        JoinType::Outer => PlJoinType::Full, // PySpark Outer = Polars Full
126        JoinType::LeftSemi => PlJoinType::Semi,
127        JoinType::LeftAnti => PlJoinType::Anti,
128    };
129
130    let mut joined = JoinBuilder::new(left_lf)
131        .with(right_lf)
132        .how(polars_how)
133        .on(&on_exprs)
134        .coalesce(JoinCoalesce::CoalesceColumns)
135        .finish();
136    // For Right/Outer, reorder columns: keys, left non-keys, right non-keys (PySpark order).
137    let result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
138        let left_names = left.columns()?;
139        let right_names = right.columns()?;
140        let result_schema = joined.collect_schema()?;
141        let result_names: std::collections::HashSet<String> =
142            result_schema.iter_names().map(|s| s.to_string()).collect();
143        let mut order: Vec<String> = Vec::new();
144        for k in &left_key_names {
145            order.push(k.clone());
146        }
147        for n in &left_names {
148            if !on_set.contains(n) {
149                order.push(n.clone());
150            }
151        }
152        for n in &right_names {
153            let use_name = if left_names.iter().any(|l| l == n) {
154                format!("{n}_right")
155            } else {
156                n.clone()
157            };
158            if result_names.contains(&use_name) {
159                order.push(use_name);
160            }
161        }
162        if order.len() == result_names.len() {
163            let select_exprs: Vec<polars::prelude::Expr> =
164                order.iter().map(|s| col(s.as_str())).collect();
165            joined.select(select_exprs.as_slice())
166        } else {
167            joined
168        }
169    } else {
170        joined
171    };
172    Ok(super::DataFrame::from_lazy_with_options(
173        result_lf,
174        case_sensitive,
175    ))
176}
177
178#[cfg(test)]
179mod tests {
180    use super::{JoinType, join};
181    use crate::{DataFrame, SparkSession};
182
183    fn left_df() -> DataFrame {
184        let spark = SparkSession::builder()
185            .app_name("join_tests")
186            .get_or_create();
187        spark
188            .create_dataframe(
189                vec![
190                    (1i64, 10i64, "a".to_string()),
191                    (2i64, 20i64, "b".to_string()),
192                ],
193                vec!["id", "v", "label"],
194            )
195            .unwrap()
196    }
197
198    fn right_df() -> DataFrame {
199        let spark = SparkSession::builder()
200            .app_name("join_tests")
201            .get_or_create();
202        spark
203            .create_dataframe(
204                vec![
205                    (1i64, 100i64, "x".to_string()),
206                    (3i64, 300i64, "z".to_string()),
207                ],
208                vec!["id", "w", "tag"],
209            )
210            .unwrap()
211    }
212
213    #[test]
214    fn inner_join() {
215        let left = left_df();
216        let right = right_df();
217        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
218        assert_eq!(out.count().unwrap(), 1);
219        let cols = out.columns().unwrap();
220        assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
221    }
222
223    #[test]
224    fn left_join() {
225        let left = left_df();
226        let right = right_df();
227        let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
228        assert_eq!(out.count().unwrap(), 2);
229    }
230
231    #[test]
232    fn right_join() {
233        let left = left_df();
234        let right = right_df();
235        let out = join(&left, &right, vec!["id"], JoinType::Right, false).unwrap();
236        assert_eq!(out.count().unwrap(), 2); // right has id 1,3; left matches 1
237    }
238
239    #[test]
240    fn outer_join() {
241        let left = left_df();
242        let right = right_df();
243        let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
244        assert_eq!(out.count().unwrap(), 3);
245    }
246
247    #[test]
248    fn left_semi_join() {
249        let left = left_df();
250        let right = right_df();
251        let out = join(&left, &right, vec!["id"], JoinType::LeftSemi, false).unwrap();
252        assert_eq!(out.count().unwrap(), 1); // left rows with match in right (id 1)
253    }
254
255    #[test]
256    fn left_anti_join() {
257        let left = left_df();
258        let right = right_df();
259        let out = join(&left, &right, vec!["id"], JoinType::LeftAnti, false).unwrap();
260        assert_eq!(out.count().unwrap(), 1); // left rows with no match (id 2)
261    }
262
263    #[test]
264    fn join_empty_right() {
265        let spark = SparkSession::builder()
266            .app_name("join_tests")
267            .get_or_create();
268        let left = left_df();
269        let right = spark
270            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
271            .unwrap();
272        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
273        assert_eq!(out.count().unwrap(), 0);
274    }
275
276    /// Join when key types differ (str on left, int on right): coerces to common type (#274).
277    #[test]
278    fn join_key_type_coercion_str_int() {
279        use polars::prelude::df;
280        let spark = SparkSession::builder()
281            .app_name("join_tests")
282            .get_or_create();
283        let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
284        let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
285        let left = spark.create_dataframe_from_polars(left_pl);
286        let right = spark.create_dataframe_from_polars(right_pl);
287        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
288        assert_eq!(out.count().unwrap(), 1);
289        let rows = out.collect().unwrap();
290        assert_eq!(rows.height(), 1);
291        // Join key was coerced to common type (string); row matched id "1" with id 1.
292        assert!(rows.column("label").is_ok());
293        assert!(rows.column("x").is_ok());
294    }
295
296    /// Issue #604: join when key names differ in case (left "id", right "ID"); collect must not fail with "not found: ID".
297    #[test]
298    fn join_column_resolution_case_insensitive() {
299        use polars::prelude::df;
300        let spark = SparkSession::builder()
301            .app_name("join_tests")
302            .get_or_create();
303        let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
304        let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
305        let left = spark.create_dataframe_from_polars(left_pl);
306        let right = spark.create_dataframe_from_polars(right_pl);
307        let out = join(&left, &right, vec!["id"], JoinType::Inner, false)
308            .expect("issue #604: join on id/ID must succeed");
309        assert_eq!(out.count().unwrap(), 1);
310        let rows = out
311            .collect()
312            .expect("issue #604: collect must not fail with 'not found: ID'");
313        assert_eq!(rows.height(), 1);
314        assert!(rows.column("id").is_ok());
315        assert!(rows.column("val").is_ok());
316        assert!(rows.column("other").is_ok());
317        // Resolving "ID" (case-insensitive) must work.
318        assert!(out.resolve_column_name("ID").is_ok());
319    }
320}