robin-sparkless 0.11.9

PySpark-like DataFrame API in Rust on Polars; no JVM.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
//! Join operations for DataFrame.

use super::DataFrame;
use crate::type_coercion::find_common_type;
use polars::prelude::Expr;
use polars::prelude::IntoLazy;
use polars::prelude::JoinType as PlJoinType;
use polars::prelude::PolarsError;

/// Join type for DataFrame joins (PySpark-compatible)
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
    Inner,
    Left,
    Right,
    Outer,
    /// Rows from left that have a match in right; only left columns (PySpark left_semi).
    LeftSemi,
    /// Rows from left that have no match in right; only left columns (PySpark left_anti).
    LeftAnti,
}

/// Join with another DataFrame on the given columns. Preserves case_sensitive on result.
/// When join key types differ (e.g. str vs int), coerces both sides to a common type (PySpark parity #274).
/// When both tables have the same join key column name(s), renames the right's keys to temp names and
/// uses left_on/right_on so Polars does not error with "duplicate column" (issue #580, PySpark parity).
/// For Right and Outer, reorders columns to match PySpark: key(s), then left non-key, then right non-key.
pub fn join(
    left: &DataFrame,
    right: &DataFrame,
    on: Vec<&str>,
    how: JoinType,
    case_sensitive: bool,
) -> Result<DataFrame, PolarsError> {
    use polars::prelude::{col, JoinBuilder, JoinCoalesce};
    let mut left_lf = left.lazy_frame();
    let mut right_lf = right.lazy_frame();

    // Resolve right-side key column names (case-sensitive resolution).
    let right_key_names: Vec<String> = on
        .iter()
        .map(|key| {
            right.resolve_column_name(key).map_err(|_| {
                PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
            })
        })
        .collect::<Result<Vec<_>, PolarsError>>()?;

    // When both tables have the same key names, rename right's keys to avoid Polars "duplicate column" (#580).
    let right_join_key_temps: Vec<String> = (0..on.len())
        .map(|i| format!("__right_join_key_{i}"))
        .collect();
    let right_has_same_key_names = on
        .iter()
        .zip(right_key_names.iter())
        .any(|(l, r)| *l == r.as_str());
    if right_has_same_key_names {
        right_lf = right_lf.rename(
            right_key_names
                .iter()
                .map(|s| s.as_str())
                .collect::<Vec<_>>(),
            right_join_key_temps.clone(),
            true,
        );
    }

    // Coerce join keys to a common type when left/right dtypes differ (PySpark #274).
    let mut left_casts: Vec<Expr> = Vec::new();
    let mut right_casts: Vec<Expr> = Vec::new();
    for (i, key) in on.iter().enumerate() {
        let left_dtype = left.get_column_dtype(key).ok_or_else(|| {
            PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
        })?;
        let right_dtype = right.get_column_dtype(key).ok_or_else(|| {
            PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
        })?;
        if left_dtype != right_dtype {
            let common = find_common_type(&left_dtype, &right_dtype)?;
            left_casts.push(col(*key).cast(common.clone()).alias(*key));
            let right_key = if right_has_same_key_names {
                right_join_key_temps[i].as_str()
            } else {
                right_key_names[i].as_str()
            };
            right_casts.push(col(right_key).cast(common).alias(right_key));
        }
    }
    if !left_casts.is_empty() {
        left_lf = left_lf.with_columns(left_casts);
        right_lf = right_lf.with_columns(right_casts);
    }

    let on_set: std::collections::HashSet<&str> = on.iter().copied().collect();
    let polars_how: PlJoinType = match how {
        JoinType::Inner => PlJoinType::Inner,
        JoinType::Left => PlJoinType::Left,
        JoinType::Right => PlJoinType::Right,
        JoinType::Outer => PlJoinType::Full, // PySpark Outer = Polars Full
        JoinType::LeftSemi => PlJoinType::Semi,
        JoinType::LeftAnti => PlJoinType::Anti,
    };

    let left_on_exprs: Vec<polars::prelude::Expr> = on.iter().map(|name| col(*name)).collect();
    let right_on_exprs: Vec<polars::prelude::Expr> = if right_has_same_key_names {
        right_join_key_temps
            .iter()
            .map(|s| col(s.as_str()))
            .collect()
    } else {
        right_key_names.iter().map(|s| col(s.as_str())).collect()
    };

    let mut joined = if right_has_same_key_names {
        JoinBuilder::new(left_lf)
            .with(right_lf)
            .how(polars_how)
            .left_on(left_on_exprs)
            .right_on(right_on_exprs)
            .coalesce(JoinCoalesce::CoalesceColumns)
            .finish()
    } else {
        JoinBuilder::new(left_lf)
            .with(right_lf)
            .how(polars_how)
            .on(&left_on_exprs)
            .coalesce(JoinCoalesce::CoalesceColumns)
            .finish()
    };

    // When we renamed right keys, result may have __right_join_key_* (e.g. Right join); alias them back to key names.
    // For Right/Outer, lazy collect_schema() can report the left key name while execution outputs __right_join_key_*,
    // so we use the executed schema and return an eager DataFrame to avoid schema/plan mismatch.
    if right_has_same_key_names && matches!(how, JoinType::Right | JoinType::Outer) {
        let pl_df = joined.clone().collect()?;
        let schema = pl_df.schema();
        let has_temp = schema
            .iter_names()
            .any(|n| n.to_string().starts_with("__right_join_key_"));
        if has_temp {
            let exprs: Vec<polars::prelude::Expr> = schema
                .iter_names()
                .map(|name| {
                    let s = name.to_string();
                    for (i, key) in on.iter().enumerate() {
                        if s == format!("__right_join_key_{i}") {
                            return col(s.as_str()).alias(*key);
                        }
                    }
                    col(s.as_str())
                })
                .collect();
            let fixed = pl_df.lazy().select(exprs.as_slice()).collect()?;
            // Reorder to PySpark order: keys, left non-keys, right non-keys.
            let left_names = left.columns()?;
            let right_names = right.columns()?;
            let fixed_names_set: std::collections::HashSet<String> = fixed
                .get_column_names()
                .iter()
                .map(|s| s.to_string())
                .collect();
            let mut order: Vec<String> = Vec::new();
            for k in on {
                order.push((*k).to_string());
            }
            for n in &left_names {
                if !on_set.contains(n.as_str()) {
                    order.push(n.clone());
                }
            }
            for n in &right_names {
                let use_name = if left_names.iter().any(|l| l == n) {
                    format!("{n}_right")
                } else {
                    n.clone()
                };
                if fixed_names_set.contains(&use_name) {
                    order.push(use_name);
                }
            }
            let fixed_names: Vec<String> = fixed
                .get_column_names()
                .iter()
                .map(|s| s.to_string())
                .collect();
            if order.len() == fixed_names.len()
                && order.iter().all(|o| fixed_names.iter().any(|f| f == o))
            {
                let reordered = fixed.select(order.iter().map(|s| s.as_str()))?;
                return Ok(super::DataFrame::from_polars_with_options(
                    reordered,
                    case_sensitive,
                ));
            }
            return Ok(super::DataFrame::from_polars_with_options(
                fixed,
                case_sensitive,
            ));
        }
    }

    // When we renamed right keys (non-Right/Outer), alias temp key columns if present in lazy schema.
    if right_has_same_key_names {
        let result_schema = joined.collect_schema()?;
        let has_temp_keys = result_schema
            .iter_names()
            .any(|n| n.to_string().starts_with("__right_join_key_"));
        if has_temp_keys {
            let exprs: Vec<polars::prelude::Expr> = result_schema
                .iter_names()
                .map(|name| {
                    let s = name.to_string();
                    for (i, key) in on.iter().enumerate() {
                        if s == format!("__right_join_key_{i}") {
                            return col(s.as_str()).alias(*key);
                        }
                    }
                    col(s.as_str())
                })
                .collect();
            joined = joined.select(exprs.as_slice());
        }
    }

    // For Right/Outer, reorder columns: keys, left non-keys, right non-keys (PySpark order).
    let result_lf = if matches!(how, JoinType::Right | JoinType::Outer) && !right_has_same_key_names
    {
        let left_names = left.columns()?;
        let right_names = right.columns()?;
        let result_schema = joined.collect_schema()?;
        let result_names: std::collections::HashSet<String> =
            result_schema.iter_names().map(|s| s.to_string()).collect();
        let mut order: Vec<String> = Vec::new();
        for k in &on {
            order.push((*k).to_string());
        }
        for n in &left_names {
            if !on_set.contains(n.as_str()) {
                order.push(n.clone());
            }
        }
        for n in &right_names {
            let use_name = if left_names.iter().any(|l| l == n) {
                format!("{n}_right")
            } else {
                n.clone()
            };
            if result_names.contains(&use_name) {
                order.push(use_name);
            }
        }
        if order.len() == result_names.len() {
            let select_exprs: Vec<polars::prelude::Expr> =
                order.iter().map(|s| col(s.as_str())).collect();
            joined.select(select_exprs.as_slice())
        } else {
            joined
        }
    } else {
        joined
    };
    Ok(super::DataFrame::from_lazy_with_options(
        result_lf,
        case_sensitive,
    ))
}

#[cfg(test)]
mod tests {
    use super::{join, JoinType};
    use crate::{DataFrame, SparkSession};

    fn left_df() -> DataFrame {
        let spark = SparkSession::builder()
            .app_name("join_tests")
            .get_or_create();
        spark
            .create_dataframe(
                vec![
                    (1i64, 10i64, "a".to_string()),
                    (2i64, 20i64, "b".to_string()),
                ],
                vec!["id", "v", "label"],
            )
            .unwrap()
    }

    fn right_df() -> DataFrame {
        let spark = SparkSession::builder()
            .app_name("join_tests")
            .get_or_create();
        spark
            .create_dataframe(
                vec![
                    (1i64, 100i64, "x".to_string()),
                    (3i64, 300i64, "z".to_string()),
                ],
                vec!["id", "w", "tag"],
            )
            .unwrap()
    }

    #[test]
    fn inner_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
        assert_eq!(out.count().unwrap(), 1);
        let cols = out.columns().unwrap();
        assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
    }

    #[test]
    fn left_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
        assert_eq!(out.count().unwrap(), 2);
    }

    #[test]
    fn right_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::Right, false).unwrap();
        assert_eq!(out.count().unwrap(), 2); // right has id 1,3; left matches 1
    }

    #[test]
    fn outer_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
        assert_eq!(out.count().unwrap(), 3);
    }

    #[test]
    fn left_semi_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::LeftSemi, false).unwrap();
        assert_eq!(out.count().unwrap(), 1); // left rows with match in right (id 1)
    }

    #[test]
    fn left_anti_join() {
        let left = left_df();
        let right = right_df();
        let out = join(&left, &right, vec!["id"], JoinType::LeftAnti, false).unwrap();
        assert_eq!(out.count().unwrap(), 1); // left rows with no match (id 2)
    }

    #[test]
    fn join_empty_right() {
        let spark = SparkSession::builder()
            .app_name("join_tests")
            .get_or_create();
        let left = left_df();
        let right = spark
            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
            .unwrap();
        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
        assert_eq!(out.count().unwrap(), 0);
    }

    /// Join when key types differ (str on left, int on right): coerces to common type (#274).
    #[test]
    fn join_key_type_coercion_str_int() {
        use polars::prelude::df;
        let spark = SparkSession::builder()
            .app_name("join_tests")
            .get_or_create();
        let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
        let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
        let left = spark.create_dataframe_from_polars(left_pl);
        let right = spark.create_dataframe_from_polars(right_pl);
        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
        assert_eq!(out.count().unwrap(), 1);
        let rows = out.collect().unwrap();
        assert_eq!(rows.height(), 1);
        // Join key was coerced to common type (string); row matched id "1" with id 1.
        assert!(rows.column("label").is_ok());
        assert!(rows.column("x").is_ok());
    }

    /// Issue #580: join with column expression when both tables have same key name (e.g. dept_id) must not fail with "duplicate column".
    #[test]
    fn join_same_key_name_both_tables() {
        use polars::prelude::df;
        let spark = SparkSession::builder()
            .app_name("join_tests")
            .get_or_create();
        let emp = df![
            "id" => [1i64, 2i64, 3i64, 4i64],
            "name" => ["Alice", "Bob", "Charlie", "David"],
            "dept_id" => [10i64, 20i64, 10i64, 30i64],
            "salary" => [50000i64, 60000i64, 70000i64, 55000i64],
        ]
        .unwrap();
        let dept = df![
            "dept_id" => [10i64, 20i64, 40i64],
            "name" => ["IT", "HR", "Finance"],
            "location" => ["NYC", "LA", "Chicago"],
        ]
        .unwrap();
        let left = spark.create_dataframe_from_polars(emp);
        let right = spark.create_dataframe_from_polars(dept);
        let out = join(&left, &right, vec!["dept_id"], JoinType::Inner, false).unwrap();
        assert_eq!(
            out.count().unwrap(),
            3,
            "Alice, Bob, Charlie match dept 10, 20"
        );
        let cols = out.columns().unwrap();
        assert!(
            cols.iter().any(|c| c == "dept_id"),
            "one dept_id column in result"
        );
        assert!(cols.iter().any(|c| c == "location"));
    }
}