iceberg_rust/arrow/
partition.rs

1//! Arrow-based partitioning implementation for Iceberg tables
2//!
3//! This module provides functionality to partition Arrow record batches according to Iceberg partition
4//! specifications. It includes:
5//!
6//! * Streaming partition implementation that processes record batches asynchronously
7//! * Support for different partition transforms (identity, bucket, truncate)
8//! * Efficient handling of distinct partition values
9//! * Automatic management of partition streams and channels
10
11use std::{collections::HashSet, hash::Hash};
12
13use arrow::{
14    array::{
15        as_primitive_array, as_string_array, ArrayRef, BooleanArray, BooleanBufferBuilder,
16        PrimitiveArray, StringArray,
17    },
18    compute::{
19        and, filter, filter_record_batch,
20        kernels::cmp::{distinct, eq},
21    },
22    datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type},
23    error::ArrowError,
24    record_batch::RecordBatch,
25};
26use itertools::{iproduct, Itertools};
27
28use iceberg_rust_spec::{partition::BoundPartitionField, spec::values::Value};
29
30use super::transform::transform_arrow;
31
32/// Partitions a record batch according to the given partition fields.
33///
34/// This function takes a record batch and partition field specifications, then splits the batch into
35/// multiple record batches based on unique combinations of partition values.
36///
37/// # Arguments
38/// * `record_batch` - The input record batch to partition
39/// * `partition_fields` - The partition field specifications that define how to split the data
40///
41/// # Returns
42/// An iterator over results containing:
43/// * A vector of partition values that identify the partition
44/// * The record batch containing only rows matching those partition values
45///
46/// # Errors
47/// Returns an ArrowError if:
48/// * Required columns are missing from the record batch
49/// * Transformation operations fail
50/// * Data type conversions fail
51pub fn partition_record_batch<'a>(
52    record_batch: &'a RecordBatch,
53    partition_fields: &[BoundPartitionField<'_>],
54) -> Result<impl Iterator<Item = Result<(Vec<Value>, RecordBatch), ArrowError>> + 'a, ArrowError> {
55    let partition_columns: Vec<ArrayRef> = partition_fields
56        .iter()
57        .map(|field| {
58            let array = record_batch
59                .column_by_name(field.source_name())
60                .ok_or(ArrowError::SchemaError("Column doesn't exist".to_string()))?;
61            transform_arrow(array.clone(), field.transform())
62        })
63        .collect::<Result<_, ArrowError>>()?;
64    let distinct_values: Vec<DistinctValues> = partition_columns
65        .iter()
66        .map(|x| distinct_values(x.clone()))
67        .collect::<Result<Vec<_>, ArrowError>>()?;
68    let mut true_buffer = BooleanBufferBuilder::new(record_batch.num_rows());
69    true_buffer.append_n(record_batch.num_rows(), true);
70    let predicates = distinct_values
71        .into_iter()
72        .zip(partition_columns.iter())
73        .map(|(distinct, value)| match distinct {
74            DistinctValues::Int(set) => set
75                .into_iter()
76                .map(|x| {
77                    Ok((
78                        Value::Int(x),
79                        eq(&PrimitiveArray::<Int32Type>::new_scalar(x), value)?,
80                    ))
81                })
82                .collect::<Result<Vec<_>, ArrowError>>(),
83            DistinctValues::Long(set) => set
84                .into_iter()
85                .map(|x| {
86                    Ok((
87                        Value::LongInt(x),
88                        eq(&PrimitiveArray::<Int64Type>::new_scalar(x), value)?,
89                    ))
90                })
91                .collect::<Result<Vec<_>, ArrowError>>(),
92            DistinctValues::String(set) => set
93                .into_iter()
94                .map(|x| {
95                    let res = eq(&StringArray::new_scalar(&x), value)?;
96                    Ok((Value::String(x), res))
97                })
98                .collect::<Result<Vec<_>, ArrowError>>(),
99        })
100        .try_fold(
101            vec![(vec![], BooleanArray::new(true_buffer.finish(), None))],
102            |acc, predicates| {
103                iproduct!(acc, predicates?.iter())
104                    .map(|((mut values, x), (value, y))| {
105                        values.push(value.clone());
106                        Ok((values, and(&x, y)?))
107                    })
108                    .filter_ok(|x| x.1.true_count() != 0)
109                    .collect::<Result<Vec<(Vec<Value>, _)>, ArrowError>>()
110            },
111        )?;
112    Ok(predicates.into_iter().map(move |(values, predicate)| {
113        Ok((values, filter_record_batch(record_batch, &predicate)?))
114    }))
115}
116
117/// Extracts distinct values from an Arrow array into a DistinctValues enum
118///
119/// # Arguments
120/// * `array` - The Arrow array to extract distinct values from
121///
122/// # Returns
123/// * `Ok(DistinctValues)` - An enum containing a HashSet of the distinct values
124/// * `Err(ArrowError)` - If the array's data type is not supported
125///
126/// # Supported Data Types
127/// * Int32 - Converted to DistinctValues::Int
128/// * Int64 - Converted to DistinctValues::Long
129/// * Utf8 - Converted to DistinctValues::String
130fn distinct_values(array: ArrayRef) -> Result<DistinctValues, ArrowError> {
131    match array.data_type() {
132        DataType::Int32 => Ok(DistinctValues::Int(distinct_values_primitive::<
133            i32,
134            Int32Type,
135        >(array)?)),
136        DataType::Int64 => Ok(DistinctValues::Long(distinct_values_primitive::<
137            i64,
138            Int64Type,
139        >(array)?)),
140        DataType::Utf8 => Ok(DistinctValues::String(distinct_values_string(array)?)),
141        _ => Err(ArrowError::ComputeError(
142            "Datatype not supported for transform.".to_string(),
143        )),
144    }
145}
146
147/// Extracts distinct primitive values from an Arrow array into a HashSet
148///
149/// # Type Parameters
150/// * `T` - The Rust native type that implements Eq + Hash
151/// * `P` - The Arrow primitive type corresponding to T
152///
153/// # Arguments
154/// * `array` - The Arrow array to extract distinct values from
155///
156/// # Returns
157/// A HashSet containing all unique values from the array
158fn distinct_values_primitive<T: Eq + Hash, P: ArrowPrimitiveType<Native = T>>(
159    array: ArrayRef,
160) -> Result<HashSet<P::Native>, ArrowError> {
161    let array = as_primitive_array::<P>(&array);
162
163    let first = array.value(0);
164
165    let slice_len = array.len() - 1;
166
167    if slice_len == 0 {
168        return Ok(HashSet::from_iter([first]));
169    }
170
171    let v1 = array.slice(0, slice_len);
172    let v2 = array.slice(1, slice_len);
173
174    // Which consecutive entries are different
175    let mask = distinct(&v1, &v2)?;
176
177    let unique = filter(&v2, &mask)?;
178
179    let unique = as_primitive_array::<P>(&unique);
180
181    let set = unique
182        .iter()
183        .fold(HashSet::from_iter([first]), |mut acc, x| {
184            if let Some(x) = x {
185                acc.insert(x);
186            }
187            acc
188        });
189    Ok(set)
190}
191
192/// Extracts distinct string values from an Arrow array into a HashSet
193///
194/// # Arguments
195/// * `array` - The Arrow array to extract distinct values from
196///
197/// # Returns
198/// A HashSet containing all unique string values from the array
199fn distinct_values_string(array: ArrayRef) -> Result<HashSet<String>, ArrowError> {
200    let slice_len = array.len() - 1;
201
202    let array = as_string_array(&array);
203
204    let first = array.value(0).to_owned();
205
206    if slice_len == 0 {
207        return Ok(HashSet::from_iter([first]));
208    }
209
210    let v1 = array.slice(0, slice_len);
211    let v2 = array.slice(1, slice_len);
212
213    // Which consecutive entries are different
214    let mask = distinct(&v1, &v2)?;
215
216    let unique = filter(&v2, &mask)?;
217
218    let unique = as_string_array(&unique);
219
220    let set = unique
221        .iter()
222        .fold(HashSet::from_iter([first]), |mut acc, x| {
223            if let Some(x) = x {
224                acc.insert(x.to_owned());
225            }
226            acc
227        });
228    Ok(set)
229}
230
231/// Represents distinct values found in Arrow arrays during partitioning
232///
233/// This enum stores unique values from different Arrow array types:
234/// * `Int` - Distinct 32-bit integer values
235/// * `Long` - Distinct 64-bit integer values  
236/// * `String` - Distinct string values
237enum DistinctValues {
238    Int(HashSet<i32>),
239    Long(HashSet<i64>),
240    String(HashSet<String>),
241}