Skip to main content

robin_sparkless/dataframe/
joins.rs

1//! Join operations for DataFrame.
2
3use super::DataFrame;
4use crate::type_coercion::find_common_type;
5use polars::prelude::Expr;
6use polars::prelude::JoinType as PlJoinType;
7use polars::prelude::PolarsError;
8
9/// Join type for DataFrame joins (PySpark-compatible)
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum JoinType {
12    Inner,
13    Left,
14    Right,
15    Outer,
16    /// Rows from left that have a match in right; only left columns (PySpark left_semi).
17    LeftSemi,
18    /// Rows from left that have no match in right; only left columns (PySpark left_anti).
19    LeftAnti,
20}
21
22/// Join with another DataFrame on the given columns. Preserves case_sensitive on result.
23/// When join key types differ (e.g. str vs int), coerces both sides to a common type (PySpark parity #274).
24/// When both tables have the same join key column name(s), renames the right's keys to temp names and
25/// uses left_on/right_on so Polars does not error with "duplicate column" (issue #580, PySpark parity).
26/// When left/right key names differ in casing (e.g. "id" vs "ID"), aliases right keys to left names
27/// so the result has one key column name and col("ID")/col("id") both resolve (PySpark parity #604).
28/// For Right and Outer, reorders columns to match PySpark: key(s), then left non-key, then right non-key.
29pub fn join(
30    left: &DataFrame,
31    right: &DataFrame,
32    on: Vec<&str>,
33    how: JoinType,
34    case_sensitive: bool,
35) -> Result<DataFrame, PolarsError> {
36    use polars::prelude::{JoinBuilder, JoinCoalesce, col};
37    let mut left_lf = left.lazy_frame();
38    let mut right_lf = right.lazy_frame();
39
40    // Resolve key names on both sides so we can alias right keys to left names (#604).
41    let left_key_names: Vec<String> = on
42        .iter()
43        .map(|k| {
44            left.resolve_column_name(k).map_err(|e| {
45                PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
46            })
47        })
48        .collect::<Result<_, _>>()?;
49    let right_key_names: Vec<String> = on
50        .iter()
51        .map(|k| {
52            right.resolve_column_name(k).map_err(|e| {
53                PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
54            })
55        })
56        .collect::<Result<_, _>>()?;
57
58    // Coerce join keys to a common type when left/right dtypes differ (PySpark #274).
59    // Alias right keys to left key names so result has one key column name (#604).
60    let mut left_casts: Vec<Expr> = Vec::new();
61    let mut right_casts: Vec<Expr> = Vec::new();
62    for (i, key) in on.iter().enumerate() {
63        let left_name = &left_key_names[i];
64        let right_name = &right_key_names[i];
65        let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
66            PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
67        })?;
68        let right_dtype = right.get_column_dtype(right_name.as_str()).ok_or_else(|| {
69            PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
70        })?;
71        let target_name = left_name.as_str();
72        if left_dtype != right_dtype {
73            let common = find_common_type(&left_dtype, &right_dtype)?;
74            left_casts.push(
75                col(left_name.as_str())
76                    .cast(common.clone())
77                    .alias(target_name),
78            );
79            right_casts.push(col(right_name.as_str()).cast(common).alias(target_name));
80        } else if left_name != right_name {
81            right_casts.push(col(right_name.as_str()).alias(target_name));
82        }
83    }
84    if !left_casts.is_empty() {
85        left_lf = left_lf.with_columns(left_casts);
86    }
87    if !right_casts.is_empty() {
88        right_lf = right_lf.with_columns(right_casts);
89        // #614: Drop right's original key columns when we aliased to left names, so the result
90        // has only the left key name (e.g. "id") and collect does not fail with "not found: ID".
91        let drop_right: std::collections::HashSet<String> = on
92            .iter()
93            .enumerate()
94            .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
95            .map(|(i, _)| right_key_names[i].clone())
96            .collect();
97        if !drop_right.is_empty() {
98            let right_names = right.columns()?;
99            let mut keep_names: Vec<&str> = right_names
100                .iter()
101                .filter(|n| !drop_right.contains(*n))
102                .map(String::as_str)
103                .collect();
104            for (i, name) in left_key_names.iter().enumerate() {
105                if left_key_names[i] != right_key_names[i] {
106                    keep_names.push(name.as_str());
107                }
108            }
109            let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
110            right_lf = right_lf.select(&keep);
111        }
112    }
113
114    let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
115    let on_exprs: Vec<polars::prelude::Expr> = left_key_names
116        .iter()
117        .map(|name| col(name.as_str()))
118        .collect();
119    let polars_how: PlJoinType = match how {
120        JoinType::Inner => PlJoinType::Inner,
121        JoinType::Left => PlJoinType::Left,
122        JoinType::Right => PlJoinType::Right,
123        JoinType::Outer => PlJoinType::Full, // PySpark Outer = Polars Full
124        JoinType::LeftSemi => PlJoinType::Semi,
125        JoinType::LeftAnti => PlJoinType::Anti,
126    };
127
128    let mut joined = JoinBuilder::new(left_lf)
129        .with(right_lf)
130        .how(polars_how)
131        .on(&on_exprs)
132        .coalesce(JoinCoalesce::CoalesceColumns)
133        .finish();
134    // For Right/Outer, reorder columns: keys, left non-keys, right non-keys (PySpark order).
135    let result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
136        let left_names = left.columns()?;
137        let right_names = right.columns()?;
138        let result_schema = joined.collect_schema()?;
139        let result_names: std::collections::HashSet<String> =
140            result_schema.iter_names().map(|s| s.to_string()).collect();
141        let mut order: Vec<String> = Vec::new();
142        for k in &left_key_names {
143            order.push(k.clone());
144        }
145        for n in &left_names {
146            if !on_set.contains(n) {
147                order.push(n.clone());
148            }
149        }
150        for n in &right_names {
151            let use_name = if left_names.iter().any(|l| l == n) {
152                format!("{n}_right")
153            } else {
154                n.clone()
155            };
156            if result_names.contains(&use_name) {
157                order.push(use_name);
158            }
159        }
160        if order.len() == result_names.len() {
161            let select_exprs: Vec<polars::prelude::Expr> =
162                order.iter().map(|s| col(s.as_str())).collect();
163            joined.select(select_exprs.as_slice())
164        } else {
165            joined
166        }
167    } else {
168        joined
169    };
170    Ok(super::DataFrame::from_lazy_with_options(
171        result_lf,
172        case_sensitive,
173    ))
174}
175
176#[cfg(test)]
177mod tests {
178    use super::{JoinType, join};
179    use crate::{DataFrame, SparkSession};
180
181    fn left_df() -> DataFrame {
182        let spark = SparkSession::builder()
183            .app_name("join_tests")
184            .get_or_create();
185        spark
186            .create_dataframe(
187                vec![
188                    (1i64, 10i64, "a".to_string()),
189                    (2i64, 20i64, "b".to_string()),
190                ],
191                vec!["id", "v", "label"],
192            )
193            .unwrap()
194    }
195
196    fn right_df() -> DataFrame {
197        let spark = SparkSession::builder()
198            .app_name("join_tests")
199            .get_or_create();
200        spark
201            .create_dataframe(
202                vec![
203                    (1i64, 100i64, "x".to_string()),
204                    (3i64, 300i64, "z".to_string()),
205                ],
206                vec!["id", "w", "tag"],
207            )
208            .unwrap()
209    }
210
211    #[test]
212    fn inner_join() {
213        let left = left_df();
214        let right = right_df();
215        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
216        assert_eq!(out.count().unwrap(), 1);
217        let cols = out.columns().unwrap();
218        assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
219    }
220
221    #[test]
222    fn left_join() {
223        let left = left_df();
224        let right = right_df();
225        let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
226        assert_eq!(out.count().unwrap(), 2);
227    }
228
229    #[test]
230    fn right_join() {
231        let left = left_df();
232        let right = right_df();
233        let out = join(&left, &right, vec!["id"], JoinType::Right, false).unwrap();
234        assert_eq!(out.count().unwrap(), 2); // right has id 1,3; left matches 1
235    }
236
237    #[test]
238    fn outer_join() {
239        let left = left_df();
240        let right = right_df();
241        let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
242        assert_eq!(out.count().unwrap(), 3);
243    }
244
245    #[test]
246    fn left_semi_join() {
247        let left = left_df();
248        let right = right_df();
249        let out = join(&left, &right, vec!["id"], JoinType::LeftSemi, false).unwrap();
250        assert_eq!(out.count().unwrap(), 1); // left rows with match in right (id 1)
251    }
252
253    #[test]
254    fn left_anti_join() {
255        let left = left_df();
256        let right = right_df();
257        let out = join(&left, &right, vec!["id"], JoinType::LeftAnti, false).unwrap();
258        assert_eq!(out.count().unwrap(), 1); // left rows with no match (id 2)
259    }
260
261    #[test]
262    fn join_empty_right() {
263        let spark = SparkSession::builder()
264            .app_name("join_tests")
265            .get_or_create();
266        let left = left_df();
267        let right = spark
268            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
269            .unwrap();
270        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
271        assert_eq!(out.count().unwrap(), 0);
272    }
273
274    /// Join when key types differ (str on left, int on right): coerces to common type (#274).
275    #[test]
276    fn join_key_type_coercion_str_int() {
277        use polars::prelude::df;
278        let spark = SparkSession::builder()
279            .app_name("join_tests")
280            .get_or_create();
281        let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
282        let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
283        let left = spark.create_dataframe_from_polars(left_pl);
284        let right = spark.create_dataframe_from_polars(right_pl);
285        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
286        assert_eq!(out.count().unwrap(), 1);
287        let rows = out.collect().unwrap();
288        assert_eq!(rows.height(), 1);
289        // Join key was coerced to common type (string); row matched id "1" with id 1.
290        assert!(rows.column("label").is_ok());
291        assert!(rows.column("x").is_ok());
292    }
293
294    /// Issue #604: join when key names differ in case (left "id", right "ID"); collect must not fail with "not found: ID".
295    #[test]
296    fn join_column_resolution_case_insensitive() {
297        use polars::prelude::df;
298        let spark = SparkSession::builder()
299            .app_name("join_tests")
300            .get_or_create();
301        let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
302        let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
303        let left = spark.create_dataframe_from_polars(left_pl);
304        let right = spark.create_dataframe_from_polars(right_pl);
305        let out = join(&left, &right, vec!["id"], JoinType::Inner, false)
306            .expect("issue #604: join on id/ID must succeed");
307        assert_eq!(out.count().unwrap(), 1);
308        let rows = out
309            .collect()
310            .expect("issue #604: collect must not fail with 'not found: ID'");
311        assert_eq!(rows.height(), 1);
312        assert!(rows.column("id").is_ok());
313        assert!(rows.column("val").is_ok());
314        assert!(rows.column("other").is_ok());
315        // Resolving "ID" (case-insensitive) must work.
316        assert!(out.resolve_column_name("ID").is_ok());
317    }
318}