alopex_dataframe/ops/
unique.rs1use std::collections::HashSet;
2use std::sync::Arc;
3
4use arrow::array::{Array, UInt32Builder};
5use arrow::datatypes::{DataType, Schema};
6use arrow::record_batch::RecordBatch;
7
8use crate::{DataFrameError, Result};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11struct RowKey(Vec<KeyValue>);
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum KeyValue {
15 Null { dtype: DataType },
16 Boolean(bool),
17 Signed(i128),
18 Unsigned(u128),
19 Float32(u32),
20 Float64(u64),
21 Utf8(String),
22}
23
24pub fn unique_batches(
25 input: Vec<RecordBatch>,
26 subset: Option<&[String]>,
27) -> Result<Vec<RecordBatch>> {
28 let batch = concat_batches(&input)?;
29 if batch.num_rows() == 0 {
30 return Ok(vec![batch]);
31 }
32
33 let indices = resolve_subset(&batch, subset)?;
34
35 let mut seen = HashSet::<RowKey>::new();
36 let mut selected = Vec::new();
37 for row in 0..batch.num_rows() {
38 let key = build_key(&batch, &indices, row)?;
39 if seen.insert(key) {
40 selected.push(row);
41 }
42 }
43
44 let index_array = build_indices(&selected)?;
45 let mut arrays = Vec::with_capacity(batch.num_columns());
46 for col in batch.columns() {
47 let array = arrow::compute::take(col.as_ref(), &index_array, None)
48 .map_err(|source| DataFrameError::Arrow { source })?;
49 arrays.push(array);
50 }
51
52 let batch = RecordBatch::try_new(batch.schema(), arrays).map_err(|e| {
53 DataFrameError::schema_mismatch(format!("failed to build RecordBatch: {e}"))
54 })?;
55 Ok(vec![batch])
56}
57
58fn concat_batches(batches: &[RecordBatch]) -> Result<RecordBatch> {
59 if batches.is_empty() {
60 return Ok(RecordBatch::new_empty(Arc::new(Schema::empty())));
61 }
62 let schema = batches[0].schema();
63 if batches.len() == 1 {
64 return Ok(batches[0].clone());
65 }
66 arrow::compute::concat_batches(&schema, batches)
67 .map_err(|source| DataFrameError::Arrow { source })
68}
69
70fn resolve_subset(batch: &RecordBatch, subset: Option<&[String]>) -> Result<Vec<usize>> {
71 let schema = batch.schema();
72 let indices = match subset {
73 Some(cols) => {
74 if cols.is_empty() {
75 return Err(DataFrameError::invalid_operation(
76 "unique subset must be non-empty",
77 ));
78 }
79 cols.iter()
80 .map(|name| {
81 schema
82 .fields()
83 .iter()
84 .position(|f| f.name() == name)
85 .ok_or_else(|| DataFrameError::column_not_found(name.clone()))
86 })
87 .collect::<Result<Vec<_>>>()?
88 }
89 None => (0..schema.fields().len()).collect(),
90 };
91 Ok(indices)
92}
93
94fn build_indices(indices: &[usize]) -> Result<arrow::array::UInt32Array> {
95 let mut builder = UInt32Builder::with_capacity(indices.len());
96 for idx in indices {
97 let value = u32::try_from(*idx)
98 .map_err(|_| DataFrameError::invalid_operation("row index exceeds u32 range"))?;
99 builder.append_value(value);
100 }
101 Ok(builder.finish())
102}
103
104fn build_key(batch: &RecordBatch, indices: &[usize], row: usize) -> Result<RowKey> {
105 let mut values = Vec::with_capacity(indices.len());
106 for idx in indices {
107 let array = batch.column(*idx).as_ref();
108 values.push(key_value_from_array(array, row)?);
109 }
110 Ok(RowKey(values))
111}
112
113fn key_value_from_array(array: &dyn Array, row: usize) -> Result<KeyValue> {
114 if array.is_null(row) {
115 return Ok(KeyValue::Null {
116 dtype: array.data_type().clone(),
117 });
118 }
119
120 use arrow::datatypes::DataType::*;
121 match array.data_type() {
122 Boolean => Ok(KeyValue::Boolean(
123 array
124 .as_any()
125 .downcast_ref::<arrow::array::BooleanArray>()
126 .ok_or_else(|| DataFrameError::invalid_operation("bad BooleanArray downcast"))?
127 .value(row),
128 )),
129 Int8 => Ok(KeyValue::Signed(
130 array
131 .as_any()
132 .downcast_ref::<arrow::array::Int8Array>()
133 .ok_or_else(|| DataFrameError::invalid_operation("bad Int8Array downcast"))?
134 .value(row) as i128,
135 )),
136 Int16 => Ok(KeyValue::Signed(
137 array
138 .as_any()
139 .downcast_ref::<arrow::array::Int16Array>()
140 .ok_or_else(|| DataFrameError::invalid_operation("bad Int16Array downcast"))?
141 .value(row) as i128,
142 )),
143 Int32 => Ok(KeyValue::Signed(
144 array
145 .as_any()
146 .downcast_ref::<arrow::array::Int32Array>()
147 .ok_or_else(|| DataFrameError::invalid_operation("bad Int32Array downcast"))?
148 .value(row) as i128,
149 )),
150 Int64 => Ok(KeyValue::Signed(
151 array
152 .as_any()
153 .downcast_ref::<arrow::array::Int64Array>()
154 .ok_or_else(|| DataFrameError::invalid_operation("bad Int64Array downcast"))?
155 .value(row) as i128,
156 )),
157 UInt8 => Ok(KeyValue::Unsigned(
158 array
159 .as_any()
160 .downcast_ref::<arrow::array::UInt8Array>()
161 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt8Array downcast"))?
162 .value(row) as u128,
163 )),
164 UInt16 => Ok(KeyValue::Unsigned(
165 array
166 .as_any()
167 .downcast_ref::<arrow::array::UInt16Array>()
168 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt16Array downcast"))?
169 .value(row) as u128,
170 )),
171 UInt32 => Ok(KeyValue::Unsigned(
172 array
173 .as_any()
174 .downcast_ref::<arrow::array::UInt32Array>()
175 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt32Array downcast"))?
176 .value(row) as u128,
177 )),
178 UInt64 => Ok(KeyValue::Unsigned(
179 array
180 .as_any()
181 .downcast_ref::<arrow::array::UInt64Array>()
182 .ok_or_else(|| DataFrameError::invalid_operation("bad UInt64Array downcast"))?
183 .value(row) as u128,
184 )),
185 Float32 => Ok(KeyValue::Float32(
186 array
187 .as_any()
188 .downcast_ref::<arrow::array::Float32Array>()
189 .ok_or_else(|| DataFrameError::invalid_operation("bad Float32Array downcast"))?
190 .value(row)
191 .to_bits(),
192 )),
193 Float64 => Ok(KeyValue::Float64(
194 array
195 .as_any()
196 .downcast_ref::<arrow::array::Float64Array>()
197 .ok_or_else(|| DataFrameError::invalid_operation("bad Float64Array downcast"))?
198 .value(row)
199 .to_bits(),
200 )),
201 Utf8 => Ok(KeyValue::Utf8(
202 array
203 .as_any()
204 .downcast_ref::<arrow::array::StringArray>()
205 .ok_or_else(|| DataFrameError::invalid_operation("bad StringArray downcast"))?
206 .value(row)
207 .to_string(),
208 )),
209 other => Err(DataFrameError::invalid_operation(format!(
210 "unsupported unique key type {other:?}",
211 ))),
212 }
213}