1use std::{collections::HashSet, hash::Hash};
12
13use arrow::{
14 array::{
15 as_primitive_array, as_string_array, ArrayRef, BooleanArray, BooleanBufferBuilder,
16 PrimitiveArray, StringArray,
17 },
18 compute::{
19 and, filter, filter_record_batch,
20 kernels::cmp::{distinct, eq},
21 },
22 datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use itertools::{iproduct, Itertools};
27
28use iceberg_rust_spec::{partition::BoundPartitionField, spec::values::Value};
29
30use super::transform::transform_arrow;
31
32pub fn partition_record_batch<'a>(
52 record_batch: &'a RecordBatch,
53 partition_fields: &[BoundPartitionField<'_>],
54) -> Result<impl Iterator<Item = Result<(Vec<Value>, RecordBatch), ArrowError>> + 'a, ArrowError> {
55 let partition_columns: Vec<ArrayRef> = partition_fields
56 .iter()
57 .map(|field| {
58 let array = record_batch
59 .column_by_name(field.source_name())
60 .ok_or(ArrowError::SchemaError("Column doesn't exist".to_string()))?;
61 transform_arrow(array.clone(), field.transform())
62 })
63 .collect::<Result<_, ArrowError>>()?;
64 let distinct_values: Vec<DistinctValues> = partition_columns
65 .iter()
66 .map(|x| distinct_values(x.clone()))
67 .collect::<Result<Vec<_>, ArrowError>>()?;
68 let mut true_buffer = BooleanBufferBuilder::new(record_batch.num_rows());
69 true_buffer.append_n(record_batch.num_rows(), true);
70 let predicates = distinct_values
71 .into_iter()
72 .zip(partition_columns.iter())
73 .map(|(distinct, value)| match distinct {
74 DistinctValues::Int(set) => set
75 .into_iter()
76 .map(|x| {
77 Ok((
78 Value::Int(x),
79 eq(&PrimitiveArray::<Int32Type>::new_scalar(x), value)?,
80 ))
81 })
82 .collect::<Result<Vec<_>, ArrowError>>(),
83 DistinctValues::Long(set) => set
84 .into_iter()
85 .map(|x| {
86 Ok((
87 Value::LongInt(x),
88 eq(&PrimitiveArray::<Int64Type>::new_scalar(x), value)?,
89 ))
90 })
91 .collect::<Result<Vec<_>, ArrowError>>(),
92 DistinctValues::String(set) => set
93 .into_iter()
94 .map(|x| {
95 let res = eq(&StringArray::new_scalar(&x), value)?;
96 Ok((Value::String(x), res))
97 })
98 .collect::<Result<Vec<_>, ArrowError>>(),
99 })
100 .try_fold(
101 vec![(vec![], BooleanArray::new(true_buffer.finish(), None))],
102 |acc, predicates| {
103 iproduct!(acc, predicates?.iter())
104 .map(|((mut values, x), (value, y))| {
105 values.push(value.clone());
106 Ok((values, and(&x, y)?))
107 })
108 .filter_ok(|x| x.1.true_count() != 0)
109 .collect::<Result<Vec<(Vec<Value>, _)>, ArrowError>>()
110 },
111 )?;
112 Ok(predicates.into_iter().map(move |(values, predicate)| {
113 Ok((values, filter_record_batch(record_batch, &predicate)?))
114 }))
115}
116
117fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
131 match array.data_type() {
132 DataType::Int32 => Ok(DistinctValues::Int(distinct_values_primitive::<
133 i32,
134 Int32Type,
135 >(array)?)),
136 DataType::Int64 => Ok(DistinctValues::Long(distinct_values_primitive::<
137 i64,
138 Int64Type,
139 >(array)?)),
140 DataType::Utf8 => Ok(DistinctValues::String(distinct_values_string(array)?)),
141 _ => Err(ArrowError::ComputeError(
142 "Datatype not supported for transform.".to_string(),
143 )),
144 }
145}
146
147fn distinct_values_primitive<T: Eq + Hash, P: ArrowPrimitiveType<Native = T>>(
159 array: ArrayRef,
160) -> Result<HashSet<P::Native>, ArrowError> {
161 let array = as_primitive_array::<P>(&array);
162
163 let first = array.value(0);
164
165 let slice_len = array.len() - 1;
166
167 if slice_len == 0 {
168 return Ok(HashSet::from_iter([first]));
169 }
170
171 let v1 = array.slice(0, slice_len);
172 let v2 = array.slice(1, slice_len);
173
174 let mask = distinct(&v1, &v2)?;
176
177 let unique = filter(&v2, &mask)?;
178
179 let unique = as_primitive_array::<P>(&unique);
180
181 let set = unique
182 .iter()
183 .fold(HashSet::from_iter([first]), |mut acc, x| {
184 if let Some(x) = x {
185 acc.insert(x);
186 }
187 acc
188 });
189 Ok(set)
190}
191
192fn distinct_values_string(array: ArrayRef) -> Result<HashSet<String>, ArrowError> {
200 let slice_len = array.len() - 1;
201
202 let array = as_string_array(&array);
203
204 let first = array.value(0).to_owned();
205
206 if slice_len == 0 {
207 return Ok(HashSet::from_iter([first]));
208 }
209
210 let v1 = array.slice(0, slice_len);
211 let v2 = array.slice(1, slice_len);
212
213 let mask = distinct(&v1, &v2)?;
215
216 let unique = filter(&v2, &mask)?;
217
218 let unique = as_string_array(&unique);
219
220 let set = unique
221 .iter()
222 .fold(HashSet::from_iter([first]), |mut acc, x| {
223 if let Some(x) = x {
224 acc.insert(x.to_owned());
225 }
226 acc
227 });
228 Ok(set)
229}
230
231enum DistinctValues {
238 Int(HashSet<i32>),
239 Long(HashSet<i64>),
240 String(HashSet<String>),
241}