robin_sparkless_polars/dataframe/
joins.rs1use super::DataFrame;
4use crate::type_coercion::coerce_expr_pair;
5use polars::prelude::Expr;
6use polars::prelude::JoinType as PlJoinType;
7use polars::prelude::PolarsError;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum JoinType {
12 Inner,
13 Left,
14 Right,
15 Outer,
16 LeftSemi,
18 LeftAnti,
20}
21
22pub fn join(
30 left: &DataFrame,
31 right: &DataFrame,
32 on: Vec<&str>,
33 how: JoinType,
34 case_sensitive: bool,
35) -> Result<DataFrame, PolarsError> {
36 use polars::prelude::{JoinBuilder, JoinCoalesce, col};
37 let mut left_lf = left.lazy_frame();
38 let mut right_lf = right.lazy_frame();
39
40 let left_key_names: Vec<String> = on
42 .iter()
43 .map(|k| {
44 left.resolve_column_name(k).map_err(|e| {
45 PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
46 })
47 })
48 .collect::<Result<_, _>>()?;
49 let right_key_names: Vec<String> = on
50 .iter()
51 .map(|k| {
52 right.resolve_column_name(k).map_err(|e| {
53 PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
54 })
55 })
56 .collect::<Result<_, _>>()?;
57
58 let mut left_casts: Vec<Expr> = Vec::new();
61 let mut right_casts: Vec<Expr> = Vec::new();
62 for (i, key) in on.iter().enumerate() {
63 let left_name = &left_key_names[i];
64 let right_name = &right_key_names[i];
65 let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
66 PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
67 })?;
68 let right_dtype = right.get_column_dtype(right_name.as_str()).ok_or_else(|| {
69 PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
70 })?;
71 let target_name = left_name.as_str();
72 if left_dtype != right_dtype {
73 let (l, r) = coerce_expr_pair(
74 left_name.as_str(),
75 right_name.as_str(),
76 &left_dtype,
77 &right_dtype,
78 target_name,
79 )?;
80 left_casts.push(l);
81 right_casts.push(r);
82 } else if left_name != right_name {
83 right_casts.push(col(right_name.as_str()).alias(target_name));
84 }
85 }
86 if !left_casts.is_empty() {
87 left_lf = left_lf.with_columns(left_casts);
88 }
89 if !right_casts.is_empty() {
90 right_lf = right_lf.with_columns(right_casts);
91 let drop_right: std::collections::HashSet<String> = on
94 .iter()
95 .enumerate()
96 .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
97 .map(|(i, _)| right_key_names[i].clone())
98 .collect();
99 if !drop_right.is_empty() {
100 let right_names = right.columns()?;
101 let mut keep_names: Vec<&str> = right_names
102 .iter()
103 .filter(|n| !drop_right.contains(*n))
104 .map(String::as_str)
105 .collect();
106 for (i, name) in left_key_names.iter().enumerate() {
107 if left_key_names[i] != right_key_names[i] {
108 keep_names.push(name.as_str());
109 }
110 }
111 let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
112 right_lf = right_lf.select(&keep);
113 }
114 }
115
116 let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
117 let on_exprs: Vec<polars::prelude::Expr> = left_key_names
118 .iter()
119 .map(|name| col(name.as_str()))
120 .collect();
121 let polars_how: PlJoinType = match how {
122 JoinType::Inner => PlJoinType::Inner,
123 JoinType::Left => PlJoinType::Left,
124 JoinType::Right => PlJoinType::Right,
125 JoinType::Outer => PlJoinType::Full, JoinType::LeftSemi => PlJoinType::Semi,
127 JoinType::LeftAnti => PlJoinType::Anti,
128 };
129
130 let mut joined = JoinBuilder::new(left_lf)
131 .with(right_lf)
132 .how(polars_how)
133 .on(&on_exprs)
134 .coalesce(JoinCoalesce::CoalesceColumns)
135 .finish();
136 let result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
138 let left_names = left.columns()?;
139 let right_names = right.columns()?;
140 let result_schema = joined.collect_schema()?;
141 let result_names: std::collections::HashSet<String> =
142 result_schema.iter_names().map(|s| s.to_string()).collect();
143 let mut order: Vec<String> = Vec::new();
144 for k in &left_key_names {
145 order.push(k.clone());
146 }
147 for n in &left_names {
148 if !on_set.contains(n) {
149 order.push(n.clone());
150 }
151 }
152 for n in &right_names {
153 let use_name = if left_names.iter().any(|l| l == n) {
154 format!("{n}_right")
155 } else {
156 n.clone()
157 };
158 if result_names.contains(&use_name) {
159 order.push(use_name);
160 }
161 }
162 if order.len() == result_names.len() {
163 let select_exprs: Vec<polars::prelude::Expr> =
164 order.iter().map(|s| col(s.as_str())).collect();
165 joined.select(select_exprs.as_slice())
166 } else {
167 joined
168 }
169 } else {
170 joined
171 };
172 Ok(super::DataFrame::from_lazy_with_options(
173 result_lf,
174 case_sensitive,
175 ))
176}
177
178#[cfg(test)]
179mod tests {
180 use super::{JoinType, join};
181 use crate::{DataFrame, SparkSession};
182
183 fn left_df() -> DataFrame {
184 let spark = SparkSession::builder()
185 .app_name("join_tests")
186 .get_or_create();
187 spark
188 .create_dataframe(
189 vec![
190 (1i64, 10i64, "a".to_string()),
191 (2i64, 20i64, "b".to_string()),
192 ],
193 vec!["id", "v", "label"],
194 )
195 .unwrap()
196 }
197
198 fn right_df() -> DataFrame {
199 let spark = SparkSession::builder()
200 .app_name("join_tests")
201 .get_or_create();
202 spark
203 .create_dataframe(
204 vec![
205 (1i64, 100i64, "x".to_string()),
206 (3i64, 300i64, "z".to_string()),
207 ],
208 vec!["id", "w", "tag"],
209 )
210 .unwrap()
211 }
212
213 #[test]
214 fn inner_join() {
215 let left = left_df();
216 let right = right_df();
217 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
218 assert_eq!(out.count().unwrap(), 1);
219 let cols = out.columns().unwrap();
220 assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
221 }
222
223 #[test]
224 fn left_join() {
225 let left = left_df();
226 let right = right_df();
227 let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
228 assert_eq!(out.count().unwrap(), 2);
229 }
230
231 #[test]
232 fn right_join() {
233 let left = left_df();
234 let right = right_df();
235 let out = join(&left, &right, vec!["id"], JoinType::Right, false).unwrap();
236 assert_eq!(out.count().unwrap(), 2); }
238
239 #[test]
240 fn outer_join() {
241 let left = left_df();
242 let right = right_df();
243 let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
244 assert_eq!(out.count().unwrap(), 3);
245 }
246
247 #[test]
248 fn left_semi_join() {
249 let left = left_df();
250 let right = right_df();
251 let out = join(&left, &right, vec!["id"], JoinType::LeftSemi, false).unwrap();
252 assert_eq!(out.count().unwrap(), 1); }
254
255 #[test]
256 fn left_anti_join() {
257 let left = left_df();
258 let right = right_df();
259 let out = join(&left, &right, vec!["id"], JoinType::LeftAnti, false).unwrap();
260 assert_eq!(out.count().unwrap(), 1); }
262
263 #[test]
264 fn join_empty_right() {
265 let spark = SparkSession::builder()
266 .app_name("join_tests")
267 .get_or_create();
268 let left = left_df();
269 let right = spark
270 .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
271 .unwrap();
272 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
273 assert_eq!(out.count().unwrap(), 0);
274 }
275
276 #[test]
278 fn join_key_type_coercion_str_int() {
279 use polars::prelude::df;
280 let spark = SparkSession::builder()
281 .app_name("join_tests")
282 .get_or_create();
283 let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
284 let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
285 let left = spark.create_dataframe_from_polars(left_pl);
286 let right = spark.create_dataframe_from_polars(right_pl);
287 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
288 assert_eq!(out.count().unwrap(), 1);
289 let rows = out.collect().unwrap();
290 assert_eq!(rows.height(), 1);
291 assert!(rows.column("label").is_ok());
293 assert!(rows.column("x").is_ok());
294 }
295
296 #[test]
298 fn join_column_resolution_case_insensitive() {
299 use polars::prelude::df;
300 let spark = SparkSession::builder()
301 .app_name("join_tests")
302 .get_or_create();
303 let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
304 let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
305 let left = spark.create_dataframe_from_polars(left_pl);
306 let right = spark.create_dataframe_from_polars(right_pl);
307 let out = join(&left, &right, vec!["id"], JoinType::Inner, false)
308 .expect("issue #604: join on id/ID must succeed");
309 assert_eq!(out.count().unwrap(), 1);
310 let rows = out
311 .collect()
312 .expect("issue #604: collect must not fail with 'not found: ID'");
313 assert_eq!(rows.height(), 1);
314 assert!(rows.column("id").is_ok());
315 assert!(rows.column("val").is_ok());
316 assert!(rows.column("other").is_ok());
317 assert!(out.resolve_column_name("ID").is_ok());
319 }
320}