use std::cmp::Ordering;
use std::collections::HashSet;
use std::fs::File;
use std::sync::{Arc, LazyLock};
use super::*;
use crate::arrow::array::{Int64Array, RecordBatch, StringArray, StructArray};
use crate::arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema as ArrowSchema};
use crate::expressions::{
column_expr, column_name, column_pred, Expression, OpaquePredicateOp, ScalarExpressionEvaluator,
};
use crate::kernel_predicates::{
DataSkippingPredicateEvaluator as _, DirectDataSkippingPredicateEvaluator,
DirectPredicateEvaluator, IndirectDataSkippingPredicateEvaluator,
KernelPredicateEvaluator as _,
};
use crate::parquet::arrow::arrow_reader::ArrowReaderMetadata;
use crate::parquet::arrow::ArrowWriter;
use crate::parquet::file::properties::WriterProperties;
use crate::parquet::file::reader::FileReader;
use crate::parquet::file::serialized_reader::SerializedFileReader;
use crate::{DeltaResult, Predicate};
static NO_PARTITIONS: LazyLock<HashSet<String>> = LazyLock::new(HashSet::new);
#[test]
fn test_get_stat_values() {
let file = File::open("./tests/data/parquet_row_group_skipping/part-00000-b92e017a-50ba-4676-8322-48fc371c2b59-c000.snappy.parquet").unwrap();
let metadata = ArrowReaderMetadata::load(&file, Default::default()).unwrap();
let columns = Predicate::and_from(vec![
column_pred!("varlen.utf8"),
column_pred!("numeric.ints.int64"),
column_pred!("numeric.ints.int32"),
column_pred!("numeric.ints.int16"),
column_pred!("numeric.ints.int8"),
column_pred!("numeric.floats.float32"),
column_pred!("numeric.floats.float64"),
column_pred!("bool"),
column_pred!("varlen.binary"),
column_pred!("numeric.decimals.decimal32"),
column_pred!("numeric.decimals.decimal64"),
column_pred!("numeric.decimals.decimal128"),
column_pred!("chrono.date32"),
column_pred!("chrono.timestamp"),
column_pred!("chrono.timestamp_ntz"),
]);
let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns);
assert_eq!(filter.get_rowcount_stat(), Some(5i64.into()));
assert_eq!(
filter.get_nullcount_stat(&column_name!("bool")),
Some(3i64.into())
);
assert_eq!(
filter.get_nullcount_stat(&column_name!("varlen.utf8")),
None );
assert_eq!(
filter.get_min_stat(&column_name!("varlen.utf8"), &DataType::STRING),
Some("a".into())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::STRING
),
Some("\0\0\0\0\0\0\0\0\0\0\0\0+x".into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.ints.int64"), &DataType::LONG),
Some(1000000000i64.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.ints.int32"), &DataType::LONG),
Some(1000000i64.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.ints.int32"), &DataType::INTEGER),
Some(1000000i32.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.ints.int16"), &DataType::SHORT),
Some(1000i16.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.ints.int8"), &DataType::BYTE),
Some(0i8.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.floats.float64"), &DataType::DOUBLE),
Some(1147f64.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.floats.float32"), &DataType::DOUBLE),
Some(139f64.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("numeric.floats.float32"), &DataType::FLOAT),
Some(139f32.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("bool"), &DataType::BOOLEAN),
Some(false.into())
);
assert_eq!(
filter.get_min_stat(&column_name!("varlen.binary"), &DataType::BINARY),
Some([].as_slice().into())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::BINARY
),
Some(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x2b, 0x78]
.as_slice()
.into()
)
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(8, 3).unwrap()
),
Some(Scalar::decimal(11032, 8, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal64"),
&DataType::decimal(16, 3).unwrap()
),
Some(Scalar::decimal(11064, 16, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(16, 3).unwrap()
),
Some(Scalar::decimal(11032, 16, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(11128, 32, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal64"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(11064, 32, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(11032, 32, 3).unwrap())
);
assert_eq!(
filter.get_min_stat(&column_name!("chrono.date32"), &DataType::DATE),
Some(PrimitiveType::Date.parse_scalar("1971-01-01").unwrap())
);
assert_eq!(
filter.get_min_stat(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP),
None );
assert_eq!(
filter.get_min_stat(
&column_name!("chrono.date32"),
&DataType::unshredded_variant()
),
None
);
assert_eq!(
filter.get_min_stat(&column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP),
Some(
PrimitiveType::Timestamp
.parse_scalar("1970-01-02 00:00:00.000000")
.unwrap()
)
);
assert_eq!(
filter.get_min_stat(
&column_name!("chrono.timestamp_ntz"),
&DataType::TIMESTAMP_NTZ
),
Some(
PrimitiveType::TimestampNtz
.parse_scalar("1970-01-02 00:00:00.000000")
.unwrap()
)
);
assert_eq!(
filter.get_min_stat(&column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ),
Some(
PrimitiveType::TimestampNtz
.parse_scalar("1971-01-01 00:00:00.000000")
.unwrap()
)
);
assert_eq!(
filter.get_max_stat(&column_name!("varlen.utf8"), &DataType::STRING),
Some("e".into())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::STRING
),
Some("\0\0\0\0\0\0\0\0\0\0\0\0;\u{18}".into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.ints.int64"), &DataType::LONG),
Some(1000000004i64.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.ints.int32"), &DataType::LONG),
Some(1000004i64.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.ints.int32"), &DataType::INTEGER),
Some(1000004.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.ints.int16"), &DataType::SHORT),
Some(1004i16.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.ints.int8"), &DataType::BYTE),
Some(4i8.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.floats.float64"), &DataType::DOUBLE),
Some(1125899906842747f64.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.floats.float32"), &DataType::DOUBLE),
Some(1048699f64.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("numeric.floats.float32"), &DataType::FLOAT),
Some(1048699f32.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("bool"), &DataType::BOOLEAN),
Some(true.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("varlen.binary"), &DataType::BINARY),
Some([0, 0, 0, 0].as_slice().into())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::BINARY
),
Some(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3b, 0x18]
.as_slice()
.into()
)
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(8, 3).unwrap()
),
Some(Scalar::decimal(15032, 8, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal64"),
&DataType::decimal(16, 3).unwrap()
),
Some(Scalar::decimal(15064, 16, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(16, 3).unwrap()
),
Some(Scalar::decimal(15032, 16, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal128"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(15128, 32, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal64"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(15064, 32, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(
&column_name!("numeric.decimals.decimal32"),
&DataType::decimal(32, 3).unwrap()
),
Some(Scalar::decimal(15032, 32, 3).unwrap())
);
assert_eq!(
filter.get_max_stat(&column_name!("chrono.date32"), &DataType::DATE),
Some(PrimitiveType::Date.parse_scalar("1971-01-05").unwrap())
);
assert_eq!(
filter.get_max_stat(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP),
None );
assert_eq!(
filter.get_max_stat(
&column_name!("chrono.date32"),
&DataType::unshredded_variant()
),
None
);
assert_eq!(
filter.get_max_stat(&column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP),
Some(
PrimitiveType::Timestamp
.parse_scalar("1970-01-02 00:04:00.000000")
.unwrap()
)
);
assert_eq!(
filter.get_max_stat(
&column_name!("chrono.timestamp_ntz"),
&DataType::TIMESTAMP_NTZ
),
Some(
PrimitiveType::TimestampNtz
.parse_scalar("1970-01-02 00:04:00.000000")
.unwrap()
)
);
assert_eq!(
filter.get_max_stat(&column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ),
Some(
PrimitiveType::TimestampNtz
.parse_scalar("1971-01-05 00:00:00.000000")
.unwrap()
)
);
}
fn wrap_in_nested_struct(
col_path: &[&str],
values: Arc<Int64Array>,
) -> (Arc<Field>, Arc<dyn crate::arrow::array::Array>) {
assert!(!col_path.is_empty());
let mut field = Arc::new(Field::new(
*col_path.last().unwrap(),
ArrowDataType::Int64,
true,
));
let mut array: Arc<dyn crate::arrow::array::Array> = values;
for &name in col_path[..col_path.len() - 1].iter().rev() {
let struct_array = StructArray::from(vec![(field.clone(), array)]);
field = Arc::new(Field::new(
name,
ArrowDataType::Struct(Fields::from(vec![field])),
true,
));
array = Arc::new(struct_array);
}
(field, array)
}
fn build_stat_column(
stat_name: &str,
col_path: &[&str],
values: Arc<Int64Array>,
) -> (Arc<Field>, Arc<dyn crate::arrow::array::Array>) {
let (col_field, col_array) = wrap_in_nested_struct(col_path, values);
let stat_struct = StructArray::from(vec![(col_field.clone(), col_array)]);
let stat_field = Arc::new(Field::new(
stat_name,
ArrowDataType::Struct(Fields::from(vec![col_field])),
true,
));
(stat_field, Arc::new(stat_struct))
}
fn write_checkpoint_parquet(
min_values: &[Option<i64>],
max_values: &[Option<i64>],
null_counts: &[Option<i64>],
col_path: &[&str],
part_values: Option<&[Option<&str>]>,
) -> tempfile::NamedTempFile {
let (min_f, min_a) = build_stat_column(
"minValues",
col_path,
Arc::new(Int64Array::from(min_values.to_vec())),
);
let (max_f, max_a) = build_stat_column(
"maxValues",
col_path,
Arc::new(Int64Array::from(max_values.to_vec())),
);
let (nc_f, nc_a) = build_stat_column(
"nullCount",
col_path,
Arc::new(Int64Array::from(null_counts.to_vec())),
);
let stats_struct = StructArray::from(vec![
(min_f.clone(), min_a),
(max_f.clone(), max_a),
(nc_f.clone(), nc_a),
]);
let stats_parsed_field = Arc::new(Field::new(
"stats_parsed",
ArrowDataType::Struct(Fields::from(vec![min_f, max_f, nc_f])),
true,
));
let mut add_children: Vec<(Arc<Field>, Arc<dyn crate::arrow::array::Array>)> =
vec![(stats_parsed_field, Arc::new(stats_struct))];
if let Some(part_values) = part_values {
let part_col_field = Arc::new(Field::new("part_col", ArrowDataType::Utf8, true));
let pv_parsed_field = Arc::new(Field::new(
"partitionValues_parsed",
ArrowDataType::Struct(Fields::from(vec![part_col_field.clone()])),
true,
));
let part_col = Arc::new(StringArray::from(part_values.to_vec()));
let pv_struct = StructArray::from(vec![(part_col_field.clone(), part_col as _)]);
add_children.push((pv_parsed_field, Arc::new(pv_struct)));
}
let add_struct = StructArray::from(add_children.clone());
let add_fields: Fields = add_children.iter().map(|(f, _)| f.clone()).collect();
let add_field = Arc::new(Field::new("add", ArrowDataType::Struct(add_fields), true));
let schema = Arc::new(ArrowSchema::new(vec![add_field]));
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(add_struct)]).unwrap();
let tmp = tempfile::NamedTempFile::new().unwrap();
let file = tmp.as_file().try_clone().unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
tmp
}
fn checkpoint_row_group_metadata(
tmp: &tempfile::NamedTempFile,
) -> crate::parquet::file::metadata::ParquetMetaData {
let file = File::open(tmp.path()).unwrap();
let reader = SerializedFileReader::new(file).unwrap();
reader.metadata().clone()
}
#[test]
fn checkpoint_filter_returns_stats_when_no_nulls_in_stat_columns() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(2), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::gt(column_name!("x"), Scalar::from(50i64));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
assert_eq!(
filter.get_min_stat(&column_name!("x"), &DataType::LONG),
Some(10i64.into())
);
assert_eq!(
filter.get_max_stat(&column_name!("x"), &DataType::LONG),
Some(200i64.into())
);
assert_eq!(
filter.get_nullcount_stat(&column_name!("x")),
Some(2i64.into())
);
assert_eq!(filter.get_rowcount_stat(), None);
}
#[test]
fn checkpoint_filter_returns_none_when_stat_column_has_nulls() {
let tmp = write_checkpoint_parquet(
&[Some(10), None, Some(20)],
&[Some(100), None, Some(200)],
&[Some(2), None, Some(0)],
&["x"],
Some(&[Some("a"), Some("b"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::gt(column_name!("x"), Scalar::from(50i64));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
assert_eq!(
filter.get_min_stat(&column_name!("x"), &DataType::LONG),
None
);
assert_eq!(
filter.get_max_stat(&column_name!("x"), &DataType::LONG),
None
);
assert_eq!(filter.get_nullcount_stat(&column_name!("x")), None);
}
#[test]
fn checkpoint_filter_partition_columns_always_available() {
let tmp = write_checkpoint_parquet(
&[Some(10), None],
&[Some(100), None],
&[Some(2), None],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let partition_columns: HashSet<String> = ["part_col".to_string()].into();
let predicate = Predicate::and(
Predicate::gt(column_name!("x"), Scalar::from(50i64)),
Predicate::eq(column_name!("part_col"), Scalar::from("a")),
);
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &partition_columns);
assert_eq!(
filter.get_min_stat(&column_name!("part_col"), &DataType::STRING),
Some("a".into())
);
assert_eq!(
filter.get_max_stat(&column_name!("part_col"), &DataType::STRING),
Some("b".into())
);
assert_eq!(
filter.get_min_stat(&column_name!("x"), &DataType::LONG),
None
);
}
#[test]
fn checkpoint_filter_apply_keeps_row_group_with_missing_stats() {
let tmp = write_checkpoint_parquet(
&[Some(10), None],
&[Some(100), None],
&[Some(0), None],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::gt(column_name!("x"), Scalar::from(500i64));
assert!(CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}
#[test]
fn checkpoint_filter_apply_prunes_row_group_with_all_stats_present() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::gt(column_name!("x"), Scalar::from(500i64));
assert!(!CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}
#[test]
fn checkpoint_filter_is_null_with_all_stats_present() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(5), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::is_null(column_name!("x"));
assert!(CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}
#[test]
fn checkpoint_filter_is_null_all_zero_nullcounts() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::is_null(column_name!("x"));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
let result = filter.eval_sql_where(&predicate);
assert_eq!(result, Some(false));
}
#[test]
fn checkpoint_filter_is_not_null_never_prunes() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(5), Some(3)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::is_not_null(column_name!("x"));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
let result = filter.eval_sql_where(&predicate);
assert_eq!(result, None);
}
#[test]
fn checkpoint_filter_timestamp_max_widened() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = column_pred!("x");
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
assert_eq!(
filter.get_max_stat(&column_name!("x"), &DataType::TIMESTAMP),
Some(Scalar::Timestamp(1199))
);
assert_eq!(
filter.get_max_stat(&column_name!("x"), &DataType::TIMESTAMP_NTZ),
Some(Scalar::TimestampNtz(1199))
);
assert_eq!(
filter.get_max_stat(&column_name!("x"), &DataType::LONG),
Some(200i64.into())
);
}
#[test]
fn checkpoint_filter_unknown_column_returns_none() {
let tmp = write_checkpoint_parquet(
&[Some(10)],
&[Some(100)],
&[Some(0)],
&["x"],
Some(&[Some("a")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::gt(column_name!("y"), Scalar::from(50i64));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
assert_eq!(
filter.get_min_stat(&column_name!("y"), &DataType::LONG),
None
);
assert_eq!(
filter.get_max_stat(&column_name!("y"), &DataType::LONG),
None
);
assert_eq!(filter.get_nullcount_stat(&column_name!("y")), None);
}
#[test]
fn checkpoint_filter_mixed_partition_and_data_predicate() {
let tmp = write_checkpoint_parquet(
&[Some(10), None],
&[Some(100), None],
&[Some(0), None],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let partition_columns: HashSet<String> = ["part_col".to_string()].into();
let predicate = Predicate::and(
Predicate::eq(column_name!("part_col"), Scalar::from("a")),
Predicate::gt(column_name!("x"), Scalar::from(500i64)),
);
assert!(CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&partition_columns
));
let predicate = Predicate::and(
Predicate::eq(column_name!("part_col"), Scalar::from("c")),
Predicate::gt(column_name!("x"), Scalar::from(5i64)),
);
assert!(!CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&partition_columns,
));
}
#[derive(Debug, PartialEq)]
struct OpaqueLessThanOp;
impl OpaquePredicateOp for OpaqueLessThanOp {
fn name(&self) -> &str {
"less_than"
}
fn eval_pred_scalar(
&self,
_eval_expr: &ScalarExpressionEvaluator<'_>,
_evaluator: &DirectPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> DeltaResult<Option<bool>> {
unimplemented!("not needed for data skipping tests")
}
fn eval_as_data_skipping_predicate(
&self,
evaluator: &DirectDataSkippingPredicateEvaluator<'_>,
exprs: &[Expression],
inverted: bool,
) -> Option<bool> {
let (col, val, ord) = match exprs {
[Expression::Column(col), Expression::Literal(val)] => (col, val, Ordering::Less),
[Expression::Literal(val), Expression::Column(col)] => (col, val, Ordering::Greater),
_ => return None,
};
evaluator.partial_cmp_min_stat(col, val, ord, inverted)
}
fn as_data_skipping_predicate(
&self,
_evaluator: &IndirectDataSkippingPredicateEvaluator<'_>,
_exprs: &[Expression],
_inverted: bool,
) -> Option<Predicate> {
unimplemented!("not needed for data skipping tests")
}
}
#[test]
fn checkpoint_filter_opaque_predicate_with_null_guarded_stats() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::opaque(
OpaqueLessThanOp,
vec![column_expr!("x"), Expression::literal(5i64)],
);
assert!(!CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
let predicate = Predicate::opaque(
OpaqueLessThanOp,
vec![column_expr!("x"), Expression::literal(50i64)],
);
assert!(CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}
#[test]
fn checkpoint_filter_opaque_predicate_with_missing_stats() {
let tmp = write_checkpoint_parquet(
&[Some(10), None],
&[Some(100), None],
&[Some(0), None],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let predicate = Predicate::opaque(
OpaqueLessThanOp,
vec![column_expr!("x"), Expression::literal(5i64)],
);
assert!(CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}
#[test]
fn checkpoint_filter_partition_nullcount_is_null() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["x"],
Some(&[Some("a"), Some("b")]),
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let partition_columns: HashSet<String> = ["part_col".to_string()].into();
let predicate = Predicate::is_null(column_name!("part_col"));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &partition_columns);
assert_eq!(filter.get_nullcount_stat(&column_name!("part_col")), None);
assert_eq!(filter.eval_sql_where(&predicate), None);
}
#[test]
fn checkpoint_filter_multi_row_group_skipping() {
let col_field = Arc::new(Field::new("x", ArrowDataType::Int64, true));
let min_field = Arc::new(Field::new(
"minValues",
ArrowDataType::Struct(Fields::from(vec![col_field.clone()])),
true,
));
let max_field = Arc::new(Field::new(
"maxValues",
ArrowDataType::Struct(Fields::from(vec![col_field.clone()])),
true,
));
let nc_field = Arc::new(Field::new(
"nullCount",
ArrowDataType::Struct(Fields::from(vec![col_field.clone()])),
true,
));
let stats_field = Arc::new(Field::new(
"stats_parsed",
ArrowDataType::Struct(Fields::from(vec![
min_field.clone(),
max_field.clone(),
nc_field.clone(),
])),
true,
));
let add_field = Arc::new(Field::new(
"add",
ArrowDataType::Struct(Fields::from(vec![stats_field.clone()])),
true,
));
let schema = Arc::new(ArrowSchema::new(vec![add_field]));
let make_batch = |mins: &[i64], maxs: &[i64], ncs: &[i64]| {
let min_arr = Arc::new(Int64Array::from(mins.to_vec()));
let max_arr = Arc::new(Int64Array::from(maxs.to_vec()));
let nc_arr = Arc::new(Int64Array::from(ncs.to_vec()));
let min_s = StructArray::from(vec![(col_field.clone(), min_arr as _)]);
let max_s = StructArray::from(vec![(col_field.clone(), max_arr as _)]);
let nc_s = StructArray::from(vec![(col_field.clone(), nc_arr as _)]);
let stats_s = StructArray::from(vec![
(min_field.clone(), Arc::new(min_s) as _),
(max_field.clone(), Arc::new(max_s) as _),
(nc_field.clone(), Arc::new(nc_s) as _),
]);
let add_s = StructArray::from(vec![(stats_field.clone(), Arc::new(stats_s) as _)]);
RecordBatch::try_new(schema.clone(), vec![Arc::new(add_s)]).unwrap()
};
let tmp = tempfile::NamedTempFile::new().unwrap();
let file = tmp.as_file().try_clone().unwrap();
#[allow(deprecated)] let props = WriterProperties::builder()
.set_max_row_group_size(2)
.build();
let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props)).unwrap();
writer
.write(&make_batch(&[10, 20], &[50, 100], &[0, 0]))
.unwrap();
writer
.write(&make_batch(&[400, 450], &[500, 600], &[0, 0]))
.unwrap();
writer.close().unwrap();
let file = File::open(tmp.path()).unwrap();
let builder =
crate::parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new(file)
.unwrap();
assert_eq!(builder.metadata().num_row_groups(), 2);
let predicate = Predicate::gt(column_name!("x"), Scalar::from(500i64));
let builder = builder.with_checkpoint_row_group_filter(&predicate, &NO_PARTITIONS, None);
let reader = builder.build().unwrap();
let batches: Vec<_> = reader.into_iter().collect::<Result<_, _>>().unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 2);
}
#[test]
fn checkpoint_filter_nested_struct_column_stats() {
let tmp = write_checkpoint_parquet(
&[Some(10), Some(20)],
&[Some(100), Some(200)],
&[Some(0), Some(0)],
&["a", "b"],
None,
);
let metadata = checkpoint_row_group_metadata(&tmp);
let row_group = metadata.row_group(0);
let col = ColumnName::new(["a", "b"]);
let predicate = Predicate::gt(col.clone(), Scalar::from(500i64));
let filter = CheckpointRowGroupFilter::new(row_group, &predicate, &NO_PARTITIONS);
assert_eq!(
filter.get_min_stat(&col, &DataType::LONG),
Some(10i64.into())
);
assert_eq!(
filter.get_max_stat(&col, &DataType::LONG),
Some(200i64.into())
);
assert!(!CheckpointRowGroupFilter::apply(
row_group,
&predicate,
&NO_PARTITIONS
));
}