Skip to main content

alopex_dataframe/ops/
unique.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use arrow::array::{Array, UInt32Builder};
5use arrow::datatypes::{DataType, Schema};
6use arrow::record_batch::RecordBatch;
7
8use crate::{DataFrameError, Result};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11struct RowKey(Vec<KeyValue>);
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum KeyValue {
15    Null { dtype: DataType },
16    Boolean(bool),
17    Signed(i128),
18    Unsigned(u128),
19    Float32(u32),
20    Float64(u64),
21    Utf8(String),
22}
23
24pub fn unique_batches(
25    input: Vec<RecordBatch>,
26    subset: Option<&[String]>,
27) -> Result<Vec<RecordBatch>> {
28    let batch = concat_batches(&input)?;
29    if batch.num_rows() == 0 {
30        return Ok(vec![batch]);
31    }
32
33    let indices = resolve_subset(&batch, subset)?;
34
35    let mut seen = HashSet::<RowKey>::new();
36    let mut selected = Vec::new();
37    for row in 0..batch.num_rows() {
38        let key = build_key(&batch, &indices, row)?;
39        if seen.insert(key) {
40            selected.push(row);
41        }
42    }
43
44    let index_array = build_indices(&selected)?;
45    let mut arrays = Vec::with_capacity(batch.num_columns());
46    for col in batch.columns() {
47        let array = arrow::compute::take(col.as_ref(), &index_array, None)
48            .map_err(|source| DataFrameError::Arrow { source })?;
49        arrays.push(array);
50    }
51
52    let batch = RecordBatch::try_new(batch.schema(), arrays).map_err(|e| {
53        DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
54    })?;
55    Ok(vec![batch])
56}
57
58fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
59    if batches.is_empty() {
60        return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
61    }
62    let schema = batches[0].schema();
63    if batches.len() == 1 {
64        return Ok(batches[0].clone());
65    }
66    arrow::compute::concat_batches(&schema, batches)
67        .map_err(|source| DataFrameError::Arrow { source })
68}
69
70fn resolve_subset(batch: &RecordBatch, subset: Option<&[String]>) -> Result<Vec<usize>> {
71    let schema = batch.schema();
72    let indices = match subset {
73        Some(cols) => {
74            if cols.is_empty() {
75                return Err(DataFrameError::invalid_operation(
76                    "unique subset must be non-empty",
77                ));
78            }
79            cols.iter()
80                .map(|name| {
81                    schema
82                        .fields()
83                        .iter()
84                        .position(|f| f.name() == name)
85                        .ok_or_else(|| DataFrameError::column_not_found(name.clone()))
86                })
87                .collect::<Result<Vec<_>>>()?
88        }
89        None => (0..schema.fields().len()).collect(),
90    };
91    Ok(indices)
92}
93
94fn build_indices(indices: &[usize]) -> Result<arrow::array::UInt32Array> {
95    let mut builder = UInt32Builder::with_capacity(indices.len());
96    for idx in indices {
97        let value = u32::try_from(*idx)
98            .map_err(|_| DataFrameError::invalid_operation("row index exceeds u32 range"))?;
99        builder.append_value(value);
100    }
101    Ok(builder.finish())
102}
103
104fn build_key(batch: &RecordBatch, indices: &[usize], row: usize) -> Result<RowKey> {
105    let mut values = Vec::with_capacity(indices.len());
106    for idx in indices {
107        let array = batch.column(*idx).as_ref();
108        values.push(key_value_from_array(array, row)?);
109    }
110    Ok(RowKey(values))
111}
112
113fn key_value_from_array(array: &dyn Array, row: usize) -> Result<KeyValue> {
114    if array.is_null(row) {
115        return Ok(KeyValue::Null {
116            dtype: array.data_type().clone(),
117        });
118    }
119
120    use arrow::datatypes::DataType::*;
121    match array.data_type() {
122        Boolean => Ok(KeyValue::Boolean(
123            array
124                .as_any()
125                .downcast_ref::<arrow::array::BooleanArray>()
126                .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray downcast"))?
127                .value(row),
128        )),
129        Int8 => Ok(KeyValue::Signed(
130            array
131                .as_any()
132                .downcast_ref::<arrow::array::Int8Array>()
133                .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array downcast"))?
134                .value(row) as i128,
135        )),
136        Int16 => Ok(KeyValue::Signed(
137            array
138                .as_any()
139                .downcast_ref::<arrow::array::Int16Array>()
140                .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array downcast"))?
141                .value(row) as i128,
142        )),
143        Int32 => Ok(KeyValue::Signed(
144            array
145                .as_any()
146                .downcast_ref::<arrow::array::Int32Array>()
147                .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array downcast"))?
148                .value(row) as i128,
149        )),
150        Int64 => Ok(KeyValue::Signed(
151            array
152                .as_any()
153                .downcast_ref::<arrow::array::Int64Array>()
154                .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array downcast"))?
155                .value(row) as i128,
156        )),
157        UInt8 => Ok(KeyValue::Unsigned(
158            array
159                .as_any()
160                .downcast_ref::<arrow::array::UInt8Array>()
161                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array downcast"))?
162                .value(row) as u128,
163        )),
164        UInt16 => Ok(KeyValue::Unsigned(
165            array
166                .as_any()
167                .downcast_ref::<arrow::array::UInt16Array>()
168                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array downcast"))?
169                .value(row) as u128,
170        )),
171        UInt32 => Ok(KeyValue::Unsigned(
172            array
173                .as_any()
174                .downcast_ref::<arrow::array::UInt32Array>()
175                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array downcast"))?
176                .value(row) as u128,
177        )),
178        UInt64 => Ok(KeyValue::Unsigned(
179            array
180                .as_any()
181                .downcast_ref::<arrow::array::UInt64Array>()
182                .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array downcast"))?
183                .value(row) as u128,
184        )),
185        Float32 => Ok(KeyValue::Float32(
186            array
187                .as_any()
188                .downcast_ref::<arrow::array::Float32Array>()
189                .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array downcast"))?
190                .value(row)
191                .to_bits(),
192        )),
193        Float64 => Ok(KeyValue::Float64(
194            array
195                .as_any()
196                .downcast_ref::<arrow::array::Float64Array>()
197                .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array downcast"))?
198                .value(row)
199                .to_bits(),
200        )),
201        Utf8 => Ok(KeyValue::Utf8(
202            array
203                .as_any()
204                .downcast_ref::<arrow::array::StringArray>()
205                .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray downcast"))?
206                .value(row)
207                .to_string(),
208        )),
209        other => Err(DataFrameError::invalid_operation(format!(
210            "unsupported unique key type {other:?}",
211        ))),
212    }
213}