1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use arrow::array::{Array, UInt32Builder};
5use arrow::datatypes::{DataType, Field, Schema};
6use arrow::record_batch::RecordBatch;
7
8use crate::ops::{JoinKeys, JoinType};
9use crate::{DataFrameError, Result};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12struct JoinKey(Vec<KeyValue>);
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15enum KeyValue {
16 Null { dtype: DataType },
17 Boolean(bool),
18 Signed(i128),
19 Unsigned(u128),
20 Float32(u32),
21 Float64(u64),
22 Utf8(String),
23}
24
25pub fn join_batches(
26 left_batches: Vec<RecordBatch>,
27 right_batches: Vec<RecordBatch>,
28 keys: &JoinKeys,
29 how: &JoinType,
30) -> Result<Vec<RecordBatch>> {
31 let left_batch = concat_batches(&left_batches)?;
32 let right_batch = concat_batches(&right_batches)?;
33 let left_schema = left_batch.schema();
34 let right_schema = right_batch.schema();
35
36 let resolved = resolve_join_keys(left_schema.as_ref(), right_schema.as_ref(), keys)?;
37 let output = build_output_spec(left_schema.as_ref(), right_schema.as_ref(), &resolved, how)?;
38
39 let left_rows = left_batch.num_rows();
40 let right_rows = right_batch.num_rows();
41
42 let mut right_map = HashMap::<JoinKey, Vec<usize>>::new();
43 for row in 0..right_rows {
44 let key = build_join_key(&right_batch, &resolved.right_indices, row)?;
45 right_map.entry(key).or_default().push(row);
46 }
47
48 let mut left_indices: Vec<Option<usize>> = Vec::new();
49 let mut right_indices: Vec<Option<usize>> = Vec::new();
50 let mut matched_right = vec![false; right_rows];
51
52 for row in 0..left_rows {
53 let key = build_join_key(&left_batch, &resolved.left_indices, row)?;
54 match right_map.get(&key) {
55 Some(matches) => match how {
56 JoinType::Semi => {
57 left_indices.push(Some(row));
58 }
59 JoinType::Anti => {}
60 _ => {
61 for &r in matches {
62 left_indices.push(Some(row));
63 right_indices.push(Some(r));
64 matched_right[r] = true;
65 }
66 }
67 },
68 None => match how {
69 JoinType::Left | JoinType::Full => {
70 left_indices.push(Some(row));
71 right_indices.push(None);
72 }
73 JoinType::Anti => {
74 left_indices.push(Some(row));
75 }
76 _ => {}
77 },
78 }
79 }
80
81 if matches!(how, JoinType::Right | JoinType::Full) {
82 for (r, matched) in matched_right.iter().enumerate() {
83 if !*matched {
84 left_indices.push(None);
85 right_indices.push(Some(r));
86 }
87 }
88 }
89
90 let left_index_array = build_indices(&left_indices)?;
91 let right_index_array = build_indices(&right_indices)?;
92
93 let mut arrays = Vec::with_capacity(output.columns.len());
94 for col in &output.columns {
95 match col {
96 OutputColumn::Left(idx) => {
97 let array = arrow::compute::take(left_batch.column(*idx), &left_index_array, None)
98 .map_err(|source| DataFrameError::Arrow { source })?;
99 arrays.push(array);
100 }
101 OutputColumn::Right(idx) => {
102 let array =
103 arrow::compute::take(right_batch.column(*idx), &right_index_array, None)
104 .map_err(|source| DataFrameError::Arrow { source })?;
105 arrays.push(array);
106 }
107 }
108 }
109
110 let schema = Arc::new(Schema::new(output.fields));
111 let batch = RecordBatch::try_new(schema, arrays).map_err(|e| {
112 DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
113 })?;
114 Ok(vec![batch])
115}
116
117fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
118 if batches.is_empty() {
119 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
120 }
121 let schema = batches[0].schema();
122 if batches.len() == 1 {
123 return Ok(batches[0].clone());
124 }
125 arrow::compute::concat_batches(&schema, batches)
126 .map_err(|source| DataFrameError::Arrow { source })
127}
128
129struct ResolvedJoinKeys {
130 left_indices: Vec<usize>,
131 right_indices: Vec<usize>,
132 right_key_indices: HashSet<usize>,
133 on_same_names: bool,
134}
135
136fn resolve_join_keys(left: &Schema, right: &Schema, keys: &JoinKeys) -> Result<ResolvedJoinKeys> {
137 let (left_names, right_names, on_same_names) = match keys {
138 JoinKeys::On(cols) => (cols.clone(), cols.clone(), true),
139 JoinKeys::LeftRight { left_on, right_on } => (left_on.clone(), right_on.clone(), false),
140 };
141
142 if left_names.is_empty() {
143 return Err(DataFrameError::invalid_operation(
144 "join keys must be non-empty",
145 ));
146 }
147
148 if left_names.len() != right_names.len() {
149 return Err(DataFrameError::invalid_operation(
150 "join key lengths do not match",
151 ));
152 }
153
154 let mut left_indices = Vec::with_capacity(left_names.len());
155 let mut right_indices = Vec::with_capacity(right_names.len());
156 let mut right_key_indices = HashSet::with_capacity(right_names.len());
157
158 for (l_name, r_name) in left_names.iter().zip(right_names.iter()) {
159 let l_idx = left
160 .fields()
161 .iter()
162 .position(|f| f.name() == l_name)
163 .ok_or_else(|| DataFrameError::column_not_found(l_name.clone()))?;
164 let r_idx = right
165 .fields()
166 .iter()
167 .position(|f| f.name() == r_name)
168 .ok_or_else(|| DataFrameError::column_not_found(r_name.clone()))?;
169
170 let l_type = left.fields()[l_idx].data_type();
171 let r_type = right.fields()[r_idx].data_type();
172 if l_type != r_type {
173 return Err(DataFrameError::type_mismatch(
174 Some(l_name.clone()),
175 l_type.to_string(),
176 r_type.to_string(),
177 ));
178 }
179
180 left_indices.push(l_idx);
181 right_indices.push(r_idx);
182 right_key_indices.insert(r_idx);
183 }
184
185 Ok(ResolvedJoinKeys {
186 left_indices,
187 right_indices,
188 right_key_indices,
189 on_same_names,
190 })
191}
192
193struct OutputSpec {
194 fields: Vec<Field>,
195 columns: Vec<OutputColumn>,
196}
197
198enum OutputColumn {
199 Left(usize),
200 Right(usize),
201}
202
203fn build_output_spec(
204 left: &Schema,
205 right: &Schema,
206 keys: &ResolvedJoinKeys,
207 how: &JoinType,
208) -> Result<OutputSpec> {
209 let mut fields = Vec::new();
210 let mut columns = Vec::new();
211 let mut seen = HashSet::<String>::new();
212
213 let left_nullable = matches!(how, JoinType::Right | JoinType::Full);
214 for (idx, f) in left.fields().iter().enumerate() {
215 let field = Field::new(
216 f.name(),
217 f.data_type().clone(),
218 f.is_nullable() || left_nullable,
219 );
220 seen.insert(field.name().to_string());
221 fields.push(field);
222 columns.push(OutputColumn::Left(idx));
223 }
224
225 if matches!(how, JoinType::Semi | JoinType::Anti) {
226 return Ok(OutputSpec { fields, columns });
227 }
228
229 let right_nullable = matches!(how, JoinType::Left | JoinType::Full);
230 for (idx, f) in right.fields().iter().enumerate() {
231 if keys.on_same_names && keys.right_key_indices.contains(&idx) {
232 continue;
233 }
234
235 let mut name = f.name().to_string();
236 if seen.contains(&name) {
237 if keys.right_key_indices.contains(&idx) {
238 return Err(DataFrameError::schema_mismatch(format!(
239 "duplicate column name '{name}'",
240 )));
241 }
242 let suffixed = format!("{name}_right");
243 if seen.contains(&suffixed) {
244 return Err(DataFrameError::schema_mismatch(format!(
245 "duplicate column name '{suffixed}'",
246 )));
247 }
248 name = suffixed;
249 }
250
251 seen.insert(name.clone());
252 fields.push(Field::new(
253 &name,
254 f.data_type().clone(),
255 f.is_nullable() || right_nullable,
256 ));
257 columns.push(OutputColumn::Right(idx));
258 }
259
260 Ok(OutputSpec { fields, columns })
261}
262
263fn build_indices(indices: &[Option<usize>]) -> Result<arrow::array::UInt32Array> {
264 let mut builder = UInt32Builder::with_capacity(indices.len());
265 for idx in indices {
266 match idx {
267 Some(value) => {
268 let value = u32::try_from(*value).map_err(|_| {
269 DataFrameError::invalid_operation("row index exceeds u32 range")
270 })?;
271 builder.append_value(value);
272 }
273 None => {
274 builder.append_null();
275 }
276 }
277 }
278 Ok(builder.finish())
279}
280
281fn build_join_key(batch: &RecordBatch, indices: &[usize], row: usize) -> Result<JoinKey> {
282 let mut values = Vec::with_capacity(indices.len());
283 for idx in indices {
284 let array = batch.column(*idx).as_ref();
285 values.push(key_value_from_array(array, row)?);
286 }
287 Ok(JoinKey(values))
288}
289
290fn key_value_from_array(array: &dyn Array, row: usize) -> Result<KeyValue> {
291 if array.is_null(row) {
292 return Ok(KeyValue::Null {
293 dtype: array.data_type().clone(),
294 });
295 }
296
297 use arrow::datatypes::DataType::*;
298 match array.data_type() {
299 Boolean => Ok(KeyValue::Boolean(
300 array
301 .as_any()
302 .downcast_ref::<arrow::array::BooleanArray>()
303 .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray downcast"))?
304 .value(row),
305 )),
306 Int8 => Ok(KeyValue::Signed(
307 array
308 .as_any()
309 .downcast_ref::<arrow::array::Int8Array>()
310 .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array downcast"))?
311 .value(row) as i128,
312 )),
313 Int16 => Ok(KeyValue::Signed(
314 array
315 .as_any()
316 .downcast_ref::<arrow::array::Int16Array>()
317 .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array downcast"))?
318 .value(row) as i128,
319 )),
320 Int32 => Ok(KeyValue::Signed(
321 array
322 .as_any()
323 .downcast_ref::<arrow::array::Int32Array>()
324 .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array downcast"))?
325 .value(row) as i128,
326 )),
327 Int64 => Ok(KeyValue::Signed(
328 array
329 .as_any()
330 .downcast_ref::<arrow::array::Int64Array>()
331 .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array downcast"))?
332 .value(row) as i128,
333 )),
334 UInt8 => Ok(KeyValue::Unsigned(
335 array
336 .as_any()
337 .downcast_ref::<arrow::array::UInt8Array>()
338 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array downcast"))?
339 .value(row) as u128,
340 )),
341 UInt16 => Ok(KeyValue::Unsigned(
342 array
343 .as_any()
344 .downcast_ref::<arrow::array::UInt16Array>()
345 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array downcast"))?
346 .value(row) as u128,
347 )),
348 UInt32 => Ok(KeyValue::Unsigned(
349 array
350 .as_any()
351 .downcast_ref::<arrow::array::UInt32Array>()
352 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array downcast"))?
353 .value(row) as u128,
354 )),
355 UInt64 => Ok(KeyValue::Unsigned(
356 array
357 .as_any()
358 .downcast_ref::<arrow::array::UInt64Array>()
359 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array downcast"))?
360 .value(row) as u128,
361 )),
362 Float32 => Ok(KeyValue::Float32(
363 array
364 .as_any()
365 .downcast_ref::<arrow::array::Float32Array>()
366 .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array downcast"))?
367 .value(row)
368 .to_bits(),
369 )),
370 Float64 => Ok(KeyValue::Float64(
371 array
372 .as_any()
373 .downcast_ref::<arrow::array::Float64Array>()
374 .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array downcast"))?
375 .value(row)
376 .to_bits(),
377 )),
378 Utf8 => Ok(KeyValue::Utf8(
379 array
380 .as_any()
381 .downcast_ref::<arrow::array::StringArray>()
382 .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray downcast"))?
383 .value(row)
384 .to_string(),
385 )),
386 other => Err(DataFrameError::invalid_operation(format!(
387 "unsupported join key type {other:?}",
388 ))),
389 }
390}