use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::Filter;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::filter::FilterReduceAdaptor;
use vortex_array::arrays::scalar_fn::AnyScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnVTable;
use vortex_array::arrays::slice::SliceReduceAdaptor;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::dtype::DType;
use vortex_array::extension::datetime::Timestamp;
use vortex_array::optimizer::ArrayOptimizer;
use vortex_array::optimizer::rules::ArrayParentReduceRule;
use vortex_array::optimizer::rules::ParentRuleSet;
use vortex_array::scalar_fn::fns::between::Between;
use vortex_array::scalar_fn::fns::binary::Binary;
use vortex_array::scalar_fn::fns::cast::CastReduceAdaptor;
use vortex_array::scalar_fn::fns::mask::MaskReduceAdaptor;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::DateTimeParts;
use crate::array::DateTimePartsArrayExt;
use crate::timestamp;
pub(crate) const PARENT_RULES: ParentRuleSet<DateTimeParts> = ParentRuleSet::new(&[
ParentRuleSet::lift(&DTPFilterPushDownRule),
ParentRuleSet::lift(&DTPComparisonPushDownRule),
ParentRuleSet::lift(&CastReduceAdaptor(DateTimeParts)),
ParentRuleSet::lift(&FilterReduceAdaptor(DateTimeParts)),
ParentRuleSet::lift(&MaskReduceAdaptor(DateTimeParts)),
ParentRuleSet::lift(&SliceReduceAdaptor(DateTimeParts)),
]);
#[derive(Debug)]
struct DTPFilterPushDownRule;
impl ArrayParentReduceRule<DateTimeParts> for DTPFilterPushDownRule {
type Parent = Filter;
fn reduce_parent(
&self,
child: ArrayView<'_, DateTimeParts>,
parent: ArrayView<'_, Filter>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
debug_assert_eq!(child_idx, 0);
if !child.seconds().is::<Constant>() || !child.subseconds().is::<Constant>() {
return Ok(None);
}
DateTimeParts::try_new(
child.dtype().clone(),
child.days().clone().filter(parent.filter_mask().clone())?,
ConstantArray::new(
child.seconds().as_constant().vortex_expect("constant"),
parent.filter_mask().true_count(),
)
.into_array(),
ConstantArray::new(
child.subseconds().as_constant().vortex_expect("constant"),
parent.filter_mask().true_count(),
)
.into_array(),
)
.map(|x| Some(x.into_array()))
}
}
#[derive(Debug)]
struct DTPComparisonPushDownRule;
impl ArrayParentReduceRule<DateTimeParts> for DTPComparisonPushDownRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
child: ArrayView<'_, DateTimeParts>,
parent: ArrayView<'_, ScalarFnVTable>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if parent
.scalar_fn()
.as_opt::<Binary>()
.is_none_or(|c| !c.is_comparison())
&& !parent.scalar_fn().is::<Between>()
{
return Ok(None);
}
if !is_constant_zero(child.seconds()) || !is_constant_zero(child.subseconds()) {
return Ok(None);
}
let days = child.days();
let mut new_children = Vec::with_capacity(parent.nchildren());
for (idx, c) in parent.iter_children().enumerate() {
if idx == child_idx {
new_children.push(days.clone());
} else {
let Some(days_value) = try_extract_days_constant(c) else {
return Ok(None);
};
let len = days.len();
let target_dtype = days.dtype();
let constant = ConstantArray::new(days_value, len).into_array();
new_children.push(constant.cast(target_dtype.clone())?);
}
}
let result =
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, parent.len())?
.into_array()
.optimize()?;
Ok(Some(result))
}
}
fn try_extract_days_constant(array: &ArrayRef) -> Option<i64> {
let constant = array.as_constant()?;
let timestamp = constant
.as_extension()
.to_storage_scalar()
.as_primitive()
.as_::<i64>()?;
let DType::Extension(ext_dtype) = constant.dtype() else {
return None;
};
let options = ext_dtype.metadata::<Timestamp>();
let ts_parts = timestamp::split(timestamp, options.unit).ok()?;
if ts_parts.seconds != 0 || ts_parts.subseconds != 0 {
return None;
}
Some(ts_parts.days)
}
fn is_constant_zero(array: &ArrayRef) -> bool {
array
.as_opt::<Constant>()
.is_some_and(|c| c.scalar().is_zero() == Some(true))
}
#[cfg(test)]
mod tests {
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::TemporalArray;
use vortex_array::arrays::scalar_fn::ScalarFnFactoryExt;
use vortex_array::extension::datetime::TimeUnit;
use vortex_array::extension::datetime::TimestampOptions;
use vortex_array::optimizer::ArrayOptimizer;
use vortex_array::scalar::Scalar;
use vortex_array::scalar_fn::fns::between::BetweenOptions;
use vortex_array::scalar_fn::fns::between::StrictComparison;
use vortex_array::scalar_fn::fns::operators::Operator;
use vortex_array::validity::Validity;
use vortex_buffer::Buffer;
use super::*;
use crate::DateTimeParts;
use crate::DateTimePartsArray;
const SECONDS_PER_DAY: i64 = 86400;
fn dtp_at_midnight(days: &[i64], time_unit: TimeUnit) -> DateTimePartsArray {
let multiplier = match time_unit {
TimeUnit::Seconds => 1,
TimeUnit::Milliseconds => 1_000,
TimeUnit::Microseconds => 1_000_000,
TimeUnit::Nanoseconds => 1_000_000_000,
TimeUnit::Days => panic!("Days not supported"),
};
let timestamps: Vec<i64> = days
.iter()
.map(|d| d * SECONDS_PER_DAY * multiplier)
.collect();
let buffer: Buffer<i64> = timestamps.into();
let temporal = TemporalArray::new_timestamp(
PrimitiveArray::new(buffer, Validity::NonNullable).into_array(),
time_unit,
None,
);
DateTimeParts::try_from_temporal(temporal)
.vortex_expect("TemporalArray must produce valid DateTimeParts")
}
fn midnight_constant(day: i64, time_unit: TimeUnit, len: usize) -> ArrayRef {
let multiplier = match time_unit {
TimeUnit::Seconds => 1,
TimeUnit::Milliseconds => 1_000,
TimeUnit::Microseconds => 1_000_000,
TimeUnit::Nanoseconds => 1_000_000_000,
TimeUnit::Days => panic!("Days not supported"),
};
let timestamp = day * SECONDS_PER_DAY * multiplier;
let scalar = Scalar::extension::<Timestamp>(
TimestampOptions {
unit: time_unit,
tz: None,
},
timestamp.into(),
);
ConstantArray::new(scalar, len).into_array()
}
fn non_midnight_constant(day: i64, seconds: i64, time_unit: TimeUnit, len: usize) -> ArrayRef {
let multiplier = match time_unit {
TimeUnit::Seconds => 1,
TimeUnit::Milliseconds => 1_000,
TimeUnit::Microseconds => 1_000_000,
TimeUnit::Nanoseconds => 1_000_000_000,
TimeUnit::Days => panic!("Days not supported"),
};
let timestamp = (day * SECONDS_PER_DAY + seconds) * multiplier;
let scalar = Scalar::extension::<Timestamp>(
TimestampOptions {
unit: time_unit,
tz: None,
},
timestamp.into(),
);
ConstantArray::new(scalar, len).into_array()
}
#[test]
fn test_binary_comparison_pushdown() {
let dtp = dtp_at_midnight(&[0, 1, 2], TimeUnit::Seconds);
let len = dtp.len();
let constant = midnight_constant(1, TimeUnit::Seconds, len);
let comparison = Binary
.try_new_array(len, Operator::Lte, [dtp.into_array(), constant])
.unwrap();
let optimized = comparison.optimize().unwrap();
assert!(
!optimized.is::<DateTimeParts>(),
"Expected pushdown to remove DTP from expression"
);
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 2);
}
#[test]
fn test_between_pushdown() {
let dtp = dtp_at_midnight(&[0, 1, 2, 3, 4], TimeUnit::Seconds);
let len = dtp.len();
let lower = midnight_constant(1, TimeUnit::Seconds, len);
let upper = midnight_constant(3, TimeUnit::Seconds, len);
let between = Between
.try_new_array(
len,
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
[dtp.into_array(), lower, upper],
)
.unwrap();
let optimized = between.optimize().unwrap();
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 3);
}
#[test]
fn test_no_pushdown_non_midnight_constant() {
let dtp = dtp_at_midnight(&[0, 1, 2], TimeUnit::Seconds);
let len = dtp.len();
let constant = non_midnight_constant(1, 43200, TimeUnit::Seconds, len);
let comparison = Binary
.try_new_array(len, Operator::Lte, [dtp.into_array(), constant])
.unwrap();
let optimized = comparison.optimize().unwrap();
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 2);
}
#[test]
fn test_no_pushdown_non_zero_dtp_seconds() {
let timestamps: Buffer<i64> = vec![
3600, SECONDS_PER_DAY + 3600, 2 * SECONDS_PER_DAY + 3600, ]
.into();
let temporal = TemporalArray::new_timestamp(
PrimitiveArray::new(timestamps, Validity::NonNullable).into_array(),
TimeUnit::Seconds,
None,
);
let dtp = DateTimeParts::try_from_temporal(temporal).unwrap();
let len = dtp.len();
let constant = midnight_constant(1, TimeUnit::Seconds, len);
let comparison = Binary
.try_new_array(len, Operator::Lte, [dtp.into_array(), constant])
.unwrap();
let optimized = comparison.optimize().unwrap();
assert_eq!(optimized.as_bool_typed().true_count().unwrap(), 1);
}
}