Skip to main content

alopex_dataframe/ops/
join.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use arrow::array::{Array, UInt32Builder};
5use arrow::datatypes::{DataType, Field, Schema};
6use arrow::record_batch::RecordBatch;
7
8use crate::ops::{JoinKeys, JoinType};
9use crate::{DataFrameError, Result};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12struct JoinKey(Vec<KeyValue>);
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15enum KeyValue {
16    Null { dtype: DataType },
17    Boolean(bool),
18    Signed(i128),
19    Unsigned(u128),
20    Float32(u32),
21    Float64(u64),
22    Utf8(String),
23}
24
25pub fn join_batches(
26    left_batches: Vec<RecordBatch>,
27    right_batches: Vec<RecordBatch>,
28    keys: &JoinKeys,
29    how: &JoinType,
30) -> Result<Vec<RecordBatch>> {
31    let left_batch = concat_batches(&left_batches)?;
32    let right_batch = concat_batches(&right_batches)?;
33    let left_schema = left_batch.schema();
34    let right_schema = right_batch.schema();
35
36    let resolved = resolve_join_keys(left_schema.as_ref(), right_schema.as_ref(), keys)?;
37    let output = build_output_spec(left_schema.as_ref(), right_schema.as_ref(), &resolved, how)?;
38
39    let left_rows = left_batch.num_rows();
40    let right_rows = right_batch.num_rows();
41
42    let mut right_map = HashMap::<JoinKey, Vec<usize>>::new();
43    for row in 0..right_rows {
44        let key = build_join_key(&right_batch, &resolved.right_indices, row)?;
45        right_map.entry(key).or_default().push(row);
46    }
47
48    let mut left_indices: Vec<Option<usize>> = Vec::new();
49    let mut right_indices: Vec<Option<usize>> = Vec::new();
50    let mut matched_right = vec![false; right_rows];
51
52    for row in 0..left_rows {
53        let key = build_join_key(&left_batch, &resolved.left_indices, row)?;
54        match right_map.get(&key) {
55            Some(matches) => match how {
56                JoinType::Semi => {
57                    left_indices.push(Some(row));
58                }
59                JoinType::Anti => {}
60                _ => {
61                    for &r in matches {
62                        left_indices.push(Some(row));
63                        right_indices.push(Some(r));
64                        matched_right[r] = true;
65                    }
66                }
67            },
68            None => match how {
69                JoinType::Left | JoinType::Full => {
70                    left_indices.push(Some(row));
71                    right_indices.push(None);
72                }
73                JoinType::Anti => {
74                    left_indices.push(Some(row));
75                }
76                _ => {}
77            },
78        }
79    }
80
81    if matches!(how, JoinType::Right | JoinType::Full) {
82        for (r, matched) in matched_right.iter().enumerate() {
83            if !*matched {
84                left_indices.push(None);
85                right_indices.push(Some(r));
86            }
87        }
88    }
89
90    let left_index_array = build_indices(&left_indices)?;
91    let right_index_array = build_indices(&right_indices)?;
92
93    let mut arrays = Vec::with_capacity(output.columns.len());
94    for col in &output.columns {
95        match col {
96            OutputColumn::Left(idx) => {
97                let array = arrow::compute::take(left_batch.column(*idx), &left_index_array, None)
98                    .map_err(|source| DataFrameError::Arrow { source })?;
99                arrays.push(array);
100            }
101            OutputColumn::Right(idx) => {
102                let array =
103                    arrow::compute::take(right_batch.column(*idx), &right_index_array, None)
104                        .map_err(|source| DataFrameError::Arrow { source })?;
105                arrays.push(array);
106            }
107        }
108    }
109
110    let schema = Arc::new(Schema::new(output.fields));
111    let batch = RecordBatch::try_new(schema, arrays).map_err(|e| {
112        DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
113    })?;
114    Ok(vec![batch])
115}
116
117fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
118    if batches.is_empty() {
119        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
120    }
121    let schema = batches[0].schema();
122    if batches.len() == 1 {
123        return Ok(batches[0].clone());
124    }
125    arrow::compute::concat_batches(&schema, batches)
126        .map_err(|source| DataFrameError::Arrow { source })
127}
128
129struct ResolvedJoinKeys {
130    left_indices: Vec<usize>,
131    right_indices: Vec<usize>,
132    right_key_indices: HashSet<usize>,
133    on_same_names: bool,
134}
135
136fn resolve_join_keys(left: &Schema, right: &Schema, keys: &JoinKeys) -> Result<ResolvedJoinKeys> {
137    let (left_names, right_names, on_same_names) = match keys {
138        JoinKeys::On(cols) => (cols.clone(), cols.clone(), true),
139        JoinKeys::LeftRight { left_on, right_on } => (left_on.clone(), right_on.clone(), false),
140    };
141
142    if left_names.is_empty() {
143        return Err(DataFrameError::invalid_operation(
144            "join keys must be non-empty",
145        ));
146    }
147
148    if left_names.len() != right_names.len() {
149        return Err(DataFrameError::invalid_operation(
150            "join key lengths do not match",
151        ));
152    }
153
154    let mut left_indices = Vec::with_capacity(left_names.len());
155    let mut right_indices = Vec::with_capacity(right_names.len());
156    let mut right_key_indices = HashSet::with_capacity(right_names.len());
157
158    for (l_name, r_name) in left_names.iter().zip(right_names.iter()) {
159        let l_idx = left
160            .fields()
161            .iter()
162            .position(|f| f.name() == l_name)
163            .ok_or_else(|| DataFrameError::column_not_found(l_name.clone()))?;
164        let r_idx = right
165            .fields()
166            .iter()
167            .position(|f| f.name() == r_name)
168            .ok_or_else(|| DataFrameError::column_not_found(r_name.clone()))?;
169
170        let l_type = left.fields()[l_idx].data_type();
171        let r_type = right.fields()[r_idx].data_type();
172        if l_type != r_type {
173            return Err(DataFrameError::type_mismatch(
174                Some(l_name.clone()),
175                l_type.to_string(),
176                r_type.to_string(),
177            ));
178        }
179
180        left_indices.push(l_idx);
181        right_indices.push(r_idx);
182        right_key_indices.insert(r_idx);
183    }
184
185    Ok(ResolvedJoinKeys {
186        left_indices,
187        right_indices,
188        right_key_indices,
189        on_same_names,
190    })
191}
192
193struct OutputSpec {
194    fields: Vec<Field>,
195    columns: Vec<OutputColumn>,
196}
197
198enum OutputColumn {
199    Left(usize),
200    Right(usize),
201}
202
203fn build_output_spec(
204    left: &Schema,
205    right: &Schema,
206    keys: &ResolvedJoinKeys,
207    how: &JoinType,
208) -> Result<OutputSpec> {
209    let mut fields = Vec::new();
210    let mut columns = Vec::new();
211    let mut seen = HashSet::<String>::new();
212
213    let left_nullable = matches!(how, JoinType::Right | JoinType::Full);
214    for (idx, f) in left.fields().iter().enumerate() {
215        let field = Field::new(
216            f.name(),
217            f.data_type().clone(),
218            f.is_nullable() || left_nullable,
219        );
220        seen.insert(field.name().to_string());
221        fields.push(field);
222        columns.push(OutputColumn::Left(idx));
223    }
224
225    if matches!(how, JoinType::Semi | JoinType::Anti) {
226        return Ok(OutputSpec { fields, columns });
227    }
228
229    let right_nullable = matches!(how, JoinType::Left | JoinType::Full);
230    for (idx, f) in right.fields().iter().enumerate() {
231        if keys.on_same_names && keys.right_key_indices.contains(&idx) {
232            continue;
233        }
234
235        let mut name = f.name().to_string();
236        if seen.contains(&name) {
237            if keys.right_key_indices.contains(&idx) {
238                return Err(DataFrameError::schema_mismatch(format!(
239                    "duplicate column name '{name}'",
240                )));
241            }
242            let suffixed = format!("{name}_right");
243            if seen.contains(&suffixed) {
244                return Err(DataFrameError::schema_mismatch(format!(
245                    "duplicate column name '{suffixed}'",
246                )));
247            }
248            name = suffixed;
249        }
250
251        seen.insert(name.clone());
252        fields.push(Field::new(
253            &name,
254            f.data_type().clone(),
255            f.is_nullable() || right_nullable,
256        ));
257        columns.push(OutputColumn::Right(idx));
258    }
259
260    Ok(OutputSpec { fields, columns })
261}
262
263fn build_indices(indices: &[Option<usize>]) -> Result<arrow::array::UInt32Array> {
264    let mut builder = UInt32Builder::with_capacity(indices.len());
265    for idx in indices {
266        match idx {
267            Some(value) => {
268                let value = u32::try_from(*value).map_err(|_| {
269                    DataFrameError::invalid_operation("row index exceeds u32 range")
270                })?;
271                builder.append_value(value);
272            }
273            None => {
274                builder.append_null();
275            }
276        }
277    }
278    Ok(builder.finish())
279}
280
281fn build_join_key(batch: &RecordBatch, indices: &[usize], row: usize) -> Result<JoinKey> {
282    let mut values = Vec::with_capacity(indices.len());
283    for idx in indices {
284        let array = batch.column(*idx).as_ref();
285        values.push(key_value_from_array(array, row)?);
286    }
287    Ok(JoinKey(values))
288}
289
290fn key_value_from_array(array: &dyn Array, row: usize) -> Result<KeyValue> {
291    if array.is_null(row) {
292        return Ok(KeyValue::Null {
293            dtype: array.data_type().clone(),
294        });
295    }
296
297    use arrow::datatypes::DataType::*;
298    match array.data_type() {
299        Boolean => Ok(KeyValue::Boolean(
300            array
301                .as_any()
302                .downcast_ref::<arrow::array::BooleanArray>()
303                .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray downcast"))?
304                .value(row),
305        )),
306        Int8 => Ok(KeyValue::Signed(
307            array
308                .as_any()
309                .downcast_ref::<arrow::array::Int8Array>()
310                .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array downcast"))?
311                .value(row) as i128,
312        )),
313        Int16 => Ok(KeyValue::Signed(
314            array
315                .as_any()
316                .downcast_ref::<arrow::array::Int16Array>()
317                .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array downcast"))?
318                .value(row) as i128,
319        )),
320        Int32 => Ok(KeyValue::Signed(
321            array
322                .as_any()
323                .downcast_ref::<arrow::array::Int32Array>()
324                .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array downcast"))?
325                .value(row) as i128,
326        )),
327        Int64 => Ok(KeyValue::Signed(
328            array
329                .as_any()
330                .downcast_ref::<arrow::array::Int64Array>()
331                .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array downcast"))?
332                .value(row) as i128,
333        )),
334        UInt8 => Ok(KeyValue::Unsigned(
335            array
336                .as_any()
337                .downcast_ref::<arrow::array::UInt8Array>()
338                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array downcast"))?
339                .value(row) as u128,
340        )),
341        UInt16 => Ok(KeyValue::Unsigned(
342            array
343                .as_any()
344                .downcast_ref::<arrow::array::UInt16Array>()
345                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array downcast"))?
346                .value(row) as u128,
347        )),
348        UInt32 => Ok(KeyValue::Unsigned(
349            array
350                .as_any()
351                .downcast_ref::<arrow::array::UInt32Array>()
352                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array downcast"))?
353                .value(row) as u128,
354        )),
355        UInt64 => Ok(KeyValue::Unsigned(
356            array
357                .as_any()
358                .downcast_ref::<arrow::array::UInt64Array>()
359                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array downcast"))?
360                .value(row) as u128,
361        )),
362        Float32 => Ok(KeyValue::Float32(
363            array
364                .as_any()
365                .downcast_ref::<arrow::array::Float32Array>()
366                .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array downcast"))?
367                .value(row)
368                .to_bits(),
369        )),
370        Float64 => Ok(KeyValue::Float64(
371            array
372                .as_any()
373                .downcast_ref::<arrow::array::Float64Array>()
374                .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array downcast"))?
375                .value(row)
376                .to_bits(),
377        )),
378        Utf8 => Ok(KeyValue::Utf8(
379            array
380                .as_any()
381                .downcast_ref::<arrow::array::StringArray>()
382                .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray downcast"))?
383                .value(row)
384                .to_string(),
385        )),
386        other => Err(DataFrameError::invalid_operation(format!(
387            "unsupported join key type {other:?}",
388        ))),
389    }
390}