Skip to main content

robin_sparkless/dataframe/
joins.rs

1//! Join operations for DataFrame.
2
3use super::DataFrame;
4use polars::prelude::JoinType as PlJoinType;
5use polars::prelude::PolarsError;
6
7/// Join type for DataFrame joins (PySpark-compatible)
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum JoinType {
10    Inner,
11    Left,
12    Right,
13    Outer,
14    /// Rows from left that have a match in right; only left columns (PySpark left_semi).
15    LeftSemi,
16    /// Rows from left that have no match in right; only left columns (PySpark left_anti).
17    LeftAnti,
18}
19
20/// Join with another DataFrame on the given columns. Preserves case_sensitive on result.
21/// For Right and Outer, reorders columns to match PySpark: key(s), then left non-key, then right non-key.
22pub fn join(
23    left: &DataFrame,
24    right: &DataFrame,
25    on: Vec<&str>,
26    how: JoinType,
27    case_sensitive: bool,
28) -> Result<DataFrame, PolarsError> {
29    use polars::prelude::{col, IntoLazy, JoinBuilder, JoinCoalesce};
30    let left_lf = left.df.as_ref().clone().lazy();
31    let right_lf = right.df.as_ref().clone().lazy();
32    let on_set: std::collections::HashSet<&str> = on.iter().copied().collect();
33    let on_exprs: Vec<polars::prelude::Expr> = on.iter().map(|name| col(*name)).collect();
34    let polars_how: PlJoinType = match how {
35        JoinType::Inner => PlJoinType::Inner,
36        JoinType::Left => PlJoinType::Left,
37        JoinType::Right => PlJoinType::Right,
38        JoinType::Outer => PlJoinType::Full, // PySpark Outer = Polars Full
39        JoinType::LeftSemi => PlJoinType::Semi,
40        JoinType::LeftAnti => PlJoinType::Anti,
41    };
42    let joined = JoinBuilder::new(left_lf)
43        .with(right_lf)
44        .how(polars_how)
45        .on(&on_exprs)
46        .coalesce(JoinCoalesce::CoalesceColumns)
47        .finish();
48    let mut pl_df = joined.collect()?;
49    if matches!(how, JoinType::Right | JoinType::Outer) {
50        let left_names: Vec<String> = left
51            .df
52            .get_column_names()
53            .iter()
54            .map(|s| s.to_string())
55            .collect();
56        let right_names: Vec<String> = right
57            .df
58            .get_column_names()
59            .iter()
60            .map(|s| s.to_string())
61            .collect();
62        let result_names: std::collections::HashSet<String> = pl_df
63            .get_column_names()
64            .iter()
65            .map(|s| s.to_string())
66            .collect();
67        let mut order: Vec<String> = Vec::new();
68        for k in &on {
69            order.push((*k).to_string());
70        }
71        for n in &left_names {
72            if !on_set.contains(n.as_str()) {
73                order.push(n.clone());
74            }
75        }
76        for n in &right_names {
77            let use_name = if left_names.iter().any(|l| l == n) {
78                format!("{n}_right")
79            } else {
80                n.clone()
81            };
82            if result_names.contains(&use_name) {
83                order.push(use_name);
84            }
85        }
86        if order.len() == result_names.len() {
87            let select_refs: Vec<&str> = order.iter().map(String::as_str).collect();
88            pl_df = pl_df.select(select_refs).map_err(|e| {
89                PolarsError::ComputeError(format!("join column reorder: {e}").into())
90            })?;
91        }
92    }
93    Ok(super::DataFrame::from_polars_with_options(
94        pl_df,
95        case_sensitive,
96    ))
97}
98
99#[cfg(test)]
100mod tests {
101    use super::{join, JoinType};
102    use crate::{DataFrame, SparkSession};
103
104    fn left_df() -> DataFrame {
105        let spark = SparkSession::builder()
106            .app_name("join_tests")
107            .get_or_create();
108        spark
109            .create_dataframe(
110                vec![
111                    (1i64, 10i64, "a".to_string()),
112                    (2i64, 20i64, "b".to_string()),
113                ],
114                vec!["id", "v", "label"],
115            )
116            .unwrap()
117    }
118
119    fn right_df() -> DataFrame {
120        let spark = SparkSession::builder()
121            .app_name("join_tests")
122            .get_or_create();
123        spark
124            .create_dataframe(
125                vec![
126                    (1i64, 100i64, "x".to_string()),
127                    (3i64, 300i64, "z".to_string()),
128                ],
129                vec!["id", "w", "tag"],
130            )
131            .unwrap()
132    }
133
134    #[test]
135    fn inner_join() {
136        let left = left_df();
137        let right = right_df();
138        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
139        assert_eq!(out.count().unwrap(), 1);
140        let cols = out.columns().unwrap();
141        assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
142    }
143
144    #[test]
145    fn left_join() {
146        let left = left_df();
147        let right = right_df();
148        let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
149        assert_eq!(out.count().unwrap(), 2);
150    }
151
152    #[test]
153    fn outer_join() {
154        let left = left_df();
155        let right = right_df();
156        let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
157        assert_eq!(out.count().unwrap(), 3);
158    }
159
160    #[test]
161    fn join_empty_right() {
162        let spark = SparkSession::builder()
163            .app_name("join_tests")
164            .get_or_create();
165        let left = left_df();
166        let right = spark
167            .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
168            .unwrap();
169        let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
170        assert_eq!(out.count().unwrap(), 0);
171    }
172}