robin_sparkless/dataframe/
joins.rs1use super::DataFrame;
4use crate::type_coercion::find_common_type;
5use polars::prelude::Expr;
6use polars::prelude::JoinType as PlJoinType;
7use polars::prelude::PolarsError;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum JoinType {
12 Inner,
13 Left,
14 Right,
15 Outer,
16 LeftSemi,
18 LeftAnti,
20}
21
22pub fn join(
30 left: &DataFrame,
31 right: &DataFrame,
32 on: Vec<&str>,
33 how: JoinType,
34 case_sensitive: bool,
35) -> Result<DataFrame, PolarsError> {
36 use polars::prelude::{JoinBuilder, JoinCoalesce, col};
37 let mut left_lf = left.lazy_frame();
38 let mut right_lf = right.lazy_frame();
39
40 let left_key_names: Vec<String> = on
42 .iter()
43 .map(|k| {
44 left.resolve_column_name(k).map_err(|e| {
45 PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
46 })
47 })
48 .collect::<Result<_, _>>()?;
49 let right_key_names: Vec<String> = on
50 .iter()
51 .map(|k| {
52 right.resolve_column_name(k).map_err(|e| {
53 PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
54 })
55 })
56 .collect::<Result<_, _>>()?;
57
58 let mut left_casts: Vec<Expr> = Vec::new();
61 let mut right_casts: Vec<Expr> = Vec::new();
62 for (i, key) in on.iter().enumerate() {
63 let left_name = &left_key_names[i];
64 let right_name = &right_key_names[i];
65 let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
66 PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
67 })?;
68 let right_dtype = right.get_column_dtype(right_name.as_str()).ok_or_else(|| {
69 PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
70 })?;
71 let target_name = left_name.as_str();
72 if left_dtype != right_dtype {
73 let common = find_common_type(&left_dtype, &right_dtype)?;
74 left_casts.push(
75 col(left_name.as_str())
76 .cast(common.clone())
77 .alias(target_name),
78 );
79 right_casts.push(col(right_name.as_str()).cast(common).alias(target_name));
80 } else if left_name != right_name {
81 right_casts.push(col(right_name.as_str()).alias(target_name));
82 }
83 }
84 if !left_casts.is_empty() {
85 left_lf = left_lf.with_columns(left_casts);
86 }
87 if !right_casts.is_empty() {
88 right_lf = right_lf.with_columns(right_casts);
89 let drop_right: std::collections::HashSet<String> = on
92 .iter()
93 .enumerate()
94 .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
95 .map(|(i, _)| right_key_names[i].clone())
96 .collect();
97 if !drop_right.is_empty() {
98 let right_names = right.columns()?;
99 let mut keep_names: Vec<&str> = right_names
100 .iter()
101 .filter(|n| !drop_right.contains(*n))
102 .map(String::as_str)
103 .collect();
104 for (i, name) in left_key_names.iter().enumerate() {
105 if left_key_names[i] != right_key_names[i] {
106 keep_names.push(name.as_str());
107 }
108 }
109 let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
110 right_lf = right_lf.select(&keep);
111 }
112 }
113
114 let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
115 let on_exprs: Vec<polars::prelude::Expr> = left_key_names
116 .iter()
117 .map(|name| col(name.as_str()))
118 .collect();
119 let polars_how: PlJoinType = match how {
120 JoinType::Inner => PlJoinType::Inner,
121 JoinType::Left => PlJoinType::Left,
122 JoinType::Right => PlJoinType::Right,
123 JoinType::Outer => PlJoinType::Full, JoinType::LeftSemi => PlJoinType::Semi,
125 JoinType::LeftAnti => PlJoinType::Anti,
126 };
127
128 let mut joined = JoinBuilder::new(left_lf)
129 .with(right_lf)
130 .how(polars_how)
131 .on(&on_exprs)
132 .coalesce(JoinCoalesce::CoalesceColumns)
133 .finish();
134 let result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
136 let left_names = left.columns()?;
137 let right_names = right.columns()?;
138 let result_schema = joined.collect_schema()?;
139 let result_names: std::collections::HashSet<String> =
140 result_schema.iter_names().map(|s| s.to_string()).collect();
141 let mut order: Vec<String> = Vec::new();
142 for k in &left_key_names {
143 order.push(k.clone());
144 }
145 for n in &left_names {
146 if !on_set.contains(n) {
147 order.push(n.clone());
148 }
149 }
150 for n in &right_names {
151 let use_name = if left_names.iter().any(|l| l == n) {
152 format!("{n}_right")
153 } else {
154 n.clone()
155 };
156 if result_names.contains(&use_name) {
157 order.push(use_name);
158 }
159 }
160 if order.len() == result_names.len() {
161 let select_exprs: Vec<polars::prelude::Expr> =
162 order.iter().map(|s| col(s.as_str())).collect();
163 joined.select(select_exprs.as_slice())
164 } else {
165 joined
166 }
167 } else {
168 joined
169 };
170 Ok(super::DataFrame::from_lazy_with_options(
171 result_lf,
172 case_sensitive,
173 ))
174}
175
176#[cfg(test)]
177mod tests {
178 use super::{JoinType, join};
179 use crate::{DataFrame, SparkSession};
180
181 fn left_df() -> DataFrame {
182 let spark = SparkSession::builder()
183 .app_name("join_tests")
184 .get_or_create();
185 spark
186 .create_dataframe(
187 vec![
188 (1i64, 10i64, "a".to_string()),
189 (2i64, 20i64, "b".to_string()),
190 ],
191 vec!["id", "v", "label"],
192 )
193 .unwrap()
194 }
195
196 fn right_df() -> DataFrame {
197 let spark = SparkSession::builder()
198 .app_name("join_tests")
199 .get_or_create();
200 spark
201 .create_dataframe(
202 vec![
203 (1i64, 100i64, "x".to_string()),
204 (3i64, 300i64, "z".to_string()),
205 ],
206 vec!["id", "w", "tag"],
207 )
208 .unwrap()
209 }
210
211 #[test]
212 fn inner_join() {
213 let left = left_df();
214 let right = right_df();
215 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
216 assert_eq!(out.count().unwrap(), 1);
217 let cols = out.columns().unwrap();
218 assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
219 }
220
221 #[test]
222 fn left_join() {
223 let left = left_df();
224 let right = right_df();
225 let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
226 assert_eq!(out.count().unwrap(), 2);
227 }
228
229 #[test]
230 fn right_join() {
231 let left = left_df();
232 let right = right_df();
233 let out = join(&left, &right, vec!["id"], JoinType::Right, false).unwrap();
234 assert_eq!(out.count().unwrap(), 2); }
236
237 #[test]
238 fn outer_join() {
239 let left = left_df();
240 let right = right_df();
241 let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
242 assert_eq!(out.count().unwrap(), 3);
243 }
244
245 #[test]
246 fn left_semi_join() {
247 let left = left_df();
248 let right = right_df();
249 let out = join(&left, &right, vec!["id"], JoinType::LeftSemi, false).unwrap();
250 assert_eq!(out.count().unwrap(), 1); }
252
253 #[test]
254 fn left_anti_join() {
255 let left = left_df();
256 let right = right_df();
257 let out = join(&left, &right, vec!["id"], JoinType::LeftAnti, false).unwrap();
258 assert_eq!(out.count().unwrap(), 1); }
260
261 #[test]
262 fn join_empty_right() {
263 let spark = SparkSession::builder()
264 .app_name("join_tests")
265 .get_or_create();
266 let left = left_df();
267 let right = spark
268 .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
269 .unwrap();
270 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
271 assert_eq!(out.count().unwrap(), 0);
272 }
273
274 #[test]
276 fn join_key_type_coercion_str_int() {
277 use polars::prelude::df;
278 let spark = SparkSession::builder()
279 .app_name("join_tests")
280 .get_or_create();
281 let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
282 let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
283 let left = spark.create_dataframe_from_polars(left_pl);
284 let right = spark.create_dataframe_from_polars(right_pl);
285 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
286 assert_eq!(out.count().unwrap(), 1);
287 let rows = out.collect().unwrap();
288 assert_eq!(rows.height(), 1);
289 assert!(rows.column("label").is_ok());
291 assert!(rows.column("x").is_ok());
292 }
293
294 #[test]
296 fn join_column_resolution_case_insensitive() {
297 use polars::prelude::df;
298 let spark = SparkSession::builder()
299 .app_name("join_tests")
300 .get_or_create();
301 let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
302 let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
303 let left = spark.create_dataframe_from_polars(left_pl);
304 let right = spark.create_dataframe_from_polars(right_pl);
305 let out = join(&left, &right, vec!["id"], JoinType::Inner, false)
306 .expect("issue #604: join on id/ID must succeed");
307 assert_eq!(out.count().unwrap(), 1);
308 let rows = out
309 .collect()
310 .expect("issue #604: collect must not fail with 'not found: ID'");
311 assert_eq!(rows.height(), 1);
312 assert!(rows.column("id").is_ok());
313 assert!(rows.column("val").is_ok());
314 assert!(rows.column("other").is_ok());
315 assert!(out.resolve_column_name("ID").is_ok());
317 }
318}