use std::{
collections::{hash_map::Entry, HashMap, HashSet},
hash::Hash,
pin::Pin,
sync::Arc,
};
use arrow::{
array::{
as_primitive_array, as_string_array, ArrayIter, ArrayRef, BooleanArray,
BooleanBufferBuilder, PrimitiveArray, StringArray,
},
compute::kernels::cmp::eq,
compute::{and, filter_record_batch},
datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type},
error::ArrowError,
record_batch::RecordBatch,
};
use futures::{
channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
lock::Mutex,
stream, SinkExt, Stream, StreamExt, TryStreamExt,
};
use itertools::{iproduct, Itertools};
use iceberg_rust_spec::spec::{partition::PartitionSpec, schema::Schema, values::Value};
use super::transform::transform_arrow;
type SendableRecordBatchStream =
Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>;
type RecordBatchSender = UnboundedSender<Result<RecordBatch, ArrowError>>;
pub async fn partition_record_batches(
record_batches: impl Stream<Item = Result<RecordBatch, ArrowError>> + Send,
partition_spec: &PartitionSpec,
schema: &Schema,
) -> Result<
Vec<(
Vec<Value>,
impl Stream<Item = Result<RecordBatch, ArrowError>> + Send,
)>,
ArrowError,
> {
let (partition_sender, partition_reciever): (
UnboundedSender<(Vec<Value>, SendableRecordBatchStream)>,
UnboundedReceiver<(Vec<Value>, SendableRecordBatchStream)>,
) = unbounded();
let partition_streams: Arc<Mutex<HashMap<Vec<Value>, RecordBatchSender>>> =
Arc::new(Mutex::new(HashMap::new()));
record_batches
.try_for_each_concurrent(None, |record_batch| {
let partition_streams = partition_streams.clone();
let partition_sender = partition_sender.clone();
async move {
let partition_columns: Vec<ArrayRef> = partition_spec
.fields()
.iter()
.map(|field| {
let column_name = &schema
.fields()
.get(*field.source_id() as usize)
.ok_or(ArrowError::SchemaError("Column doesn't exist".to_string()))?
.name;
let array = record_batch
.column_by_name(column_name)
.ok_or(ArrowError::SchemaError("Column doesn't exist".to_string()))?;
transform_arrow(array.clone(), field.transform())
})
.collect::<Result<_, ArrowError>>()?;
let distinct_values: Vec<DistinctValues> = partition_columns
.iter()
.map(|x| distinct_values(x.clone()))
.collect::<Result<Vec<_>, ArrowError>>()?;
let mut true_buffer = BooleanBufferBuilder::new(record_batch.num_rows());
true_buffer.append_n(record_batch.num_rows(), true);
let predicates = distinct_values
.into_iter()
.zip(partition_columns.iter())
.map(|(distinct, value)| match distinct {
DistinctValues::Int(set) => set
.into_iter()
.map(|x| {
Ok((
Value::Int(x),
eq(&PrimitiveArray::<Int32Type>::new_scalar(x), value)?,
))
})
.collect::<Result<Vec<_>, ArrowError>>(),
DistinctValues::Long(set) => set
.into_iter()
.map(|x| {
Ok((
Value::LongInt(x),
eq(&PrimitiveArray::<Int64Type>::new_scalar(x), value)?,
))
})
.collect::<Result<Vec<_>, ArrowError>>(),
DistinctValues::String(set) => set
.into_iter()
.map(|x| {
let res = eq(&StringArray::new_scalar(&x), value)?;
Ok((Value::String(x), res))
})
.collect::<Result<Vec<_>, ArrowError>>(),
})
.try_fold(
vec![(vec![], BooleanArray::new(true_buffer.finish(), None))],
|acc, predicates| {
iproduct!(acc, predicates?.iter())
.map(|((mut values, x), (value, y))| {
values.push(value.clone());
Ok((values, and(&x, y)?))
})
.filter_ok(|x| x.1.true_count() != 0)
.collect::<Result<Vec<(Vec<Value>, _)>, ArrowError>>()
},
)?;
stream::iter(predicates.into_iter())
.map(|(values, predicate)| {
Ok((values, filter_record_batch(&record_batch, &predicate)?))
})
.try_for_each_concurrent(None, |(values, batch)| {
let partition_streams = partition_streams.clone();
let mut partition_sender = partition_sender.clone();
async move {
let mut sender = {
let mut partition_streams = partition_streams.lock().await;
let entry = partition_streams.entry(values.clone());
match entry {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => {
let (sender, reciever) = unbounded();
entry.insert(sender.clone());
partition_sender.send((values,Box::pin(reciever))).await.map_err(
|err| ArrowError::ExternalError(Box::new(err)),
)?;
sender
}
}
};
sender
.send(Ok(batch))
.await
.map_err(|err| ArrowError::ExternalError(Box::new(err)))?;
Ok::<_, ArrowError>(())
}
})
.await?;
Ok(())
}
})
.await?;
Arc::into_inner(partition_streams)
.ok_or(ArrowError::MemoryError(
"Couldn't close partition streams.".to_string(),
))?
.into_inner()
.into_values()
.for_each(|sender| sender.close_channel());
partition_sender.close_channel();
let recievers = partition_reciever.collect().await;
Ok(recievers)
}
fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
match array.data_type() {
DataType::Int32 => Ok(DistinctValues::Int(distinct_values_primitive::<
i32,
Int32Type,
>(array))),
DataType::Int64 => Ok(DistinctValues::Long(distinct_values_primitive::<
i64,
Int64Type,
>(array))),
DataType::Utf8 => Ok(DistinctValues::String(distinct_values_string(array))),
_ => Err(ArrowError::ComputeError(
"Datatype not supported for transform.".to_string(),
)),
}
}
fn distinct_values_primitive<T: Eq + Hash, P: ArrowPrimitiveType<Native = T>>(
array: ArrayRef,
) -> HashSet<P::Native> {
let mut set = HashSet::new();
let array = as_primitive_array::<P>(&array);
for value in ArrayIter::new(array).flatten() {
if !set.contains(&value) {
set.insert(value);
}
}
set
}
fn distinct_values_string(array: ArrayRef) -> HashSet<String> {
let mut set = HashSet::new();
let array = as_string_array(&array);
for value in ArrayIter::new(array).flatten() {
if !set.contains(value) {
set.insert(value.to_owned());
}
}
set
}
enum DistinctValues {
Int(HashSet<i32>),
Long(HashSet<i64>),
String(HashSet<String>),
}
#[cfg(test)]
mod tests {
use futures::{stream, StreamExt};
use std::sync::Arc;
use arrow::{
array::{ArrayRef, Int64Array, StringArray},
error::ArrowError,
record_batch::RecordBatch,
};
use iceberg_rust_spec::spec::{
partition::{PartitionField, PartitionSpecBuilder, Transform},
schema::Schema,
types::{PrimitiveType, StructField, StructType, Type},
};
use super::partition_record_batches;
#[tokio::test]
async fn test_partition() {
let batch1 = RecordBatch::try_from_iter(vec![
(
"x",
Arc::new(Int64Array::from(vec![1, 1, 1, 1, 2, 3])) as ArrayRef,
),
(
"y",
Arc::new(Int64Array::from(vec![1, 2, 2, 2, 1, 1])) as ArrayRef,
),
(
"z",
Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E", "F"])) as ArrayRef,
),
])
.unwrap();
let batch2 = RecordBatch::try_from_iter(vec![
(
"x",
Arc::new(Int64Array::from(vec![1, 1, 2, 2, 2, 3])) as ArrayRef,
),
(
"y",
Arc::new(Int64Array::from(vec![1, 2, 2, 2, 1, 1])) as ArrayRef,
),
(
"z",
Arc::new(StringArray::from(vec!["A", "B", "C", "D", "E", "F"])) as ArrayRef,
),
])
.unwrap();
let record_batches = stream::iter(
vec![Ok::<_, ArrowError>(batch1), Ok::<_, ArrowError>(batch2)].into_iter(),
);
let schema = Schema::builder()
.with_schema_id(0)
.with_fields(
StructType::builder()
.with_struct_field(StructField {
id: 1,
name: "x".to_string(),
field_type: Type::Primitive(PrimitiveType::Int),
required: true,
doc: None,
})
.with_struct_field(StructField {
id: 2,
name: "y".to_string(),
field_type: Type::Primitive(PrimitiveType::Int),
required: true,
doc: None,
})
.with_struct_field(StructField {
id: 3,
name: "z".to_string(),
field_type: Type::Primitive(PrimitiveType::String),
required: true,
doc: None,
})
.build()
.unwrap(),
)
.build()
.unwrap();
let partition_spec = PartitionSpecBuilder::default()
.with_spec_id(0)
.with_partition_field(PartitionField::new(1, 1001, "x", Transform::Identity))
.build()
.unwrap();
let streams = partition_record_batches(record_batches, &partition_spec, &schema)
.await
.unwrap();
let output = stream::iter(streams.into_iter())
.then(|s| async move { s.1.collect::<Vec<_>>().await })
.collect::<Vec<_>>()
.await;
for x in output {
for y in x {
y.unwrap();
}
}
}
}