robin_sparkless/dataframe/
joins.rs1use super::DataFrame;
4use polars::prelude::JoinType as PlJoinType;
5use polars::prelude::PolarsError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum JoinType {
10 Inner,
11 Left,
12 Right,
13 Outer,
14 LeftSemi,
16 LeftAnti,
18}
19
20pub fn join(
23 left: &DataFrame,
24 right: &DataFrame,
25 on: Vec<&str>,
26 how: JoinType,
27 case_sensitive: bool,
28) -> Result<DataFrame, PolarsError> {
29 use polars::prelude::{col, IntoLazy, JoinBuilder, JoinCoalesce};
30 let left_lf = left.df.as_ref().clone().lazy();
31 let right_lf = right.df.as_ref().clone().lazy();
32 let on_set: std::collections::HashSet<&str> = on.iter().copied().collect();
33 let on_exprs: Vec<polars::prelude::Expr> = on.iter().map(|name| col(*name)).collect();
34 let polars_how: PlJoinType = match how {
35 JoinType::Inner => PlJoinType::Inner,
36 JoinType::Left => PlJoinType::Left,
37 JoinType::Right => PlJoinType::Right,
38 JoinType::Outer => PlJoinType::Full, JoinType::LeftSemi => PlJoinType::Semi,
40 JoinType::LeftAnti => PlJoinType::Anti,
41 };
42 let joined = JoinBuilder::new(left_lf)
43 .with(right_lf)
44 .how(polars_how)
45 .on(&on_exprs)
46 .coalesce(JoinCoalesce::CoalesceColumns)
47 .finish();
48 let mut pl_df = joined.collect()?;
49 if matches!(how, JoinType::Right | JoinType::Outer) {
50 let left_names: Vec<String> = left
51 .df
52 .get_column_names()
53 .iter()
54 .map(|s| s.to_string())
55 .collect();
56 let right_names: Vec<String> = right
57 .df
58 .get_column_names()
59 .iter()
60 .map(|s| s.to_string())
61 .collect();
62 let result_names: std::collections::HashSet<String> = pl_df
63 .get_column_names()
64 .iter()
65 .map(|s| s.to_string())
66 .collect();
67 let mut order: Vec<String> = Vec::new();
68 for k in &on {
69 order.push((*k).to_string());
70 }
71 for n in &left_names {
72 if !on_set.contains(n.as_str()) {
73 order.push(n.clone());
74 }
75 }
76 for n in &right_names {
77 let use_name = if left_names.iter().any(|l| l == n) {
78 format!("{n}_right")
79 } else {
80 n.clone()
81 };
82 if result_names.contains(&use_name) {
83 order.push(use_name);
84 }
85 }
86 if order.len() == result_names.len() {
87 let select_refs: Vec<&str> = order.iter().map(String::as_str).collect();
88 pl_df = pl_df.select(select_refs).map_err(|e| {
89 PolarsError::ComputeError(format!("join column reorder: {e}").into())
90 })?;
91 }
92 }
93 Ok(super::DataFrame::from_polars_with_options(
94 pl_df,
95 case_sensitive,
96 ))
97}
98
99#[cfg(test)]
100mod tests {
101 use super::{join, JoinType};
102 use crate::{DataFrame, SparkSession};
103
104 fn left_df() -> DataFrame {
105 let spark = SparkSession::builder()
106 .app_name("join_tests")
107 .get_or_create();
108 spark
109 .create_dataframe(
110 vec![
111 (1i64, 10i64, "a".to_string()),
112 (2i64, 20i64, "b".to_string()),
113 ],
114 vec!["id", "v", "label"],
115 )
116 .unwrap()
117 }
118
119 fn right_df() -> DataFrame {
120 let spark = SparkSession::builder()
121 .app_name("join_tests")
122 .get_or_create();
123 spark
124 .create_dataframe(
125 vec![
126 (1i64, 100i64, "x".to_string()),
127 (3i64, 300i64, "z".to_string()),
128 ],
129 vec!["id", "w", "tag"],
130 )
131 .unwrap()
132 }
133
134 #[test]
135 fn inner_join() {
136 let left = left_df();
137 let right = right_df();
138 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
139 assert_eq!(out.count().unwrap(), 1);
140 let cols = out.columns().unwrap();
141 assert!(cols.iter().any(|c| c == "id" || c.ends_with("_right")));
142 }
143
144 #[test]
145 fn left_join() {
146 let left = left_df();
147 let right = right_df();
148 let out = join(&left, &right, vec!["id"], JoinType::Left, false).unwrap();
149 assert_eq!(out.count().unwrap(), 2);
150 }
151
152 #[test]
153 fn outer_join() {
154 let left = left_df();
155 let right = right_df();
156 let out = join(&left, &right, vec!["id"], JoinType::Outer, false).unwrap();
157 assert_eq!(out.count().unwrap(), 3);
158 }
159
160 #[test]
161 fn join_empty_right() {
162 let spark = SparkSession::builder()
163 .app_name("join_tests")
164 .get_or_create();
165 let left = left_df();
166 let right = spark
167 .create_dataframe(vec![] as Vec<(i64, i64, String)>, vec!["id", "w", "tag"])
168 .unwrap();
169 let out = join(&left, &right, vec!["id"], JoinType::Inner, false).unwrap();
170 assert_eq!(out.count().unwrap(), 0);
171 }
172}