1use std::collections::HashSet;
4
5use super::DataFrame;
6use crate::schema_conv::data_type_to_polars_type;
7use crate::type_coercion::coerce_expr_pair_for_join;
8use polars::prelude::{
9 DataType as PlDataType, Expr, JoinType as PlJoinType, Operator, PolarsError,
10 SchemaNamesAndDtypes, coalesce as pl_coalesce,
11};
12use polars_plan::dsl::functions::nth;
13
14fn expr_to_column_name(expr: &Expr) -> Option<String> {
15 use polars::prelude::Expr as PlExpr;
16 let mut e = expr;
17 loop {
18 match e {
19 PlExpr::Column(n) => return Some(n.as_str().to_string()),
20 PlExpr::Alias(inner, _) | PlExpr::Cast { expr: inner, .. } => e = inner.as_ref(),
21 _ => return None,
22 }
23 }
24}
25
26pub fn try_extract_join_eq_columns(expr: &Expr) -> Option<(String, String)> {
32 try_extract_join_eq_columns_all(expr).into_iter().next()
33}
34
35pub fn try_extract_join_eq_columns_all(expr: &Expr) -> Vec<(String, String)> {
38 use polars::prelude::Expr as PlExpr;
39
40 fn inner_extract_all(e: &Expr, out: &mut Vec<(String, String)>) {
41 let mut current = e;
42 while let PlExpr::Alias(inner, _) = current {
43 current = inner.as_ref();
44 }
45 match current {
46 PlExpr::BinaryExpr {
47 left,
48 op: Operator::Eq | Operator::EqValidity,
49 right,
50 } => {
51 if let (Some(l), Some(r)) = (
52 expr_to_column_name(left.as_ref()),
53 expr_to_column_name(right.as_ref()),
54 ) {
55 out.push((l, r));
56 }
57 }
58 PlExpr::BinaryExpr {
59 left,
60 op: Operator::And,
61 right,
62 } => {
63 inner_extract_all(left.as_ref(), out);
64 inner_extract_all(right.as_ref(), out);
65 }
66 _ => {}
67 }
68 }
69
70 let mut pairs = Vec::new();
71 inner_extract_all(expr, &mut pairs);
72 pairs
73}
74
75pub fn expr_contains_only_join_key_equalities(expr: &Expr) -> bool {
79 use polars::prelude::Expr as PlExpr;
80 fn only_join_equalities(e: &Expr) -> bool {
81 let mut current = e;
82 while let PlExpr::Alias(inner, _) = current {
83 current = inner.as_ref();
84 }
85 match current {
86 PlExpr::BinaryExpr {
87 left,
88 op: Operator::Eq | Operator::EqValidity,
89 right,
90 } => {
91 expr_to_column_name(left.as_ref()).is_some()
92 && expr_to_column_name(right.as_ref()).is_some()
93 }
94 PlExpr::BinaryExpr {
95 left,
96 op: Operator::And,
97 right,
98 } => only_join_equalities(left.as_ref()) && only_join_equalities(right.as_ref()),
99 _ => false,
100 }
101 }
102 only_join_equalities(expr)
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum JoinType {
108 Inner,
109 Left,
110 Right,
111 Outer,
112 LeftSemi,
114 LeftAnti,
116}
117
118#[derive(Debug, Clone, Copy)]
120pub enum JoinOrigin {
121 ColumnOn,
123 Condition,
126}
127
128pub struct JoinOptions {
129 pub case_sensitive: bool,
130 pub coalesce_same_name_keys: bool,
131 pub mark_join_keys_ambiguous: bool,
132 pub origin: JoinOrigin,
133}
134
135pub fn join(
151 left: &DataFrame,
152 right: &DataFrame,
153 left_on: Vec<&str>,
154 right_on: Vec<&str>,
155 how: JoinType,
156 options: JoinOptions,
157) -> Result<DataFrame, PolarsError> {
158 let JoinOptions {
159 case_sensitive,
160 coalesce_same_name_keys,
161 mark_join_keys_ambiguous,
162 origin,
163 } = options;
164 use polars::prelude::{JoinBuilder, JoinCoalesce, col};
165 if left_on.len() != right_on.len() {
166 return Err(PolarsError::ComputeError(
167 "join: left_on and right_on must have the same length".into(),
168 ));
169 }
170 let mut left_lf = left.lazy_frame();
171 let mut right_lf = right.lazy_frame();
172 let mut outer_left_key_copies: Vec<(String, String)> = Vec::new();
175 let mut outer_join_renamed_right_keys: Vec<String> = Vec::new();
178
179 let left_key_names: Vec<String> = left_on
181 .iter()
182 .map(|k| {
183 left.resolve_column_name(k).map_err(|e| {
184 PolarsError::ComputeError(format!("join key '{k}' on left: {e}").into())
185 })
186 })
187 .collect::<Result<_, _>>()?;
188 let mut right_key_names: Vec<String> = right_on
189 .iter()
190 .map(|k| {
191 right.resolve_column_name(k).map_err(|e| {
192 PolarsError::ComputeError(format!("join key '{k}' on right: {e}").into())
193 })
194 })
195 .collect::<Result<_, _>>()?;
196 if matches!(how, JoinType::Outer)
200 && coalesce_same_name_keys
201 && matches!(origin, JoinOrigin::ColumnOn)
202 {
203 use polars::prelude::col;
204 let mut copy_exprs: Vec<Expr> = Vec::new();
205 for name in &left_key_names {
206 let temp = format!("__rs_outer_key_{}", name);
207 outer_left_key_copies.push((name.clone(), temp.clone()));
208 copy_exprs.push(col(name.as_str()).alias(temp.as_str()));
209 }
210 if !copy_exprs.is_empty() {
211 left_lf = left_lf.with_columns(copy_exprs);
212 }
213 }
214 if matches!(origin, JoinOrigin::Condition)
221 && matches!(how, JoinType::Outer)
222 && left_key_names.len() == right_key_names.len()
223 && left_key_names
224 .iter()
225 .zip(right_key_names.iter())
226 .all(|(a, b)| a.eq_ignore_ascii_case(b))
227 {
228 use polars::prelude::col;
229 use std::collections::HashMap;
230 let mut rename_map: HashMap<String, String> = HashMap::new();
231 for name in &right_key_names {
232 rename_map.insert(name.clone(), format!("{name}_right"));
233 }
234 if !rename_map.is_empty() {
235 let current_names: Vec<String> = right.columns()?.into_iter().collect();
236 let exprs: Vec<Expr> = current_names
237 .iter()
238 .map(|n| {
239 if let Some(new_name) = rename_map.get(n) {
240 col(n.as_str()).alias(new_name.as_str())
241 } else {
242 col(n.as_str())
243 }
244 })
245 .collect();
246 right_lf = right_lf.select(&exprs);
247 for rk in &mut right_key_names {
248 if let Some(new_name) = rename_map.get(rk) {
249 *rk = new_name.clone();
250 }
251 }
252 }
253 }
254
255 if matches!(how, JoinType::Outer)
261 && coalesce_same_name_keys
262 && left_key_names == right_key_names
263 && matches!(origin, JoinOrigin::ColumnOn)
264 {
265 use polars::prelude::col;
266 use std::collections::HashMap;
267 let mut rename_map: HashMap<String, String> = HashMap::new();
268 for name in &right_key_names {
269 rename_map.insert(name.clone(), format!("{name}_right"));
270 }
271 if !rename_map.is_empty() {
272 let current_names: Vec<String> = right.columns()?.into_iter().collect();
273 let exprs: Vec<Expr> = current_names
274 .iter()
275 .map(|n| {
276 if let Some(new_name) = rename_map.get(n) {
277 col(n.as_str()).alias(new_name.as_str())
278 } else {
279 col(n.as_str())
280 }
281 })
282 .collect();
283 right_lf = right_lf.select(&exprs);
284 for rk in &mut right_key_names {
287 if let Some(new_name) = rename_map.get(rk) {
288 *rk = new_name.clone();
289 }
290 }
291 outer_join_renamed_right_keys = right_key_names.clone();
292 }
293 }
294
295 let keys_differ = left_key_names != right_key_names;
296 let keys_match_for_coalesce = !keys_differ
298 || (coalesce_same_name_keys
299 && !case_sensitive
300 && left_key_names.len() == right_key_names.len()
301 && left_key_names
302 .iter()
303 .zip(right_key_names.iter())
304 .all(|(a, b)| a.eq_ignore_ascii_case(b)));
305
306 if keys_match_for_coalesce {
307 let right_names: Vec<String> = right.columns()?.into_iter().collect();
309 let mut renames: std::collections::HashMap<String, String> =
310 std::collections::HashMap::new();
311 for (i, _) in left_on.iter().enumerate() {
312 let target_name = &left_key_names[i];
313 let right_key = &right_key_names[i];
314 if target_name != right_key && right_names.iter().any(|n| n == target_name) {
315 renames.insert(target_name.clone(), format!("{}_right", target_name));
316 }
317 }
318 if !renames.is_empty() {
319 let exprs: Vec<Expr> = right_names
320 .iter()
321 .map(|n| {
322 if let Some(suffix) = renames.get(n) {
323 col(n.as_str()).alias(suffix.as_str())
324 } else {
325 col(n.as_str())
326 }
327 })
328 .collect();
329 right_lf = right_lf.select(&exprs);
330 }
331
332 let left_schema = left.polars_schema()?;
336 let right_schema = right.polars_schema()?;
337 let mut left_casts: Vec<Expr> = Vec::new();
338 let mut right_casts: Vec<Expr> = Vec::new();
339 for (i, key) in left_on.iter().enumerate() {
340 let left_name = &left_key_names[i];
341 let right_name = &right_key_names[i];
342 let left_dtype = left_schema
343 .get(left_name.as_str())
344 .cloned()
345 .ok_or_else(|| {
346 PolarsError::ComputeError(format!("join key '{key}' not found on left").into())
347 })?;
348 let right_dtype = right_schema
349 .get(right_name.as_str())
350 .cloned()
351 .ok_or_else(|| {
352 PolarsError::ComputeError(format!("join key '{key}' not found on right").into())
353 })?;
354 let target_name = left_name.as_str();
355 if left_dtype != right_dtype {
356 let (l, r) = coerce_expr_pair_for_join(
357 left_name.as_str(),
358 right_name.as_str(),
359 &left_dtype,
360 &right_dtype,
361 target_name,
362 )?;
363 left_casts.push(l);
364 right_casts.push(r);
365 } else if left_name != right_name {
366 right_casts.push(col(right_name.as_str()).alias(target_name));
367 }
368 }
369 if !left_casts.is_empty() {
370 left_lf = left_lf.with_columns(left_casts);
371 }
372 if !right_casts.is_empty() {
373 right_lf = right_lf.with_columns(right_casts);
374 let drop_right: std::collections::HashSet<String> = left_on
375 .iter()
376 .enumerate()
377 .filter(|(i, _)| left_key_names[*i] != right_key_names[*i])
378 .map(|(i, _)| right_key_names[i].clone())
379 .collect();
380 if !drop_right.is_empty() {
381 let current_right_names: Vec<String> = right_lf
382 .collect_schema()
383 .map(|s| s.iter_names().map(|n| n.to_string()).collect())?;
384 let keep_names: Vec<&str> = current_right_names
385 .iter()
386 .filter(|n| !drop_right.contains(*n))
387 .map(String::as_str)
388 .collect();
389 let keep: Vec<Expr> = keep_names.iter().map(|s| col(*s)).collect();
390 right_lf = right_lf.select(&keep);
391 right_key_names = left_key_names.clone();
393 }
394 }
395 }
396
397 let on_set: std::collections::HashSet<String> = left_key_names.iter().cloned().collect();
398 let polars_how: PlJoinType = match how {
399 JoinType::Inner => PlJoinType::Inner,
400 JoinType::Left => PlJoinType::Left,
401 JoinType::Right => PlJoinType::Right,
402 JoinType::Outer => PlJoinType::Full, JoinType::LeftSemi => PlJoinType::Semi,
404 JoinType::LeftAnti => PlJoinType::Anti,
405 };
406
407 let mut left_on_exprs: Vec<Expr> = Vec::with_capacity(left_key_names.len());
409 let mut right_on_exprs: Vec<Expr> = Vec::with_capacity(right_key_names.len());
410
411 if keys_differ {
412 use crate::type_coercion::find_common_type_for_join;
415 let right_schema = right_lf.collect_schema()?;
416 for i in 0..left_key_names.len() {
417 let left_name = &left_key_names[i];
418 let right_name = &right_key_names[i];
419 let left_dtype = left.get_column_dtype(left_name.as_str()).ok_or_else(|| {
420 PolarsError::ComputeError(
421 format!("join key '{}' not found on left", left_name).into(),
422 )
423 })?;
424 let right_dtype = right_schema
425 .get(right_name.as_str())
426 .cloned()
427 .ok_or_else(|| {
428 PolarsError::ComputeError(
429 format!("join key '{}' not found on right", right_name).into(),
430 )
431 })?;
432 if left_dtype == right_dtype {
433 left_on_exprs.push(col(left_name.as_str()));
434 right_on_exprs.push(col(right_name.as_str()));
435 } else {
436 let common = find_common_type_for_join(&left_dtype, &right_dtype)?;
437 left_on_exprs.push(col(left_name.as_str()).cast(common.clone()));
438 right_on_exprs.push(col(right_name.as_str()).cast(common));
439 }
440 }
441 } else {
442 left_on_exprs = left_key_names.iter().map(|n| col(n.as_str())).collect();
443 right_on_exprs = right_key_names.iter().map(|n| col(n.as_str())).collect();
444 }
445
446 let coalesce = if !keys_match_for_coalesce {
451 JoinCoalesce::KeepColumns
452 } else if matches!(how, JoinType::Inner | JoinType::Left | JoinType::Right) {
453 JoinCoalesce::CoalesceColumns
454 } else if matches!(how, JoinType::Outer) {
455 JoinCoalesce::KeepColumns
456 } else {
457 JoinCoalesce::CoalesceColumns
458 };
459 let mut joined = JoinBuilder::new(left_lf)
460 .with(right_lf)
461 .how(polars_how)
462 .left_on(&left_on_exprs)
463 .right_on(&right_on_exprs)
464 .coalesce(coalesce)
465 .finish();
466
467 if matches!(how, JoinType::Outer) && !outer_left_key_copies.is_empty() {
468 use polars::prelude::col;
469 let left_names_full: Vec<String> = left.columns()?.into_iter().collect();
482 let right_names_full: Vec<String> = right.columns()?.into_iter().collect();
483 let has_non_key_overlap = left_names_full.iter().any(|ln| {
484 !on_set.contains(ln.as_str())
485 && right_names_full
486 .iter()
487 .any(|rn| rn.eq_ignore_ascii_case(ln.as_str()))
488 });
489 for (i, (left_name, temp)) in outer_left_key_copies.iter().enumerate() {
490 let right_key_name = right_key_names.get(i).map(|s| s.as_str()).unwrap_or("");
491 let expr = if mark_join_keys_ambiguous || has_non_key_overlap {
492 col(temp.as_str())
496 } else if right_key_name.is_empty() {
497 col(temp.as_str())
498 } else {
499 pl_coalesce(&[col(temp.as_str()), col(right_key_name)])
502 };
503 joined = joined.with_column(expr.alias(left_name.as_str()));
504 }
505 let schema = joined.collect_schema()?;
507 let all_names: Vec<String> = schema.iter_names().map(|n| n.to_string()).collect();
508 let temp_set: std::collections::HashSet<&str> = outer_left_key_copies
509 .iter()
510 .map(|(_, t)| t.as_str())
511 .collect();
512 let keep_exprs: Vec<Expr> = all_names
513 .iter()
514 .filter(|n| !temp_set.contains(n.as_str()))
515 .map(|n| col(n.as_str()))
516 .collect();
517 joined = joined.select(&keep_exprs);
518
519 if !outer_join_renamed_right_keys.is_empty() && !mark_join_keys_ambiguous {
527 let schema = joined.collect_schema()?;
528 let all_names: Vec<String> = schema.iter_names().map(|n| n.to_string()).collect();
529 let drop_right_keys: std::collections::HashSet<&str> = outer_join_renamed_right_keys
530 .iter()
531 .map(|s| s.as_str())
532 .collect();
533 let keep_exprs: Vec<Expr> = all_names
534 .iter()
535 .filter(|n| !drop_right_keys.contains(n.as_str()))
536 .map(|n| col(n.as_str()))
537 .collect();
538 joined = joined.select(&keep_exprs);
539 }
540 }
541
542 let result_schema = joined.collect_schema()?;
543 let mut names: Vec<String> = result_schema.iter_names().map(|s| s.to_string()).collect();
544 if keys_match_for_coalesce && matches!(how, JoinType::Inner | JoinType::Left | JoinType::Right)
549 {
550 let left_names: Vec<String> = left.columns()?.into_iter().collect();
551 let right_names: Vec<String> = right.columns()?.into_iter().collect();
552 let key_set: std::collections::HashSet<&str> =
553 left_key_names.iter().map(|s| s.as_str()).collect();
554 let result_schema_ref = joined.collect_schema()?;
555 let result_names_vec: Vec<String> = result_schema_ref
556 .iter_names()
557 .map(|s| s.to_string())
558 .collect();
559 let result_names_set: std::collections::HashSet<String> =
560 result_names_vec.iter().cloned().collect();
561 let cast_exprs: Vec<Expr> = if !case_sensitive {
565 let left_struct = left.schema().ok();
566 let right_struct = right.schema().ok();
567 let mut exprs: Vec<Expr> = Vec::new();
568 for left_name in &left_names {
569 let matches: Vec<&String> = result_names_vec
570 .iter()
571 .filter(|r| r.eq_ignore_ascii_case(left_name))
572 .collect();
573 if matches.is_empty() {
574 continue;
575 }
576 let dtype = key_set
577 .contains(left_name.as_str())
578 .then(|| {
579 left_struct
580 .as_ref()
581 .and_then(|s| {
582 s.fields()
583 .iter()
584 .find(|f| f.name.as_str() == left_name.as_str())
585 .map(|f| data_type_to_polars_type(&f.data_type))
586 })
587 .or_else(|| left.get_column_dtype(left_name.as_str()))
588 })
589 .flatten()
590 .or_else(|| left.get_column_dtype(left_name.as_str()));
591 let parts: Vec<Expr> = matches.iter().map(|m| col(m.as_str())).collect();
592 let e = if parts.len() == 1 {
593 col(matches[0].as_str())
594 } else {
595 pl_coalesce(&parts)
596 };
597 let e = match dtype {
598 Some(dt) => e.cast(dt),
599 None => e,
600 };
601 exprs.push(e.alias(left_name.as_str()));
602 }
603 let mut right_non_key_pos = 0_usize;
604 for right_name in &right_names {
605 if key_set.contains(right_name.as_str()) {
606 continue;
607 }
608 let matches_left = left_names
609 .iter()
610 .any(|l| l.eq_ignore_ascii_case(right_name));
611 if matches_left {
612 let result_idx = left_names.len() + right_non_key_pos;
614 if result_idx < result_names_vec.len() {
615 let dtype = right_struct
616 .as_ref()
617 .and_then(|s| {
618 s.fields()
619 .iter()
620 .find(|f| f.name.as_str() == right_name.as_str())
621 .map(|f| data_type_to_polars_type(&f.data_type))
622 })
623 .or_else(|| right.get_column_dtype(right_name.as_str()));
624 let alias_name = format!("{}_right", right_name);
625 let e = nth(result_idx as i64).as_expr();
626 let e = match dtype {
627 Some(dt) => e.cast(dt),
628 None => e,
629 };
630 exprs.push(e.alias(alias_name.as_str()));
631 }
632 right_non_key_pos += 1;
633 continue;
634 }
635 if !result_names_set.contains(right_name) {
636 continue;
637 }
638 let dtype = right_struct
639 .as_ref()
640 .and_then(|s| {
641 s.fields()
642 .iter()
643 .find(|f| f.name.as_str() == right_name.as_str())
644 .map(|f| data_type_to_polars_type(&f.data_type))
645 })
646 .or_else(|| right.get_column_dtype(right_name.as_str()));
647 let e = match dtype {
648 Some(dt) => col(right_name.as_str()).cast(dt),
649 None => col(right_name.as_str()),
650 };
651 exprs.push(e.alias(right_name.as_str()));
652 right_non_key_pos += 1;
653 }
654 Ok(exprs)
655 } else {
656 let schema_before = joined.collect_schema()?;
659 let dtypes_by_index: Vec<PlDataType> = schema_before
660 .iter_names_and_dtypes()
661 .map(|(_name, dt): (_, &PlDataType)| dt.clone())
662 .collect();
663 let mut seen_lower: std::collections::HashSet<String> =
665 std::collections::HashSet::new();
666 let desired: Vec<String> = result_names_vec
667 .iter()
668 .map(|name| {
669 let name_lower = name.to_lowercase();
670 let alias = if seen_lower.contains(&name_lower) {
671 format!("{}_right", name)
672 } else {
673 seen_lower.insert(name_lower);
674 name.clone()
675 };
676 alias
677 })
678 .collect();
679 let left_struct = left.schema().ok();
680 let right_struct = right.schema().ok();
681 let exprs: Vec<Expr> = desired
682 .iter()
683 .enumerate()
684 .map(|(idx, alias_name)| {
685 let result_name = &result_names_vec[idx];
686 let dtype = if idx < left_names.len() {
687 left_struct
688 .as_ref()
689 .and_then(|s| {
690 s.fields()
691 .iter()
692 .find(|f| f.name.as_str() == result_name.as_str())
693 .map(|f| data_type_to_polars_type(&f.data_type))
694 })
695 .or_else(|| left.get_column_dtype(result_name.as_str()))
696 } else if let Some(base) = alias_name.strip_suffix("_right") {
697 right_struct
698 .as_ref()
699 .and_then(|s| {
700 s.fields()
701 .iter()
702 .find(|f| f.name.as_str() == base)
703 .map(|f| data_type_to_polars_type(&f.data_type))
704 })
705 .or_else(|| right.get_column_dtype(base))
706 } else {
707 right_struct
708 .as_ref()
709 .and_then(|s| {
710 s.fields()
711 .iter()
712 .find(|f| f.name.as_str() == result_name.as_str())
713 .map(|f| data_type_to_polars_type(&f.data_type))
714 })
715 .or_else(|| right.get_column_dtype(result_name.as_str()))
716 };
717 let e = nth(idx as i64).as_expr();
718 match (dtype, dtypes_by_index.get(idx)) {
719 (Some(dt), _) => e.cast(dt).alias(alias_name.as_str()),
720 (_, Some(dt)) => e.cast(dt.clone()).alias(alias_name.as_str()),
721 _ => e.alias(alias_name.as_str()),
722 }
723 })
724 .collect();
725 Ok::<_, PolarsError>(exprs)
726 }?;
727 if !cast_exprs.is_empty() {
728 joined = joined.select(&cast_exprs);
729 let result_schema = joined.collect_schema()?;
730 names = result_schema.iter_names().map(|s| s.to_string()).collect();
731 }
732 }
733 let mut seen = std::collections::HashSet::new();
734 let mut unique_order: Vec<String> = Vec::new();
735 for n in &names {
736 if seen.insert(n.clone()) {
737 unique_order.push(n.clone());
738 }
739 }
740 if unique_order.len() < names.len() {
741 let schema_before_nth = joined.collect_schema()?;
745 let dtypes_by_index: Vec<PlDataType> = schema_before_nth
746 .iter_names_and_dtypes()
747 .map(|(_name, dt): (_, &PlDataType)| dt.clone())
748 .collect();
749 let exprs: Vec<Expr> = unique_order
750 .iter()
751 .map(|name| {
752 let idx = names.iter().position(|n| n == name).unwrap();
753 let e = nth(idx as i64).as_expr();
754 if let Some(dt) = dtypes_by_index.get(idx) {
755 e.cast(dt.clone()).alias(name.as_str())
756 } else {
757 e.alias(name.as_str())
758 }
759 })
760 .collect();
761 joined = joined.select(&exprs);
762 }
763 let mut result_lf = if matches!(how, JoinType::Right | JoinType::Outer) {
765 let left_names = left.columns()?;
766 let right_names = right.columns()?;
767 let result_schema = joined.collect_schema()?;
768 let result_names: std::collections::HashSet<String> =
769 result_schema.iter_names().map(|s| s.to_string()).collect();
770 let mut order: Vec<String> = Vec::new();
771 for k in &left_key_names {
772 order.push(k.clone());
773 }
774 for n in &left_names {
775 if !on_set.contains(n) {
776 order.push(n.clone());
777 }
778 }
779 for n in &right_names {
780 let use_name = if left_names.iter().any(|l| l == n) {
781 format!("{n}_right")
782 } else {
783 n.clone()
784 };
785 if result_names.contains(&use_name) {
786 order.push(use_name);
787 }
788 }
789 if order.len() == result_names.len() {
790 let select_exprs: Vec<polars::prelude::Expr> =
791 order.iter().map(|s| col(s.as_str())).collect();
792 joined.select(select_exprs.as_slice())
793 } else {
794 joined
795 }
796 } else {
797 joined
798 };
799 let result_lf = if !case_sensitive {
803 let schema = result_lf.collect_schema()?;
804 let result_names: Vec<String> = schema.iter_names().map(|s| s.to_string()).collect();
805 let mut seen_lower: std::collections::HashSet<String> = std::collections::HashSet::new();
806 let mut need_rename = false;
807 let aliases: Vec<String> = result_names
808 .iter()
809 .map(|name| {
810 let name_lower = name.to_lowercase();
811 if seen_lower.contains(&name_lower) {
812 need_rename = true;
813 format!("{}_right", name)
814 } else {
815 seen_lower.insert(name_lower);
816 name.clone()
817 }
818 })
819 .collect();
820 if need_rename {
821 let dtypes: Vec<PlDataType> = schema
822 .iter_names_and_dtypes()
823 .map(|(_, dt)| dt.clone())
824 .collect();
825 let exprs: Vec<Expr> = aliases
826 .iter()
827 .enumerate()
828 .map(|(idx, alias)| {
829 let e = nth(idx as i64).as_expr();
830 if let Some(dt) = dtypes.get(idx) {
831 e.cast(dt.clone()).alias(alias.as_str())
832 } else {
833 e.alias(alias.as_str())
834 }
835 })
836 .collect();
837 result_lf.select(&exprs)
838 } else {
839 result_lf
840 }
841 } else {
842 result_lf
843 };
844 let ambiguous_columns = if mark_join_keys_ambiguous {
850 Some(left_key_names.iter().cloned().collect::<HashSet<String>>())
851 } else {
852 None
853 };
854 Ok(super::DataFrame::from_lazy_with_options_and_ambiguous(
855 result_lf,
856 case_sensitive,
857 ambiguous_columns,
858 ))
859}
860
861#[cfg(test)]
862mod tests {
863 use super::{
864 JoinOptions, JoinOrigin, JoinType, expr_contains_only_join_key_equalities, join,
865 try_extract_join_eq_columns, try_extract_join_eq_columns_all,
866 };
867 use crate::functions::col;
868 use crate::{DataFrame, SparkSession};
869 use std::collections::HashMap;
870
871 #[test]
872 fn extract_join_eq_columns_from_eq_expr() {
873 let left = col("dept_id");
874 let right = col("dept_id");
875 let eq_expr = left.eq(right.into_expr());
876 let expr = eq_expr.into_expr();
877 let out = try_extract_join_eq_columns(&expr);
878 assert_eq!(out, Some(("dept_id".to_string(), "dept_id".to_string())));
879 }
880
881 #[test]
882 fn extract_join_eq_columns_all_from_and_of_equalities() {
883 let right = col("b").eq(col("b").into_expr());
885 let expr = col("a").eq(col("a").into_expr()).and_(&right).into_expr();
886 let out = try_extract_join_eq_columns_all(&expr);
887 assert_eq!(
888 out,
889 vec![
890 ("a".to_string(), "a".to_string()),
891 ("b".to_string(), "b".to_string()),
892 ]
893 );
894 }
895
896 #[test]
897 fn extract_join_eq_columns_from_aliased_eq() {
898 let eq_expr = col("a").eq(col("b").into_expr());
899 let expr = eq_expr.into_expr(); let out = try_extract_join_eq_columns(&expr);
901 assert_eq!(out, Some(("a".to_string(), "b".to_string())));
902 }
903
904 #[test]
905 fn expr_contains_only_join_key_equalities_simple_and_compound() {
906 let eq_expr = col("Key").eq(col("Name").into_expr()).into_expr();
908 assert!(expr_contains_only_join_key_equalities(&eq_expr));
909 let and_expr = col("a")
910 .eq(col("b").into_expr())
911 .and_(&col("c").eq(col("d").into_expr()))
912 .into_expr();
913 assert!(expr_contains_only_join_key_equalities(&and_expr));
914 let gt_expr = col("a")
916 .eq(col("b").into_expr())
917 .and_(&col("x").gt(col("y").into_expr()))
918 .into_expr();
919 assert!(!expr_contains_only_join_key_equalities(>_expr));
920 }
921
922 fn left_df() -> DataFrame {
923 let spark = SparkSession::builder()
924 .app_name("join_tests")
925 .get_or_create();
926 spark
927 .create_dataframe(
928 vec![
929 (1i64, 10i64, "a".to_string()),
930 (2i64, 20i64, "b".to_string()),
931 ],
932 vec!["id", "v", "label"],
933 )
934 .unwrap()
935 }
936
937 fn right_df() -> DataFrame {
938 let spark = SparkSession::builder()
939 .app_name("join_tests")
940 .get_or_create();
941 spark
942 .create_dataframe(
943 vec![
944 (1i64, 100i64, "x".to_string()),
945 (3i64, 300i64, "z".to_string()),
946 ],
947 vec!["id", "w", "tag"],
948 )
949 .unwrap()
950 }
951
952 #[test]
953 fn inner_join() {
954 let left = left_df();
955 let right = right_df();
956 let out = join(
957 &left,
958 &right,
959 vec!["id"],
960 vec!["id"],
961 JoinType::Inner,
962 JoinOptions {
963 case_sensitive: false,
964 coalesce_same_name_keys: false,
965 mark_join_keys_ambiguous: false,
966 origin: JoinOrigin::ColumnOn,
967 },
968 )
969 .unwrap();
970 assert_eq!(out.count().unwrap(), 1);
971 let cols = out.columns().unwrap();
972 assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
973 }
974
975 #[test]
977 fn join_coalesce_preserves_non_key_column_types() {
978 use robin_sparkless_core::DataType as CoreDataType;
979 let left = left_df();
980 let right = right_df();
981 let out = join(
982 &left,
983 &right,
984 vec!["id"],
985 vec!["id"],
986 JoinType::Inner,
987 JoinOptions {
988 case_sensitive: false,
989 coalesce_same_name_keys: true,
990 mark_join_keys_ambiguous: false,
991 origin: JoinOrigin::ColumnOn,
992 },
993 )
994 .unwrap();
995 assert_eq!(out.count().unwrap(), 1);
996 let schema = out.schema().unwrap();
997 let v_field = schema.fields().iter().find(|f| f.name == "v");
998 let w_field = schema.fields().iter().find(|f| f.name == "w");
999 assert!(
1000 matches!(v_field.map(|f| &f.data_type), Some(CoreDataType::Long)),
1001 "v should be Long"
1002 );
1003 assert!(
1004 matches!(w_field.map(|f| &f.data_type), Some(CoreDataType::Long)),
1005 "w should be Long"
1006 );
1007 let rows = out.collect_as_json_rows().unwrap();
1008 assert_eq!(rows.len(), 1);
1009 let row = &rows[0];
1010 assert!(
1011 row.get("v").and_then(|v| v.as_i64()).is_some(),
1012 "v should be number in JSON"
1013 );
1014 assert!(
1015 row.get("w").and_then(|v| v.as_i64()).is_some(),
1016 "w should be number in JSON"
1017 );
1018 }
1019
1020 #[test]
1021 fn left_join() {
1022 let left = left_df();
1023 let right = right_df();
1024 let out = join(
1025 &left,
1026 &right,
1027 vec!["id"],
1028 vec!["id"],
1029 JoinType::Left,
1030 JoinOptions {
1031 case_sensitive: false,
1032 coalesce_same_name_keys: false,
1033 mark_join_keys_ambiguous: false,
1034 origin: JoinOrigin::ColumnOn,
1035 },
1036 )
1037 .unwrap();
1038 assert_eq!(out.count().unwrap(), 2);
1039 }
1040
1041 #[test]
1042 fn right_join() {
1043 let left = left_df();
1044 let right = right_df();
1045 let out = join(
1046 &left,
1047 &right,
1048 vec!["id"],
1049 vec!["id"],
1050 JoinType::Right,
1051 JoinOptions {
1052 case_sensitive: false,
1053 coalesce_same_name_keys: false,
1054 mark_join_keys_ambiguous: false,
1055 origin: JoinOrigin::ColumnOn,
1056 },
1057 )
1058 .unwrap();
1059 assert_eq!(out.count().unwrap(), 2); }
1061
1062 #[test]
1063 fn outer_join() {
1064 let left = left_df();
1065 let right = right_df();
1066 let out = join(
1067 &left,
1068 &right,
1069 vec!["id"],
1070 vec!["id"],
1071 JoinType::Outer,
1072 JoinOptions {
1073 case_sensitive: false,
1074 coalesce_same_name_keys: false,
1075 mark_join_keys_ambiguous: false,
1076 origin: JoinOrigin::ColumnOn,
1077 },
1078 )
1079 .unwrap();
1080 assert_eq!(out.count().unwrap(), 3);
1081 }
1082
1083 #[test]
1084 fn left_semi_join() {
1085 let left = left_df();
1086 let right = right_df();
1087 let out = join(
1088 &left,
1089 &right,
1090 vec!["id"],
1091 vec!["id"],
1092 JoinType::LeftSemi,
1093 JoinOptions {
1094 case_sensitive: false,
1095 coalesce_same_name_keys: false,
1096 mark_join_keys_ambiguous: false,
1097 origin: JoinOrigin::ColumnOn,
1098 },
1099 )
1100 .unwrap();
1101 assert_eq!(out.count().unwrap(), 1); }
1103
1104 #[test]
1105 fn left_anti_join() {
1106 let left = left_df();
1107 let right = right_df();
1108 let out = join(
1109 &left,
1110 &right,
1111 vec!["id"],
1112 vec!["id"],
1113 JoinType::LeftAnti,
1114 JoinOptions {
1115 case_sensitive: false,
1116 coalesce_same_name_keys: false,
1117 mark_join_keys_ambiguous: false,
1118 origin: JoinOrigin::ColumnOn,
1119 },
1120 )
1121 .unwrap();
1122 assert_eq!(out.count().unwrap(), 1); }
1124
1125 #[test]
1126 fn join_empty_right() {
1127 let spark = SparkSession::builder()
1128 .app_name("join_tests")
1129 .get_or_create();
1130 let left = left_df();
1131 let right = spark
1132 .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
1133 .unwrap();
1134 let out = join(
1135 &left,
1136 &right,
1137 vec!["id"],
1138 vec!["id"],
1139 JoinType::Inner,
1140 JoinOptions {
1141 case_sensitive: false,
1142 coalesce_same_name_keys: false,
1143 mark_join_keys_ambiguous: false,
1144 origin: JoinOrigin::ColumnOn,
1145 },
1146 )
1147 .unwrap();
1148 assert_eq!(out.count().unwrap(), 0);
1149 }
1150
1151 #[test]
1153 fn join_key_type_coercion_str_int() {
1154 use polars::prelude::df;
1155 let spark = SparkSession::builder()
1156 .app_name("join_tests")
1157 .get_or_create();
1158 let left_pl = df!("id" => &["1"], "label" => &["a"]).unwrap();
1159 let right_pl = df!("id" => &[1i64], "x" => &[10i64]).unwrap();
1160 let left = spark.create_dataframe_from_polars(left_pl);
1161 let right = spark.create_dataframe_from_polars(right_pl);
1162 let out = join(
1163 &left,
1164 &right,
1165 vec!["id"],
1166 vec!["id"],
1167 JoinType::Inner,
1168 JoinOptions {
1169 case_sensitive: false,
1170 coalesce_same_name_keys: false,
1171 mark_join_keys_ambiguous: false,
1172 origin: JoinOrigin::ColumnOn,
1173 },
1174 )
1175 .unwrap();
1176 assert_eq!(out.count().unwrap(), 1);
1177 let rows = out.collect().unwrap();
1178 assert_eq!(rows.height(), 1);
1179 assert!(rows.column("label").is_ok());
1181 assert!(rows.column("x").is_ok());
1182 }
1183
1184 #[test]
1186 fn join_key_type_coercion_int_str() {
1187 use polars::prelude::df;
1188 let spark = SparkSession::builder()
1189 .app_name("join_tests")
1190 .get_or_create();
1191 let left_pl = df!("id" => &[1i64, 2i64], "name" => &["alice", "bob"]).unwrap();
1192 let right_pl = df!("id" => &["1", "3"], "value" => &[100i64, 300i64]).unwrap();
1193 let left = spark.create_dataframe_from_polars(left_pl);
1194 let right = spark.create_dataframe_from_polars(right_pl);
1195 let out = join(
1196 &left,
1197 &right,
1198 vec!["id"],
1199 vec!["id"],
1200 JoinType::Inner,
1201 JoinOptions {
1202 case_sensitive: false,
1203 coalesce_same_name_keys: false,
1204 mark_join_keys_ambiguous: false,
1205 origin: JoinOrigin::ColumnOn,
1206 },
1207 )
1208 .unwrap();
1209 assert_eq!(out.count().unwrap(), 1, "inner join on id: 1 match (id=1)");
1210 let rows = out.collect().unwrap();
1211 assert_eq!(rows.height(), 1);
1212 assert!(rows.column("id").is_ok());
1213 assert!(rows.column("name").is_ok());
1214 assert!(rows.column("value").is_ok());
1215 }
1216
1217 #[test]
1218 fn outer_join_then_groupby_on_key_matches_pyspark_semantics() {
1219 let spark = SparkSession::builder()
1223 .app_name("outer_join_groupby_tests")
1224 .get_or_create();
1225
1226 let left_tuples = vec![
1227 (1i64, 0i64, "L1".to_string()),
1228 (3i64, 0i64, "L3".to_string()),
1229 ];
1230 let right_tuples = vec![
1231 (1i64, 0i64, "R1".to_string()),
1232 (2i64, 0i64, "R2".to_string()),
1233 ];
1234
1235 let left = spark
1236 .create_dataframe(left_tuples, vec!["key", "extra_left", "left_val"])
1237 .unwrap();
1238 let right = spark
1239 .create_dataframe(right_tuples, vec!["key", "extra_right", "right_val"])
1240 .unwrap();
1241
1242 let joined = join(
1243 &left,
1244 &right,
1245 vec!["key"],
1246 vec!["key"],
1247 JoinType::Outer,
1248 JoinOptions {
1249 case_sensitive: false,
1250 coalesce_same_name_keys: true,
1251 mark_join_keys_ambiguous: false,
1252 origin: JoinOrigin::ColumnOn,
1253 },
1254 )
1255 .unwrap();
1256
1257 let grouped = joined.group_by(vec!["key"]).unwrap();
1258 let out = grouped.count().unwrap();
1259 let pl_df = out.collect().unwrap();
1260
1261 let key_col = pl_df.column("key").unwrap().i64().unwrap();
1262 let count_col = pl_df.column("count").unwrap().i64().unwrap();
1263
1264 let mut by_key: HashMap<Option<i64>, i64> = HashMap::new();
1265 for idx in 0..key_col.len() {
1266 let key = key_col.get(idx);
1267 let cnt = count_col.get(idx).unwrap_or(0);
1268 by_key.insert(key, cnt);
1269 }
1270
1271 assert_eq!(by_key.len(), 3);
1273 assert_eq!(by_key.get(&Some(1)).copied(), Some(1));
1274 assert_eq!(by_key.get(&Some(2)).copied(), Some(1));
1275 assert_eq!(by_key.get(&Some(3)).copied(), Some(1));
1276 }
1277
1278 #[test]
1280 fn join_column_resolution_case_insensitive() {
1281 use polars::prelude::df;
1282 let spark = SparkSession::builder()
1283 .app_name("join_tests")
1284 .get_or_create();
1285 let left_pl = df!("id" => &[1i64, 2i64], "val" => &["a", "b"]).unwrap();
1286 let right_pl = df!("ID" => &[1i64], "other" => &["x"]).unwrap();
1287 let left = spark.create_dataframe_from_polars(left_pl);
1288 let right = spark.create_dataframe_from_polars(right_pl);
1289 let out = join(
1290 &left,
1291 &right,
1292 vec!["id"],
1293 vec!["id"],
1294 JoinType::Inner,
1295 JoinOptions {
1296 case_sensitive: false,
1297 coalesce_same_name_keys: false,
1298 mark_join_keys_ambiguous: false,
1299 origin: JoinOrigin::ColumnOn,
1300 },
1301 )
1302 .expect("issue #604: join on id/ID must succeed");
1303 assert_eq!(out.count().unwrap(), 1);
1304 let rows = out
1305 .collect()
1306 .expect("issue #604: collect must not fail with 'not found: ID'");
1307 assert_eq!(rows.height(), 1);
1308 assert!(rows.column("id").is_ok());
1309 assert!(rows.column("val").is_ok());
1310 assert!(rows.column("other").is_ok());
1311 assert!(out.resolve_column_name("ID").is_ok());
1313 }
1314}