use std::collections::{HashMap, HashSet};
use chrono::{DateTime, Days};
use delta_kernel_derive::internal_api;
use tracing::debug;
use crate::engine::arrow_utils::RowIndexBuilder;
use crate::expressions::{ColumnName, DecimalData, Predicate, Scalar};
use crate::kernel_predicates::parquet_stats_skipping::ParquetStatsProvider;
use crate::parquet::arrow::arrow_reader::ArrowReaderBuilder;
use crate::parquet::file::metadata::RowGroupMetaData;
use crate::parquet::file::statistics::Statistics;
use crate::parquet::schema::types::ColumnDescPtr;
use crate::schema::{DataType, DecimalType, PrimitiveType};
#[cfg(test)]
mod tests;
#[internal_api]
pub(crate) trait ParquetRowGroupSkipping {
fn with_row_group_filter(
self,
predicate: &Predicate,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self;
#[allow(dead_code)]
fn with_checkpoint_row_group_filter(
self,
predicate: &Predicate,
partition_columns: &HashSet<String>,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self;
}
impl<T> ParquetRowGroupSkipping for ArrowReaderBuilder<T> {
fn with_row_group_filter(
self,
predicate: &Predicate,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self {
let ordinals: Vec<_> = self
.metadata()
.row_groups()
.iter()
.enumerate()
.filter_map(|(ordinal, row_group)| {
RowGroupFilter::apply(row_group, predicate).then_some(ordinal)
})
.collect();
debug!("with_row_group_filter({predicate:#?}) = {ordinals:?})");
if let Some(row_indexes) = row_indexes {
row_indexes.select_row_groups(&ordinals);
}
self.with_row_groups(ordinals)
}
fn with_checkpoint_row_group_filter(
self,
predicate: &Predicate,
partition_columns: &HashSet<String>,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self {
let ordinals: Vec<_> = self
.metadata()
.row_groups()
.iter()
.enumerate()
.filter_map(|(ordinal, row_group)| {
CheckpointRowGroupFilter::apply(row_group, predicate, partition_columns)
.then_some(ordinal)
})
.collect();
debug!("with_checkpoint_row_group_filter({predicate:#?}) = {ordinals:?})");
if let Some(row_indexes) = row_indexes {
row_indexes.select_row_groups(&ordinals);
}
self.with_row_groups(ordinals)
}
}
struct RowGroupFilter<'a> {
row_group: &'a RowGroupMetaData,
field_indices: HashMap<ColumnName, usize>,
}
impl<'a> RowGroupFilter<'a> {
fn new(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> Self {
Self {
row_group,
field_indices: compute_field_indices(row_group.schema_descr().columns(), predicate),
}
}
fn apply(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> bool {
use crate::kernel_predicates::KernelPredicateEvaluator as _;
RowGroupFilter::new(row_group, predicate).eval_sql_where(predicate) != Some(false)
}
fn get_stats(&self, col: &ColumnName) -> Option<Option<&Statistics>> {
self.field_indices
.get(col)
.map(|&i| self.row_group.column(i).statistics())
}
}
impl ParquetStatsProvider for RowGroupFilter<'_> {
fn get_parquet_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
extract_min_scalar(data_type, self.get_stats(col)??)
}
fn get_parquet_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
extract_max_scalar(data_type, self.get_stats(col)??)
}
fn get_parquet_nullcount_stat(&self, col: &ColumnName) -> Option<i64> {
let Some(stats) = self.get_stats(col) else {
return self.get_parquet_rowcount_stat().filter(|_| false);
};
extract_nullcount(stats)
}
fn get_parquet_rowcount_stat(&self) -> Option<i64> {
Some(self.row_group.num_rows())
}
}
fn extract_min_scalar(data_type: &DataType, stats: &Statistics) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, stats) {
(String, Statistics::ByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(),
(String, Statistics::FixedLenByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(),
(String, _) => return None,
(Long, Statistics::Int64(s)) => s.min_opt()?.into(),
(Long, Statistics::Int32(s)) => (*s.min_opt()? as i64).into(),
(Long, _) => return None,
(Integer, Statistics::Int32(s)) => s.min_opt()?.into(),
(Integer, _) => return None,
(Short, Statistics::Int32(s)) => (*s.min_opt()? as i16).into(),
(Short, _) => return None,
(Byte, Statistics::Int32(s)) => (*s.min_opt()? as i8).into(),
(Byte, _) => return None,
(Float, Statistics::Float(s)) => s.min_opt()?.into(),
(Float, _) => return None,
(Double, Statistics::Double(s)) => s.min_opt()?.into(),
(Double, Statistics::Float(s)) => (*s.min_opt()? as f64).into(),
(Double, _) => return None,
(Boolean, Statistics::Boolean(s)) => s.min_opt()?.into(),
(Boolean, _) => return None,
(Binary, Statistics::ByteArray(s)) => s.min_opt()?.data().into(),
(Binary, Statistics::FixedLenByteArray(s)) => s.min_opt()?.data().into(),
(Binary, _) => return None,
(Date, Statistics::Int32(s)) => Scalar::Date(*s.min_opt()?),
(Date, _) => return None,
#[cfg(feature = "nanosecond-timestamps")]
(TimestampNanos, Statistics::Int64(s)) => Scalar::TimestampNanos(*s.min_opt()?),
#[cfg(feature = "nanosecond-timestamps")]
(TimestampNanos, _) => return None, (Timestamp, Statistics::Int64(s)) => Scalar::Timestamp(*s.min_opt()?),
(Timestamp, _) => return None, (TimestampNtz, Statistics::Int64(s)) => Scalar::TimestampNtz(*s.min_opt()?),
(TimestampNtz, Statistics::Int32(s)) => timestamp_from_date(s.min_opt())?,
(TimestampNtz, _) => return None, (Decimal(d), Statistics::Int32(i)) => DecimalData::try_new(*i.min_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::Int64(i)) => DecimalData::try_new(*i.min_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::FixedLenByteArray(b)) => {
decimal_from_bytes(b.min_bytes_opt(), *d)?
}
(Decimal(..), _) => return None,
};
Some(value)
}
fn extract_max_scalar(data_type: &DataType, stats: &Statistics) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, stats) {
(String, Statistics::ByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(),
(String, Statistics::FixedLenByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(),
(String, _) => return None,
(Long, Statistics::Int64(s)) => s.max_opt()?.into(),
(Long, Statistics::Int32(s)) => (*s.max_opt()? as i64).into(),
(Long, _) => return None,
(Integer, Statistics::Int32(s)) => s.max_opt()?.into(),
(Integer, _) => return None,
(Short, Statistics::Int32(s)) => (*s.max_opt()? as i16).into(),
(Short, _) => return None,
(Byte, Statistics::Int32(s)) => (*s.max_opt()? as i8).into(),
(Byte, _) => return None,
(Float, Statistics::Float(s)) => s.max_opt()?.into(),
(Float, _) => return None,
(Double, Statistics::Double(s)) => s.max_opt()?.into(),
(Double, Statistics::Float(s)) => (*s.max_opt()? as f64).into(),
(Double, _) => return None,
(Boolean, Statistics::Boolean(s)) => s.max_opt()?.into(),
(Boolean, _) => return None,
(Binary, Statistics::ByteArray(s)) => s.max_opt()?.data().into(),
(Binary, Statistics::FixedLenByteArray(s)) => s.max_opt()?.data().into(),
(Binary, _) => return None,
(Date, Statistics::Int32(s)) => Scalar::Date(*s.max_opt()?),
(Date, _) => return None,
#[cfg(feature = "nanosecond-timestamps")]
(TimestampNanos, Statistics::Int64(s)) => Scalar::TimestampNanos(*s.max_opt()?),
#[cfg(feature = "nanosecond-timestamps")]
(TimestampNanos, _) => return None, (Timestamp, Statistics::Int64(s)) => Scalar::Timestamp(*s.max_opt()?),
(Timestamp, _) => return None, (TimestampNtz, Statistics::Int64(s)) => Scalar::TimestampNtz(*s.max_opt()?),
(TimestampNtz, Statistics::Int32(s)) => timestamp_from_date(s.max_opt())?,
(TimestampNtz, _) => return None, (Decimal(d), Statistics::Int32(i)) => DecimalData::try_new(*i.max_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::Int64(i)) => DecimalData::try_new(*i.max_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::FixedLenByteArray(b)) => {
decimal_from_bytes(b.max_bytes_opt(), *d)?
}
(Decimal(..), _) => return None,
};
Some(value)
}
fn extract_nullcount(stats: Option<&Statistics>) -> Option<i64> {
let nullcount = stats?.null_count_opt().filter(|n| *n > 0);
Some(nullcount? as i64)
}
fn decimal_from_bytes(bytes: Option<&[u8]>, dtype: DecimalType) -> Option<Scalar> {
let bytes = bytes.filter(|b| b.len() <= 16)?;
let mut bytes = Vec::from(bytes);
bytes.reverse();
bytes.resize(16, 0u8);
let bytes: [u8; 16] = bytes.try_into().ok()?;
let value = DecimalData::try_new(i128::from_le_bytes(bytes), dtype).ok()?;
Some(value.into())
}
fn timestamp_from_date(days: Option<&i32>) -> Option<Scalar> {
let days = u64::try_from(*days?).ok()?;
let timestamp = DateTime::UNIX_EPOCH.checked_add_days(Days::new(days))?;
let timestamp = timestamp.signed_duration_since(DateTime::UNIX_EPOCH);
Some(Scalar::TimestampNtz(timestamp.num_microseconds()?))
}
#[allow(dead_code)]
fn column_has_nulls(row_group: &RowGroupMetaData, col_index: usize) -> bool {
row_group
.column(col_index)
.statistics()
.and_then(|s| s.null_count_opt())
.is_none_or(|n| n > 0)
}
#[derive(Default)]
#[allow(dead_code)]
struct StatsColumnIndices {
min_index: Option<usize>,
max_index: Option<usize>,
nullcount_index: Option<usize>,
}
#[allow(dead_code)]
pub(crate) struct CheckpointRowGroupFilter<'a> {
row_group: &'a RowGroupMetaData,
stats_column_indices: HashMap<ColumnName, StatsColumnIndices>,
partition_column_indices: HashMap<ColumnName, usize>,
partition_columns: &'a HashSet<String>,
}
#[allow(dead_code)]
impl<'a> CheckpointRowGroupFilter<'a> {
pub(crate) fn new(
row_group: &'a RowGroupMetaData,
predicate: &Predicate,
partition_columns: &'a HashSet<String>,
) -> Self {
let (stats_column_indices, partition_column_indices) = compute_checkpoint_field_indices(
row_group.schema_descr().columns(),
predicate,
partition_columns,
);
Self {
row_group,
stats_column_indices,
partition_column_indices,
partition_columns,
}
}
pub(crate) fn apply(
row_group: &'a RowGroupMetaData,
predicate: &Predicate,
partition_columns: &'a HashSet<String>,
) -> bool {
use crate::kernel_predicates::KernelPredicateEvaluator as _;
CheckpointRowGroupFilter::new(row_group, predicate, partition_columns)
.eval_sql_where(predicate)
!= Some(false)
}
fn is_partition_column(&self, col: &ColumnName) -> bool {
is_partition_column(col, self.partition_columns)
}
fn get_stats_at(&self, index: usize) -> Option<&Statistics> {
self.row_group.column(index).statistics()
}
fn get_guarded_stat(
&self,
col: &ColumnName,
data_type: &DataType,
get_index: impl FnOnce(&StatsColumnIndices) -> Option<usize>,
extract: impl FnOnce(&DataType, &Statistics) -> Option<Scalar>,
) -> Option<Scalar> {
let indices = self.stats_column_indices.get(col)?;
let stat_index = get_index(indices)?;
if column_has_nulls(self.row_group, stat_index) {
return None;
}
extract(data_type, self.get_stats_at(stat_index)?)
}
}
impl ParquetStatsProvider for CheckpointRowGroupFilter<'_> {
fn get_parquet_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
if self.is_partition_column(col) {
let &idx = self.partition_column_indices.get(col)?;
return extract_min_scalar(data_type, self.get_stats_at(idx)?);
}
self.get_guarded_stat(col, data_type, |i| i.min_index, extract_min_scalar)
}
fn get_parquet_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
if self.is_partition_column(col) {
let &idx = self.partition_column_indices.get(col)?;
return extract_max_scalar(data_type, self.get_stats_at(idx)?);
}
let max = self.get_guarded_stat(col, data_type, |i| i.max_index, extract_max_scalar)?;
Some(adjust_stats_for_truncation(max))
}
fn get_parquet_nullcount_stat(&self, col: &ColumnName) -> Option<i64> {
if self.is_partition_column(col) {
let &idx = self.partition_column_indices.get(col)?;
return extract_nullcount(self.get_stats_at(idx));
}
let indices = self.stats_column_indices.get(col)?;
let nullcount_index = indices.nullcount_index?;
if column_has_nulls(self.row_group, nullcount_index) {
return None;
}
let stats = self.get_stats_at(nullcount_index)?;
extract_max_i64(stats)
}
fn get_parquet_rowcount_stat(&self) -> Option<i64> {
None
}
}
#[allow(dead_code)]
fn adjust_stats_for_truncation(val: Scalar) -> Scalar {
match val {
Scalar::Timestamp(us) => Scalar::Timestamp(us.saturating_add(999)),
Scalar::TimestampNtz(us) => Scalar::TimestampNtz(us.saturating_add(999)),
other => other,
}
}
#[allow(dead_code)]
fn extract_max_i64(stats: &Statistics) -> Option<i64> {
match stats {
Statistics::Int64(s) => Some(*s.max_opt()?),
Statistics::Int32(s) => Some(i64::from(*s.max_opt()?)),
_ => None,
}
}
pub(crate) fn compute_field_indices(
fields: &[ColumnDescPtr],
predicate: &Predicate,
) -> HashMap<ColumnName, usize> {
let mut requested_columns = predicate.references();
fields
.iter()
.enumerate()
.filter_map(|(i, f)| {
requested_columns
.take(f.path().parts())
.map(|path| (path.clone(), i))
})
.collect()
}
#[allow(dead_code)]
fn is_partition_column(col: &ColumnName, partition_columns: &HashSet<String>) -> bool {
let path = col.path();
path.len() == 1 && partition_columns.contains(path[0].as_str())
}
#[allow(dead_code)]
fn compute_checkpoint_field_indices(
fields: &[ColumnDescPtr],
predicate: &Predicate,
partition_columns: &HashSet<String>,
) -> (
HashMap<ColumnName, StatsColumnIndices>,
HashMap<ColumnName, usize>,
) {
let referenced_columns = predicate.references();
let mut stats_indices: HashMap<ColumnName, StatsColumnIndices> = HashMap::new();
let mut partition_indices: HashMap<ColumnName, usize> = HashMap::new();
for (i, field) in fields.iter().enumerate() {
let parts = field.path().parts();
if parts.len() == 3 && parts[0] == "add" && parts[1] == "partitionValues_parsed" {
let col_name = ColumnName::new([&parts[2]]);
if referenced_columns.contains(&col_name)
&& is_partition_column(&col_name, partition_columns)
{
partition_indices.insert(col_name, i);
}
continue;
}
if parts.len() >= 4 && parts[0] == "add" && parts[1] == "stats_parsed" {
let stat_type = parts[2].as_str();
let col_name = ColumnName::new(&parts[3..]);
if !referenced_columns.contains(&col_name)
|| is_partition_column(&col_name, partition_columns)
{
continue;
}
let entry = stats_indices.entry(col_name).or_default();
match stat_type {
"minValues" => entry.min_index = Some(i),
"maxValues" => entry.max_index = Some(i),
"nullCount" => entry.nullcount_index = Some(i),
_ => {}
}
}
}
(stats_indices, partition_indices)
}