use std::collections::HashMap;
use chrono::{DateTime, Days};
use delta_kernel_derive::internal_api;
use tracing::debug;
use crate::engine::arrow_utils::RowIndexBuilder;
use crate::expressions::{ColumnName, DecimalData, Predicate, Scalar};
use crate::kernel_predicates::parquet_stats_skipping::ParquetStatsProvider;
use crate::parquet::arrow::arrow_reader::ArrowReaderBuilder;
use crate::parquet::file::metadata::RowGroupMetaData;
use crate::parquet::file::statistics::Statistics;
use crate::parquet::schema::types::ColumnDescPtr;
use crate::schema::{DataType, DecimalType, PrimitiveType};
#[cfg(test)]
mod tests;
#[internal_api]
pub(crate) trait ParquetRowGroupSkipping {
fn with_row_group_filter(
self,
predicate: &Predicate,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self;
}
impl<T> ParquetRowGroupSkipping for ArrowReaderBuilder<T> {
fn with_row_group_filter(
self,
predicate: &Predicate,
row_indexes: Option<&mut RowIndexBuilder>,
) -> Self {
let ordinals: Vec<_> = self
.metadata()
.row_groups()
.iter()
.enumerate()
.filter_map(|(ordinal, row_group)| {
RowGroupFilter::apply(row_group, predicate).then_some(ordinal)
})
.collect();
debug!("with_row_group_filter({predicate:#?}) = {ordinals:?})");
if let Some(row_indexes) = row_indexes {
row_indexes.select_row_groups(&ordinals);
}
self.with_row_groups(ordinals)
}
}
struct RowGroupFilter<'a> {
row_group: &'a RowGroupMetaData,
field_indices: HashMap<ColumnName, usize>,
}
impl<'a> RowGroupFilter<'a> {
fn new(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> Self {
Self {
row_group,
field_indices: compute_field_indices(row_group.schema_descr().columns(), predicate),
}
}
fn apply(row_group: &'a RowGroupMetaData, predicate: &Predicate) -> bool {
use crate::kernel_predicates::KernelPredicateEvaluator as _;
RowGroupFilter::new(row_group, predicate).eval_sql_where(predicate) != Some(false)
}
fn get_stats(&self, col: &ColumnName) -> Option<Option<&Statistics>> {
self.field_indices
.get(col)
.map(|&i| self.row_group.column(i).statistics())
}
}
impl ParquetStatsProvider for RowGroupFilter<'_> {
fn get_parquet_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
extract_min_scalar(data_type, self.get_stats(col)??)
}
fn get_parquet_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option<Scalar> {
extract_max_scalar(data_type, self.get_stats(col)??)
}
fn get_parquet_nullcount_stat(&self, col: &ColumnName) -> Option<i64> {
let Some(stats) = self.get_stats(col) else {
return self.get_parquet_rowcount_stat().filter(|_| false);
};
extract_nullcount(stats)
}
fn get_parquet_rowcount_stat(&self) -> Option<i64> {
Some(self.row_group.num_rows())
}
}
fn extract_min_scalar(data_type: &DataType, stats: &Statistics) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, stats) {
(String, Statistics::ByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(),
(String, Statistics::FixedLenByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(),
(String, _) => return None,
(Long, Statistics::Int64(s)) => s.min_opt()?.into(),
(Long, Statistics::Int32(s)) => (*s.min_opt()? as i64).into(),
(Long, _) => return None,
(Integer, Statistics::Int32(s)) => s.min_opt()?.into(),
(Integer, _) => return None,
(Short, Statistics::Int32(s)) => (*s.min_opt()? as i16).into(),
(Short, _) => return None,
(Byte, Statistics::Int32(s)) => (*s.min_opt()? as i8).into(),
(Byte, _) => return None,
(Float, Statistics::Float(s)) => s.min_opt()?.into(),
(Float, _) => return None,
(Double, Statistics::Double(s)) => s.min_opt()?.into(),
(Double, Statistics::Float(s)) => (*s.min_opt()? as f64).into(),
(Double, _) => return None,
(Boolean, Statistics::Boolean(s)) => s.min_opt()?.into(),
(Boolean, _) => return None,
(Binary, Statistics::ByteArray(s)) => s.min_opt()?.data().into(),
(Binary, Statistics::FixedLenByteArray(s)) => s.min_opt()?.data().into(),
(Binary, _) => return None,
(Date, Statistics::Int32(s)) => Scalar::Date(*s.min_opt()?),
(Date, _) => return None,
(Timestamp, Statistics::Int64(s)) => Scalar::Timestamp(*s.min_opt()?),
(Timestamp, _) => return None, (TimestampNtz, Statistics::Int64(s)) => Scalar::TimestampNtz(*s.min_opt()?),
(TimestampNtz, Statistics::Int32(s)) => timestamp_from_date(s.min_opt())?,
(TimestampNtz, _) => return None, (Decimal(d), Statistics::Int32(i)) => DecimalData::try_new(*i.min_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::Int64(i)) => DecimalData::try_new(*i.min_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::FixedLenByteArray(b)) => {
decimal_from_bytes(b.min_bytes_opt(), *d)?
}
(Decimal(..), _) => return None,
};
Some(value)
}
fn extract_max_scalar(data_type: &DataType, stats: &Statistics) -> Option<Scalar> {
use PrimitiveType::*;
let value = match (data_type.as_primitive_opt()?, stats) {
(String, Statistics::ByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(),
(String, Statistics::FixedLenByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(),
(String, _) => return None,
(Long, Statistics::Int64(s)) => s.max_opt()?.into(),
(Long, Statistics::Int32(s)) => (*s.max_opt()? as i64).into(),
(Long, _) => return None,
(Integer, Statistics::Int32(s)) => s.max_opt()?.into(),
(Integer, _) => return None,
(Short, Statistics::Int32(s)) => (*s.max_opt()? as i16).into(),
(Short, _) => return None,
(Byte, Statistics::Int32(s)) => (*s.max_opt()? as i8).into(),
(Byte, _) => return None,
(Float, Statistics::Float(s)) => s.max_opt()?.into(),
(Float, _) => return None,
(Double, Statistics::Double(s)) => s.max_opt()?.into(),
(Double, Statistics::Float(s)) => (*s.max_opt()? as f64).into(),
(Double, _) => return None,
(Boolean, Statistics::Boolean(s)) => s.max_opt()?.into(),
(Boolean, _) => return None,
(Binary, Statistics::ByteArray(s)) => s.max_opt()?.data().into(),
(Binary, Statistics::FixedLenByteArray(s)) => s.max_opt()?.data().into(),
(Binary, _) => return None,
(Date, Statistics::Int32(s)) => Scalar::Date(*s.max_opt()?),
(Date, _) => return None,
(Timestamp, Statistics::Int64(s)) => Scalar::Timestamp(*s.max_opt()?),
(Timestamp, _) => return None, (TimestampNtz, Statistics::Int64(s)) => Scalar::TimestampNtz(*s.max_opt()?),
(TimestampNtz, Statistics::Int32(s)) => timestamp_from_date(s.max_opt())?,
(TimestampNtz, _) => return None, (Decimal(d), Statistics::Int32(i)) => DecimalData::try_new(*i.max_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::Int64(i)) => DecimalData::try_new(*i.max_opt()?, *d).ok()?.into(),
(Decimal(d), Statistics::FixedLenByteArray(b)) => {
decimal_from_bytes(b.max_bytes_opt(), *d)?
}
(Decimal(..), _) => return None,
};
Some(value)
}
fn extract_nullcount(stats: Option<&Statistics>) -> Option<i64> {
let nullcount = stats?.null_count_opt().filter(|n| *n > 0);
Some(nullcount? as i64)
}
fn decimal_from_bytes(bytes: Option<&[u8]>, dtype: DecimalType) -> Option<Scalar> {
let bytes = bytes.filter(|b| b.len() <= 16)?;
let mut bytes = Vec::from(bytes);
bytes.reverse();
bytes.resize(16, 0u8);
let bytes: [u8; 16] = bytes.try_into().ok()?;
let value = DecimalData::try_new(i128::from_le_bytes(bytes), dtype).ok()?;
Some(value.into())
}
fn timestamp_from_date(days: Option<&i32>) -> Option<Scalar> {
let days = u64::try_from(*days?).ok()?;
let timestamp = DateTime::UNIX_EPOCH.checked_add_days(Days::new(days))?;
let timestamp = timestamp.signed_duration_since(DateTime::UNIX_EPOCH);
Some(Scalar::TimestampNtz(timestamp.num_microseconds()?))
}
pub(crate) fn compute_field_indices(
fields: &[ColumnDescPtr],
predicate: &Predicate,
) -> HashMap<ColumnName, usize> {
let mut requested_columns = predicate.references();
fields
.iter()
.enumerate()
.filter_map(|(i, f)| {
requested_columns
.take(f.path().parts())
.map(|path| (path.clone(), i))
})
.collect()
}