use std::borrow::Cow;
use std::sync::Arc;
use delta_kernel_derive::internal_api;
use crate::arrow::array::{
new_null_array, Array, ArrayRef, AsArray, BooleanArray, Decimal128Array, Int64Array,
LargeStringArray, PrimitiveArray, RecordBatch, StringArray, StringViewArray, StructArray,
};
use crate::arrow::compute::kernels::aggregate::{max, max_string, min, min_string};
use crate::arrow::datatypes::{
ArrowPrimitiveType, DataType, Date32Type, Date64Type, Decimal128Type, Field, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type,
UInt64Type, UInt8Type,
};
use crate::column_trie::ColumnTrie;
use crate::engine::arrow_utils::fix_nested_null_masks;
use crate::expressions::ColumnName;
use crate::{DeltaResult, Error};
const STRING_PREFIX_LENGTH: usize = 32;
const STRING_EXPANSION_LIMIT: usize = STRING_PREFIX_LENGTH * 2;
const ASCII_MAX_CHAR: char = '\x7F';
const UTF8_MAX_CHAR: char = '\u{10FFFF}';
fn truncate_min_string(s: &str) -> &str {
if s.len() <= STRING_PREFIX_LENGTH {
return s;
}
let end = s
.char_indices()
.take(STRING_PREFIX_LENGTH + 1)
.last()
.map(|(i, _)| i)
.unwrap_or(s.len());
let truncated_end = s
.char_indices()
.nth(STRING_PREFIX_LENGTH)
.map(|(i, _)| i)
.unwrap_or(end);
&s[..truncated_end]
}
fn truncate_max_string(s: &str) -> Option<Cow<'_, str>> {
if s.len() <= STRING_PREFIX_LENGTH {
return Some(Cow::Borrowed(s));
}
let char_indices: Vec<(usize, char)> = s.char_indices().collect();
let max_chars = char_indices.len().min(STRING_EXPANSION_LIMIT);
for len in STRING_PREFIX_LENGTH..=max_chars {
if len >= char_indices.len() {
return Some(Cow::Borrowed(s));
}
let (_, next_char) = char_indices[len];
if next_char == UTF8_MAX_CHAR {
continue;
}
let truncation_byte_idx = char_indices[len].0;
let truncated = &s[..truncation_byte_idx];
let tie_breaker = if next_char < ASCII_MAX_CHAR {
ASCII_MAX_CHAR
} else {
UTF8_MAX_CHAR
};
return Some(Cow::Owned(format!("{truncated}{tie_breaker}")));
}
None
}
#[derive(Clone, Copy)]
enum Agg {
Min,
Max,
}
fn agg_primitive<T>(column: &ArrayRef, agg: Agg) -> DeltaResult<Option<ArrayRef>>
where
T: ArrowPrimitiveType,
T::Native: PartialOrd,
PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
{
let array = column.as_primitive_opt::<T>().ok_or_else(|| {
Error::generic(format!(
"Failed to downcast column to PrimitiveArray<{}>",
std::any::type_name::<T>()
))
})?;
let result = match agg {
Agg::Min => min(array),
Agg::Max => max(array),
};
Ok(result.map(|v| Arc::new(PrimitiveArray::<T>::from(vec![Some(v)])) as ArrayRef))
}
fn agg_timestamp<T>(
column: &ArrayRef,
tz: Option<Arc<str>>,
agg: Agg,
) -> DeltaResult<Option<ArrayRef>>
where
T: crate::arrow::datatypes::ArrowTimestampType,
PrimitiveArray<T>: From<Vec<Option<i64>>>,
{
let array = column.as_primitive_opt::<T>().ok_or_else(|| {
Error::generic(format!(
"Failed to downcast column to PrimitiveArray<{}>",
std::any::type_name::<T>()
))
})?;
let result = match agg {
Agg::Min => min(array),
Agg::Max => max(array),
};
Ok(result.map(|v| {
Arc::new(PrimitiveArray::<T>::from(vec![Some(v)]).with_timezone_opt(tz)) as ArrayRef
}))
}
fn agg_decimal(
column: &ArrayRef,
precision: u8,
scale: i8,
agg: Agg,
) -> DeltaResult<Option<ArrayRef>> {
let array = column
.as_primitive_opt::<Decimal128Type>()
.ok_or_else(|| Error::generic("Failed to downcast column to Decimal128Array"))?;
let result = match agg {
Agg::Min => min(array),
Agg::Max => max(array),
};
result
.map(|v| {
Decimal128Array::from(vec![Some(v)])
.with_precision_and_scale(precision, scale)
.map(|arr| Arc::new(arr) as ArrayRef)
})
.transpose()
.map_err(|e| Error::generic(format!("Invalid decimal precision/scale: {e}")))
}
fn agg_string(column: &ArrayRef, agg: Agg) -> DeltaResult<Option<ArrayRef>> {
let array = column
.as_string_opt::<i32>()
.ok_or_else(|| Error::generic("Failed to downcast column to StringArray"))?;
let result = match agg {
Agg::Min => min_string(array),
Agg::Max => max_string(array),
};
match (result, agg) {
(Some(s), Agg::Min) => {
let truncated = truncate_min_string(s);
Ok(Some(
Arc::new(StringArray::from(vec![Some(truncated)])) as ArrayRef
))
}
(Some(s), Agg::Max) => Ok(truncate_max_string(s)
.map(|t| Arc::new(StringArray::from(vec![Some(&*t)])) as ArrayRef)),
(None, _) => Ok(None),
}
}
fn agg_large_string(column: &ArrayRef, agg: Agg) -> DeltaResult<Option<ArrayRef>> {
let array = column
.as_string_opt::<i64>()
.ok_or_else(|| Error::generic("Failed to downcast column to LargeStringArray"))?;
let result = match agg {
Agg::Min => array.iter().flatten().min(),
Agg::Max => array.iter().flatten().max(),
};
match (result, agg) {
(Some(s), Agg::Min) => {
let truncated = truncate_min_string(s);
Ok(Some(
Arc::new(LargeStringArray::from(vec![Some(truncated)])) as ArrayRef,
))
}
(Some(s), Agg::Max) => Ok(truncate_max_string(s)
.map(|t| Arc::new(LargeStringArray::from(vec![Some(&*t)])) as ArrayRef)),
(None, _) => Ok(None),
}
}
fn agg_string_view(column: &ArrayRef, agg: Agg) -> DeltaResult<Option<ArrayRef>> {
let array = column
.as_string_view_opt()
.ok_or_else(|| Error::generic("Failed to downcast column to StringViewArray"))?;
let result: Option<&str> = match agg {
Agg::Min => array.iter().flatten().min(),
Agg::Max => array.iter().flatten().max(),
};
match (result, agg) {
(Some(s), Agg::Min) => {
let truncated = truncate_min_string(s);
Ok(Some(
Arc::new(StringViewArray::from(vec![Some(truncated)])) as ArrayRef
))
}
(Some(s), Agg::Max) => Ok(truncate_max_string(s)
.map(|t| Arc::new(StringViewArray::from(vec![Some(&*t)])) as ArrayRef)),
(None, _) => Ok(None),
}
}
fn compute_leaf_agg(column: &ArrayRef, agg: Agg) -> DeltaResult<Option<ArrayRef>> {
match column.data_type() {
DataType::Int8 => agg_primitive::<Int8Type>(column, agg),
DataType::Int16 => agg_primitive::<Int16Type>(column, agg),
DataType::Int32 => agg_primitive::<Int32Type>(column, agg),
DataType::Int64 => agg_primitive::<Int64Type>(column, agg),
DataType::UInt8 => agg_primitive::<UInt8Type>(column, agg),
DataType::UInt16 => agg_primitive::<UInt16Type>(column, agg),
DataType::UInt32 => agg_primitive::<UInt32Type>(column, agg),
DataType::UInt64 => agg_primitive::<UInt64Type>(column, agg),
DataType::Float32 => agg_primitive::<Float32Type>(column, agg),
DataType::Float64 => agg_primitive::<Float64Type>(column, agg),
DataType::Date32 => agg_primitive::<Date32Type>(column, agg),
DataType::Date64 => agg_primitive::<Date64Type>(column, agg),
DataType::Timestamp(TimeUnit::Second, tz) => {
agg_timestamp::<TimestampSecondType>(column, tz.clone(), agg)
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
agg_timestamp::<TimestampMillisecondType>(column, tz.clone(), agg)
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
agg_timestamp::<TimestampMicrosecondType>(column, tz.clone(), agg)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
agg_timestamp::<TimestampNanosecondType>(column, tz.clone(), agg)
}
DataType::Decimal128(p, s) => agg_decimal(column, *p, *s, agg),
DataType::Utf8 => agg_string(column, agg),
DataType::LargeUtf8 => agg_large_string(column, agg),
DataType::Utf8View => agg_string_view(column, agg),
_ => Ok(None),
}
}
#[derive(Default)]
struct ColumnStats {
null_count: Option<ArrayRef>,
min_value: Option<ArrayRef>,
max_value: Option<ArrayRef>,
}
fn compute_column_stats(
column: &ArrayRef,
path: &mut Vec<String>,
filter: &ColumnTrie<'_>,
) -> DeltaResult<ColumnStats> {
match column.data_type() {
DataType::Struct(fields) => {
let struct_array = column
.as_struct_opt()
.ok_or_else(|| Error::generic("Failed to downcast column to StructArray"))?;
let fixed_struct = fix_nested_null_masks(struct_array.clone());
let mut null_fields: Vec<Field> = Vec::new();
let mut null_arrays: Vec<ArrayRef> = Vec::new();
let mut min_fields: Vec<Field> = Vec::new();
let mut min_arrays: Vec<ArrayRef> = Vec::new();
let mut max_fields: Vec<Field> = Vec::new();
let mut max_arrays: Vec<ArrayRef> = Vec::new();
for (i, field) in fields.iter().enumerate() {
path.push(field.name().to_string());
let child_stats = compute_column_stats(fixed_struct.column(i), path, filter)?;
if let Some(arr) = child_stats.null_count {
null_fields.push(Field::new(field.name(), arr.data_type().clone(), true));
null_arrays.push(arr);
}
if let Some(arr) = child_stats.min_value {
min_fields.push(Field::new(field.name(), arr.data_type().clone(), true));
min_arrays.push(arr);
}
if let Some(arr) = child_stats.max_value {
max_fields.push(Field::new(field.name(), arr.data_type().clone(), true));
max_arrays.push(arr);
}
path.pop();
}
let build_struct =
|fields: Vec<Field>, arrays: Vec<ArrayRef>| -> DeltaResult<Option<ArrayRef>> {
if fields.is_empty() {
Ok(None)
} else {
Ok(Some(Arc::new(
StructArray::try_new(fields.into(), arrays, None)
.map_err(|e| Error::generic(format!("stats struct: {e}")))?,
) as ArrayRef))
}
};
Ok(ColumnStats {
null_count: build_struct(null_fields, null_arrays)?,
min_value: build_struct(min_fields, min_arrays)?,
max_value: build_struct(max_fields, max_arrays)?,
})
}
DataType::Map(_, _)
| DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _)
| DataType::ListView(_)
| DataType::LargeListView(_) => {
if !filter.contains_prefix_of(path) {
return Ok(ColumnStats::default());
}
Ok(ColumnStats {
null_count: Some(Arc::new(Int64Array::from(vec![column.null_count() as i64]))),
min_value: None,
max_value: None,
})
}
_ => {
if !filter.contains_prefix_of(path) {
return Ok(ColumnStats::default());
}
let null_fallback = || -> ArrayRef { Arc::new(new_null_array(column.data_type(), 1)) };
Ok(ColumnStats {
null_count: Some(Arc::new(Int64Array::from(vec![column.null_count() as i64]))),
min_value: Some(compute_leaf_agg(column, Agg::Min)?.unwrap_or_else(&null_fallback)),
max_value: Some(compute_leaf_agg(column, Agg::Max)?.unwrap_or_else(null_fallback)),
})
}
}
}
struct StatsAccumulator {
name: &'static str,
fields: Vec<Field>,
arrays: Vec<ArrayRef>,
}
impl StatsAccumulator {
fn new(name: &'static str) -> Self {
Self {
name,
fields: Vec::new(),
arrays: Vec::new(),
}
}
fn push(&mut self, field_name: &str, array: ArrayRef) {
self.fields
.push(Field::new(field_name, array.data_type().clone(), true));
self.arrays.push(array);
}
fn build(self) -> DeltaResult<Option<(Field, Arc<dyn Array>)>> {
if self.fields.is_empty() {
return Ok(None);
}
let struct_arr = StructArray::try_new(self.fields.into(), self.arrays, None)
.map_err(|e| Error::generic(format!("Failed to create {}: {e}", self.name)))?;
let field = Field::new(self.name, struct_arr.data_type().clone(), true);
Ok(Some((field, Arc::new(struct_arr) as Arc<dyn Array>)))
}
}
#[internal_api]
pub(crate) fn collect_stats(
batch: &RecordBatch,
stats_columns: &[ColumnName],
) -> DeltaResult<StructArray> {
let filter = ColumnTrie::from_columns(stats_columns);
let schema = batch.schema();
let mut null_counts = StatsAccumulator::new("nullCount");
let mut min_values = StatsAccumulator::new("minValues");
let mut max_values = StatsAccumulator::new("maxValues");
for (col_idx, field) in schema.fields().iter().enumerate() {
let mut path = vec![field.name().to_string()];
let column = batch.column(col_idx);
let stats = compute_column_stats(column, &mut path, &filter)?;
if let Some(arr) = stats.null_count {
null_counts.push(field.name(), arr);
}
if let Some(arr) = stats.min_value {
min_values.push(field.name(), arr);
}
if let Some(arr) = stats.max_value {
max_values.push(field.name(), arr);
}
}
let mut fields = vec![Field::new("numRecords", DataType::Int64, true)];
let mut arrays: Vec<Arc<dyn Array>> =
vec![Arc::new(Int64Array::from(vec![batch.num_rows() as i64]))];
for acc in [null_counts, min_values, max_values] {
if let Some((field, array)) = acc.build()? {
fields.push(field);
arrays.push(array);
}
}
fields.push(Field::new("tightBounds", DataType::Boolean, true));
arrays.push(Arc::new(BooleanArray::from(vec![true])));
StructArray::try_new(fields.into(), arrays, None)
.map_err(|e| Error::generic(format!("Failed to create stats struct: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::{Array, AsArray, Int32Array, Int64Array, StringArray};
use crate::arrow::buffer::NullBuffer;
use crate::arrow::compute::concat_batches;
use crate::arrow::datatypes::{Fields, Int32Type, Int64Type, Schema};
use crate::engine::arrow_expression::evaluate_expression::to_json;
use crate::expressions::column_name;
use crate::parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
#[test]
fn test_collect_stats_single_batch() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch =
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
let stats = collect_stats(&batch, &[column_name!("id")]).unwrap();
assert_eq!(stats.len(), 1);
let num_records = stats
.column_by_name("numRecords")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(num_records.value(0), 3);
}
#[test]
fn test_collect_stats_null_counts() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])),
],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("id"), column_name!("value")]).unwrap();
let null_count = stats
.column_by_name("nullCount")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let id_null_count = null_count
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(id_null_count.value(0), 0);
let value_null_count = null_count
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(value_null_count.value(0), 1);
}
#[test]
fn test_collect_stats_respects_stats_columns() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])),
],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("id")]).unwrap();
let null_count = stats
.column_by_name("nullCount")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert!(null_count.column_by_name("id").is_some());
assert!(null_count.column_by_name("value").is_none());
}
#[test]
fn test_collect_stats_min_max() {
let schema = Arc::new(Schema::new(vec![
Field::new("number", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![5, 1, 9, 3])),
Arc::new(StringArray::from(vec![
Some("banana"),
Some("apple"),
Some("cherry"),
None,
])),
],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("number"), column_name!("name")]).unwrap();
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let number_min = min_values
.column_by_name("number")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(number_min.value(0), 1);
let name_min = min_values
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(name_min.value(0), "apple");
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let number_max = max_values
.column_by_name("number")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(number_max.value(0), 9);
let name_max = max_values
.column_by_name("name")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(name_max.value(0), "cherry");
}
#[test]
fn test_collect_stats_all_nulls() {
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
true,
)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Int64Array::from(vec![
None as Option<i64>,
None,
None,
]))],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("value")]).unwrap();
let num_records = stats
.column_by_name("numRecords")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(num_records.value(0), 3);
let null_count = stats
.column_by_name("nullCount")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let value_null_count = null_count
.column_by_name("value")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(value_null_count.value(0), 3);
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let min_col = min_values.column_by_name("value").unwrap();
assert!(min_col.is_null(0));
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let max_col = max_values.column_by_name("value").unwrap();
assert!(max_col.is_null(0));
}
#[test]
fn test_collect_stats_empty_stats_columns() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch =
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![1, 2, 3]))]).unwrap();
let stats = collect_stats(&batch, &[]).unwrap();
assert!(stats.column_by_name("numRecords").is_some());
assert!(stats.column_by_name("tightBounds").is_some());
assert!(stats.column_by_name("nullCount").is_none());
assert!(stats.column_by_name("minValues").is_none());
assert!(stats.column_by_name("maxValues").is_none());
}
#[test]
fn test_collect_stats_string_truncation_ascii() {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let long_string = "a".repeat(50);
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(StringArray::from(vec![long_string.as_str()]))],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("text")]).unwrap();
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let text_min = min_values
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(text_min.value(0).len(), 32);
assert_eq!(text_min.value(0), "a".repeat(32));
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let text_max = max_values
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let expected_max = format!("{}\x7F", "a".repeat(32));
assert_eq!(text_max.value(0), expected_max);
}
#[test]
fn test_collect_stats_string_truncation_non_ascii() {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let long_string = format!("{}À{}", "a".repeat(32), "b".repeat(20));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(StringArray::from(vec![long_string.as_str()]))],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("text")]).unwrap();
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let text_max = max_values
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let expected_max = format!("{}\u{10FFFF}", "a".repeat(32));
assert_eq!(text_max.value(0), expected_max);
}
#[test]
fn test_collect_stats_string_no_truncation_needed() {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let short_string = "hello world";
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(StringArray::from(vec![short_string]))],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("text")]).unwrap();
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let text_min = min_values
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(text_min.value(0), short_string);
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let text_max = max_values
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(text_max.value(0), short_string);
}
#[test]
fn test_truncate_min_string() {
assert_eq!(truncate_min_string("hello"), "hello");
let s32 = "a".repeat(32);
assert_eq!(truncate_min_string(&s32), s32);
let s50 = "a".repeat(50);
assert_eq!(truncate_min_string(&s50), "a".repeat(32));
let multi = format!("{}À", "a".repeat(35)); assert_eq!(truncate_min_string(&multi).chars().count(), 32);
}
#[test]
fn test_truncate_max_string() {
assert_eq!(truncate_max_string("hello").as_deref(), Some("hello"));
let s32 = "a".repeat(32);
assert_eq!(truncate_max_string(&s32).as_deref(), Some(s32.as_str()));
let s50 = "a".repeat(50);
let expected = format!("{}\x7F", "a".repeat(32));
assert_eq!(
truncate_max_string(&s50).as_deref(),
Some(expected.as_str())
);
let non_ascii = format!("{}À{}", "a".repeat(32), "b".repeat(20));
let expected = format!("{}\u{10FFFF}", "a".repeat(32));
assert_eq!(
truncate_max_string(&non_ascii).as_deref(),
Some(expected.as_str())
);
let with_max_char = format!("{}\u{10FFFF}b{}", "a".repeat(32), "c".repeat(10));
let expected = format!("{}\u{10FFFF}\x7F", "a".repeat(32)); assert_eq!(
truncate_max_string(&with_max_char).as_deref(),
Some(expected.as_str())
);
}
#[test]
fn test_collect_stats_nested_struct() {
let nested_fields = Fields::from(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true),
]);
let schema = Arc::new(Schema::new(vec![Field::new(
"nested",
DataType::Struct(nested_fields.clone()),
false,
)]));
let a_array = Arc::new(Int64Array::from(vec![10, 5, 20]));
let b_array = Arc::new(StringArray::from(vec![Some("zebra"), Some("apple"), None]));
let nested_struct = StructArray::try_new(
nested_fields,
vec![a_array as ArrayRef, b_array as ArrayRef],
None,
)
.unwrap();
let batch =
RecordBatch::try_new(schema, vec![Arc::new(nested_struct) as ArrayRef]).unwrap();
let stats = collect_stats(&batch, &[column_name!("nested")]).unwrap();
let null_count = stats
.column_by_name("nullCount")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let nested_null = null_count
.column_by_name("nested")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let a_null = nested_null
.column_by_name("a")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_null.value(0), 0);
let b_null = nested_null
.column_by_name("b")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(b_null.value(0), 1);
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let nested_min = min_values
.column_by_name("nested")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let a_min = nested_min
.column_by_name("a")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_min.value(0), 5);
let b_min = nested_min
.column_by_name("b")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(b_min.value(0), "apple");
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let nested_max = max_values
.column_by_name("nested")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let a_max = nested_max
.column_by_name("a")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(a_max.value(0), 20);
let b_max = nested_max
.column_by_name("b")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(b_max.value(0), "zebra");
}
#[test]
fn test_collect_stats_complex_types_null_count_only() {
use crate::arrow::array::ListArray;
use crate::arrow::buffer::OffsetBuffer;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new(
"list_col",
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
true,
),
]));
let values = Int64Array::from(vec![1, 2, 4, 5, 6]);
let offsets = OffsetBuffer::new(vec![0, 2, 2, 5].into());
let list_array = ListArray::new(
Arc::new(Field::new("item", DataType::Int64, true)),
offsets,
Arc::new(values),
Some(vec![true, false, true].into()), );
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(list_array),
],
)
.unwrap();
let stats = collect_stats(&batch, &[column_name!("id"), column_name!("list_col")]).unwrap();
let null_count = stats
.column_by_name("nullCount")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
let id_nulls = null_count
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(id_nulls.value(0), 0);
let list_nulls = null_count
.column_by_name("list_col")
.unwrap()
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(list_nulls.value(0), 1);
let min_values = stats
.column_by_name("minValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert!(min_values.column_by_name("id").is_some());
assert!(min_values.column_by_name("list_col").is_none());
let max_values = stats
.column_by_name("maxValues")
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap();
assert!(max_values.column_by_name("id").is_some());
assert!(max_values.column_by_name("list_col").is_none());
}
#[test]
fn test_collect_stats_struct_with_nulls_at_struct_level() {
let child_fields = Fields::from(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, true),
]);
let a_values = Int32Array::from(vec![1, 2, 3, 4]);
let b_values = Int32Array::from(vec![None, Some(20), None, Some(40)]);
let nulls = NullBuffer::from(vec![false, true, true, false]);
let struct_array = StructArray::new(
child_fields.clone(),
vec![Arc::new(a_values), Arc::new(b_values)],
Some(nulls),
);
let schema = Schema::new(vec![Field::new(
"my_struct",
DataType::Struct(child_fields),
true,
)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)]).unwrap();
let stats = collect_stats(&batch, &[column_name!("my_struct")]).unwrap();
assert_eq!(
get_stat::<Int64Type>(&stats, "nullCount", "my_struct", "a"),
2
);
assert_eq!(
get_stat::<Int64Type>(&stats, "nullCount", "my_struct", "b"),
3
);
assert_eq!(
get_stat::<Int32Type>(&stats, "minValues", "my_struct", "a"),
2
);
assert_eq!(
get_stat::<Int32Type>(&stats, "minValues", "my_struct", "b"),
20
);
assert_eq!(
get_stat::<Int32Type>(&stats, "maxValues", "my_struct", "a"),
3
);
assert_eq!(
get_stat::<Int32Type>(&stats, "maxValues", "my_struct", "b"),
20
);
}
fn get_stat<T>(
stats: &StructArray,
stat_name: &str,
struct_name: &str,
field_name: &str,
) -> T::Native
where
T: crate::arrow::datatypes::ArrowPrimitiveType,
{
stats
.column_by_name(stat_name)
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap()
.column_by_name(struct_name)
.unwrap()
.as_any()
.downcast_ref::<StructArray>()
.unwrap()
.column_by_name(field_name)
.unwrap()
.as_primitive::<T>()
.value(0)
}
fn extract_leaf_columns(fields: &Fields, prefix: &[String]) -> Vec<ColumnName> {
let mut columns = Vec::new();
for field in fields.iter() {
let mut path = prefix.to_vec();
path.push(field.name().clone());
match field.data_type() {
DataType::Struct(sub_fields) => {
columns.extend(extract_leaf_columns(sub_fields, &path));
}
_ => {
columns.push(ColumnName::new(path));
}
}
}
columns
}
fn assert_stats_match(
spark_val: &serde_json::Value,
kernel_val: &serde_json::Value,
path: &str,
) {
match (spark_val, kernel_val) {
(serde_json::Value::Object(spark_map), serde_json::Value::Object(kernel_map)) => {
for (key, spark_child) in spark_map {
let child_path = if path.is_empty() {
key.clone()
} else {
format!("{path}.{key}")
};
let kernel_child = kernel_map
.get(key)
.unwrap_or_else(|| panic!("Kernel stats missing key: {child_path}"));
assert_stats_match(spark_child, kernel_child, &child_path);
}
}
(serde_json::Value::Number(s), serde_json::Value::Number(k)) => {
let sv = s.as_f64().unwrap();
let kv = k.as_f64().unwrap();
assert!(
(sv - kv).abs() < 1e-6,
"Numeric mismatch at {path}: spark={sv}, kernel={kv}"
);
}
(serde_json::Value::String(s), serde_json::Value::String(k)) => {
let s_normalized = s.trim_end_matches('Z').trim_end_matches("+00:00");
let k_normalized = k.trim_end_matches('Z').trim_end_matches("+00:00");
if s_normalized.contains('T') && k_normalized.contains('T') {
let normalize_ts = |ts: &str| -> String {
if let Some(dot_pos) = ts.rfind('.') {
let frac = &ts[dot_pos + 1..];
if frac.chars().all(|c| c == '0') {
return ts[..dot_pos].to_string();
}
let trimmed = frac.trim_end_matches('0');
return format!("{}.{trimmed}", &ts[..dot_pos]);
}
ts.to_string()
};
let s_norm = normalize_ts(s_normalized);
let k_norm = normalize_ts(k_normalized);
assert_eq!(
s_norm, k_norm,
"Timestamp mismatch at {path}: spark={s}, kernel={k}"
);
} else {
assert_eq!(s, k, "String mismatch at {path}: spark={s}, kernel={k}");
}
}
_ => {
assert_eq!(
spark_val, kernel_val,
"Value mismatch at {path}: spark={spark_val}, kernel={kernel_val}"
);
}
}
}
#[test]
fn test_assert_stats_match_accepts_equivalent_values() {
let spark = serde_json::json!({"a": 1, "b": "hello"});
let kernel = serde_json::json!({"a": 1, "b": "hello", "extra": true});
assert_stats_match(&spark, &kernel, "");
let spark = serde_json::json!({"outer": {"inner": 42}});
let kernel = serde_json::json!({"outer": {"inner": 42, "extra": 0}});
assert_stats_match(&spark, &kernel, "");
let spark = serde_json::json!({"ts": "2023-06-15T12:30:00.000Z"});
let kernel = serde_json::json!({"ts": "2023-06-15T12:30:00Z"});
assert_stats_match(&spark, &kernel, "");
let spark = serde_json::json!({"ts": "2023-06-15T12:30:00.000"});
let kernel = serde_json::json!({"ts": "2023-06-15T12:30:00"});
assert_stats_match(&spark, &kernel, "");
let spark = serde_json::json!({"ts": "2023-06-15T12:30:00.500Z"});
let kernel = serde_json::json!({"ts": "2023-06-15T12:30:00.5Z"});
assert_stats_match(&spark, &kernel, "");
}
#[test]
fn test_assert_stats_match_rejects_mismatches() {
let result = std::panic::catch_unwind(|| {
let spark = serde_json::json!({"a": 1});
let kernel = serde_json::json!({"b": 1});
assert_stats_match(&spark, &kernel, "");
});
assert!(result.is_err(), "should panic on missing key");
let result = std::panic::catch_unwind(|| {
let spark = serde_json::json!({"val": 1.0});
let kernel = serde_json::json!({"val": 2.0});
assert_stats_match(&spark, &kernel, "");
});
assert!(result.is_err(), "should panic on numeric mismatch");
let result = std::panic::catch_unwind(|| {
let spark = serde_json::json!({"s": "alpha"});
let kernel = serde_json::json!({"s": "beta"});
assert_stats_match(&spark, &kernel, "");
});
assert!(result.is_err(), "should panic on string mismatch");
}
#[test]
fn test_collect_stats_matches_spark() {
let test_path =
std::fs::canonicalize("./tests/data/stats-writing-all-types/delta").unwrap();
let commit_path = test_path
.join("_delta_log")
.join("00000000000000000001.json");
let commit_data = std::fs::read_to_string(&commit_path).expect("read commit 1 json");
let mut spark_stats_json = None;
let mut parquet_path = None;
for line in commit_data.lines() {
let action: serde_json::Value = serde_json::from_str(line).expect("parse JSON line");
if let Some(add) = action.get("add") {
spark_stats_json = Some(
add["stats"]
.as_str()
.expect("stats should be a string")
.to_string(),
);
parquet_path = Some(
add["path"]
.as_str()
.expect("path should be a string")
.to_string(),
);
break;
}
}
let spark_stats_json = spark_stats_json.expect("should find add action with stats");
let parquet_path = parquet_path.expect("should find add action with path");
let spark_stats: serde_json::Value =
serde_json::from_str(&spark_stats_json).expect("parse Spark stats JSON");
let parquet_file_path = test_path.join(&parquet_path);
let file = std::fs::File::open(&parquet_file_path).expect("open parquet file");
let builder =
ParquetRecordBatchReaderBuilder::try_new(file).expect("create parquet reader builder");
let schema = builder.schema().clone();
let reader = builder.build().expect("build parquet reader");
let batches: Vec<RecordBatch> = reader.map(|b| b.expect("read batch")).collect();
let record_batch = concat_batches(&schema, &batches).expect("concat batches");
let stats_columns = extract_leaf_columns(schema.fields(), &[]);
let stats_struct = collect_stats(&record_batch, &stats_columns).expect("collect stats");
let json_array = to_json(&stats_struct).expect("convert stats to JSON");
let json_strings = json_array.as_string::<i32>();
assert_eq!(json_strings.len(), 1, "should have exactly one stats row");
let kernel_stats_json_str = json_strings.value(0);
let kernel_stats: serde_json::Value =
serde_json::from_str(kernel_stats_json_str).expect("parse kernel stats JSON");
assert_eq!(
spark_stats["numRecords"], kernel_stats["numRecords"],
"numRecords mismatch"
);
for section in &["nullCount", "minValues", "maxValues"] {
if let Some(spark_section) = spark_stats.get(*section) {
let kernel_section = kernel_stats
.get(*section)
.unwrap_or_else(|| panic!("Kernel stats missing {section}"));
assert_stats_match(spark_section, kernel_section, section);
}
}
}
}