use std::{collections::HashSet, hash::Hash};
use arrow::{
array::{
as_primitive_array, as_string_array, ArrayRef, BooleanArray, BooleanBufferBuilder,
PrimitiveArray, StringArray,
},
compute::{
and, filter, filter_record_batch,
kernels::cmp::{distinct, eq},
},
datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type},
error::ArrowError,
record_batch::RecordBatch,
};
use itertools::{iproduct, Itertools};
use iceberg_rust_spec::{partition::BoundPartitionField, spec::values::Value};
use super::transform::transform_arrow;
pub fn partition_record_batch<'a>(
record_batch: &'a RecordBatch,
partition_fields: &[BoundPartitionField<'_>],
) -> Result<impl Iterator<Item = Result<(Vec<Value>, RecordBatch), ArrowError>> + 'a, ArrowError> {
let partition_columns: Vec<ArrayRef> = partition_fields
.iter()
.map(|field| {
let array = record_batch
.column_by_name(field.source_name())
.ok_or(ArrowError::SchemaError("Column doesn't exist".to_string()))?;
transform_arrow(array.clone(), field.transform())
})
.collect::<Result<_, ArrowError>>()?;
let distinct_values: Vec<DistinctValues> = partition_columns
.iter()
.map(|x| distinct_values(x.clone()))
.collect::<Result<Vec<_>, ArrowError>>()?;
let mut true_buffer = BooleanBufferBuilder::new(record_batch.num_rows());
true_buffer.append_n(record_batch.num_rows(), true);
let predicates = distinct_values
.into_iter()
.zip(partition_columns.iter())
.map(|(distinct, value)| match distinct {
DistinctValues::Int(set) => set
.into_iter()
.map(|x| {
Ok((
Value::Int(x),
eq(&PrimitiveArray::<Int32Type>::new_scalar(x), value)?,
))
})
.collect::<Result<Vec<_>, ArrowError>>(),
DistinctValues::Long(set) => set
.into_iter()
.map(|x| {
Ok((
Value::LongInt(x),
eq(&PrimitiveArray::<Int64Type>::new_scalar(x), value)?,
))
})
.collect::<Result<Vec<_>, ArrowError>>(),
DistinctValues::String(set) => set
.into_iter()
.map(|x| {
let res = eq(&StringArray::new_scalar(&x), value)?;
Ok((Value::String(x), res))
})
.collect::<Result<Vec<_>, ArrowError>>(),
})
.try_fold(
vec![(vec![], BooleanArray::new(true_buffer.finish(), None))],
|acc, predicates| {
iproduct!(acc, predicates?.iter())
.map(|((mut values, x), (value, y))| {
values.push(value.clone());
Ok((values, and(&x, y)?))
})
.filter_ok(|x| x.1.true_count() != 0)
.collect::<Result<Vec<(Vec<Value>, _)>, ArrowError>>()
},
)?;
Ok(predicates.into_iter().map(move |(values, predicate)| {
Ok((values, filter_record_batch(record_batch, &predicate)?))
}))
}
fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
match array.data_type() {
DataType::Int32 => Ok(DistinctValues::Int(distinct_values_primitive::<
i32,
Int32Type,
>(array)?)),
DataType::Int64 => Ok(DistinctValues::Long(distinct_values_primitive::<
i64,
Int64Type,
>(array)?)),
DataType::Utf8 => Ok(DistinctValues::String(distinct_values_string(array)?)),
_ => Err(ArrowError::ComputeError(
"Datatype not supported for transform.".to_string(),
)),
}
}
fn distinct_values_primitive<T: Eq + Hash, P: ArrowPrimitiveType<Native = T>>(
array: ArrayRef,
) -> Result<HashSet<P::Native>, ArrowError> {
let array = as_primitive_array::<P>(&array);
let first = array.value(0);
let slice_len = array.len() - 1;
if slice_len == 0 {
return Ok(HashSet::from_iter([first]));
}
let v1 = array.slice(0, slice_len);
let v2 = array.slice(1, slice_len);
let mask = distinct(&v1, &v2)?;
let unique = filter(&v2, &mask)?;
let unique = as_primitive_array::<P>(&unique);
let set = unique
.iter()
.fold(HashSet::from_iter([first]), |mut acc, x| {
if let Some(x) = x {
acc.insert(x);
}
acc
});
Ok(set)
}
fn distinct_values_string(array: ArrayRef) -> Result<HashSet<String>, ArrowError> {
let slice_len = array.len() - 1;
let array = as_string_array(&array);
let first = array.value(0).to_owned();
if slice_len == 0 {
return Ok(HashSet::from_iter([first]));
}
let v1 = array.slice(0, slice_len);
let v2 = array.slice(1, slice_len);
let mask = distinct(&v1, &v2)?;
let unique = filter(&v2, &mask)?;
let unique = as_string_array(&unique);
let set = unique
.iter()
.fold(HashSet::from_iter([first]), |mut acc, x| {
if let Some(x) = x {
acc.insert(x.to_owned());
}
acc
});
Ok(set)
}
enum DistinctValues {
Int(HashSet<i32>),
Long(HashSet<i64>),
String(HashSet<String>),
}