use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::mem::{size_of, size_of_val};
use std::sync::Arc;
use super::utils::{
convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
};
use crate::PhysicalExpr;
use crate::expressions::{BinaryExpr, Literal};
use crate::utils::{ExprTreeNode, build_dag};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{Result, internal_err, not_impl_err};
use datafusion_expr::Operator;
use datafusion_expr::interval_arithmetic::{Interval, apply_operator, satisfy_greater};
use petgraph::Outgoing;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::{DefaultIx, StableGraph};
use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
#[derive(Clone, Debug)]
pub struct ExprIntervalGraph {
graph: StableGraph<ExprIntervalGraphNode, usize>,
root: NodeIndex,
}
#[derive(PartialEq, Debug)]
pub enum PropagationResult {
CannotPropagate,
Infeasible,
Success,
}
#[derive(Clone, Debug)]
pub struct ExprIntervalGraphNode {
expr: Arc<dyn PhysicalExpr>,
interval: Interval,
}
impl PartialEq for ExprIntervalGraphNode {
fn eq(&self, other: &Self) -> bool {
self.expr.eq(&other.expr)
}
}
impl Display for ExprIntervalGraphNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)
}
}
impl ExprIntervalGraphNode {
pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
Interval::make_unbounded(dt)
.map(|interval| ExprIntervalGraphNode { expr, interval })
}
pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
ExprIntervalGraphNode { expr, interval }
}
pub fn interval(&self) -> &Interval {
&self.interval
}
pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
let expr = Arc::clone(&node.expr);
if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
let value = literal.value();
Interval::try_new(value.clone(), value.clone())
.map(|interval| Self::new_with_interval(expr, interval))
} else {
expr.data_type(schema)
.and_then(|dt| Self::new_unbounded(expr, &dt))
}
}
}
pub fn propagate_arithmetic(
op: &Operator,
parent: &Interval,
left_child: &Interval,
right_child: &Interval,
) -> Result<Option<(Interval, Interval)>> {
let inverse_op = get_inverse_op(*op)?;
match (left_child.data_type(), right_child.data_type()) {
(DataType::Timestamp(..), DataType::Interval(_)) => {
propagate_time_interval_at_right(
left_child,
right_child,
parent,
op,
&inverse_op,
)
}
(DataType::Interval(_), DataType::Timestamp(..)) => {
propagate_time_interval_at_left(
left_child,
right_child,
parent,
op,
&inverse_op,
)
}
_ => {
match apply_operator(&inverse_op, parent, right_child)?
.intersect(left_child)?
{
Some(value) => Ok(
propagate_right(&value, parent, right_child, op, &inverse_op)?
.map(|right| (value, right)),
),
None => Ok(None),
}
}
}
}
pub fn propagate_comparison(
op: &Operator,
parent: &Interval,
left_child: &Interval,
right_child: &Interval,
) -> Result<Option<(Interval, Interval)>> {
if parent == &Interval::TRUE {
match op {
Operator::Eq => left_child.intersect(right_child).map(|result| {
result.map(|intersection| (intersection.clone(), intersection))
}),
Operator::Gt => satisfy_greater(left_child, right_child, true),
Operator::GtEq => satisfy_greater(left_child, right_child, false),
Operator::Lt => satisfy_greater(right_child, left_child, true)
.map(|t| t.map(reverse_tuple)),
Operator::LtEq => satisfy_greater(right_child, left_child, false)
.map(|t| t.map(reverse_tuple)),
_ => internal_err!(
"The operator must be a comparison operator to propagate intervals"
),
}
} else if parent == &Interval::FALSE {
match op {
Operator::Eq => {
Ok(None)
}
Operator::Gt => satisfy_greater(right_child, left_child, false),
Operator::GtEq => satisfy_greater(right_child, left_child, true),
Operator::Lt => satisfy_greater(left_child, right_child, false)
.map(|t| t.map(reverse_tuple)),
Operator::LtEq => satisfy_greater(left_child, right_child, true)
.map(|t| t.map(reverse_tuple)),
_ => internal_err!(
"The operator must be a comparison operator to propagate intervals"
),
}
} else {
Ok(None)
}
}
impl ExprIntervalGraph {
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
let (root, graph) =
build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?;
Ok(Self { graph, root })
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn size(&self) -> usize {
let node_memory_usage = self.graph.node_count()
* (size_of::<ExprIntervalGraphNode>() + size_of::<NodeIndex>());
let edge_memory_usage =
self.graph.edge_count() * (size_of::<usize>() + size_of::<NodeIndex>() * 2);
size_of_val(self) + node_memory_usage + edge_memory_usage
}
pub fn gather_node_indices(
&mut self,
exprs: &[Arc<dyn PhysicalExpr>],
) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
let graph = &self.graph;
let mut bfs = Bfs::new(graph, self.root);
let mut removals = vec![];
let mut expr_node_indices = exprs
.iter()
.map(|e| (Arc::clone(e), usize::MAX))
.collect::<Vec<_>>();
while let Some(node) = bfs.next(graph) {
let expr = &graph[node].expr;
if let Some(value) = exprs.iter().position(|e| expr.eq(e)) {
expr_node_indices[value].1 = node.index();
for edge in graph.edges_directed(node, Outgoing) {
removals.push(edge.id());
}
}
}
for edge_idx in removals {
self.graph.remove_edge(edge_idx);
}
let connected_nodes = self.connected_nodes();
self.graph
.retain_nodes(|_, index| connected_nodes.contains(&index));
expr_node_indices
}
fn connected_nodes(&self) -> HashSet<NodeIndex> {
let mut nodes = HashSet::new();
let mut dfs = Dfs::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
nodes.insert(node);
}
nodes
}
pub fn update_ranges(
&mut self,
leaf_bounds: &mut [(usize, Interval)],
given_range: Interval,
) -> Result<PropagationResult> {
self.assign_intervals(leaf_bounds);
let bounds = self.evaluate_bounds()?;
if given_range.contains(bounds)? == Interval::TRUE {
Ok(PropagationResult::CannotPropagate)
} else if bounds.contains(&given_range)? != Interval::FALSE {
let result = self.propagate_constraints(given_range);
self.update_intervals(leaf_bounds);
result
} else {
Ok(PropagationResult::Infeasible)
}
}
pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
for (index, interval) in assignments {
let node_index = NodeIndex::from(*index as DefaultIx);
self.graph[node_index].interval = interval.clone();
}
}
pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
for (index, interval) in assignments.iter_mut() {
let node_index = NodeIndex::from(*index as DefaultIx);
*interval = self.graph[node_index].interval.clone();
}
}
pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
let mut dfs = DfsPostOrder::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children_intervals = neighbors
.map(|child| self.graph[child].interval())
.collect::<Vec<_>>();
if !children_intervals.is_empty() {
children_intervals.reverse();
self.graph[node].interval =
self.graph[node].expr.evaluate_bounds(&children_intervals)?;
}
}
Ok(self.graph[self.root].interval())
}
fn propagate_constraints(
&mut self,
given_range: Interval,
) -> Result<PropagationResult> {
if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? {
self.graph[self.root].interval = interval;
} else {
return Ok(PropagationResult::Infeasible);
}
let mut bfs = Bfs::new(&self.graph, self.root);
while let Some(node) = bfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children = neighbors.collect::<Vec<_>>();
if children.is_empty() {
continue;
}
children.reverse();
let children_intervals = children
.iter()
.map(|child| self.graph[*child].interval())
.collect::<Vec<_>>();
let node_interval = self.graph[node].interval();
if node_interval == &Interval::TRUE
&& self.graph[node]
.expr
.as_any()
.downcast_ref::<BinaryExpr>()
.is_some_and(|expr| expr.op() == &Operator::Or)
{
return not_impl_err!("OR operator cannot yet propagate true intervals");
}
let propagated_intervals = self.graph[node]
.expr
.propagate_constraints(node_interval, &children_intervals)?;
if let Some(propagated_intervals) = propagated_intervals {
for (child, interval) in children.into_iter().zip(propagated_intervals) {
self.graph[child].interval = interval;
}
} else {
return Ok(PropagationResult::Infeasible);
}
}
Ok(PropagationResult::Success)
}
pub fn get_interval(&self, index: usize) -> Interval {
self.graph[NodeIndex::new(index)].interval.clone()
}
}
fn propagate_right(
left: &Interval,
parent: &Interval,
right: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<Option<Interval>> {
match op {
Operator::Minus => apply_operator(op, left, parent),
Operator::Plus => apply_operator(inverse_op, parent, left),
Operator::Divide => apply_operator(op, left, parent),
Operator::Multiply => apply_operator(inverse_op, parent, left),
_ => internal_err!("Interval arithmetic does not support the operator {}", op),
}?
.intersect(right)
}
fn propagate_time_interval_at_left(
left_child: &Interval,
right_child: &Interval,
parent: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<Option<(Interval, Interval)>> {
let result = if let Some(duration) = convert_interval_type_to_duration(left_child) {
match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? {
Some(value) => {
let left = convert_duration_type_to_interval(&value);
let right = propagate_right(&value, parent, right_child, op, inverse_op)?;
match (left, right) {
(Some(left), Some(right)) => Some((left, right)),
_ => None,
}
}
None => None,
}
} else {
propagate_right(left_child, parent, right_child, op, inverse_op)?
.map(|right| (left_child.clone(), right))
};
Ok(result)
}
fn propagate_time_interval_at_right(
left_child: &Interval,
right_child: &Interval,
parent: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<Option<(Interval, Interval)>> {
let result = if let Some(duration) = convert_interval_type_to_duration(right_child) {
match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? {
Some(value) => {
propagate_right(left_child, parent, &duration, op, inverse_op)?
.and_then(|right| convert_duration_type_to_interval(&right))
.map(|right| (value, right))
}
None => None,
}
} else {
apply_operator(inverse_op, parent, right_child)?
.intersect(left_child)?
.map(|value| (value, right_child.clone()))
};
Ok(result)
}
fn reverse_tuple<T, U>((first, second): (T, U)) -> (U, T) {
(second, first)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{BinaryExpr, Column};
use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
use arrow::datatypes::{Field, TimeUnit};
use datafusion_common::ScalarValue;
use itertools::Itertools;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rstest::*;
#[expect(clippy::too_many_arguments)]
fn experiment(
expr: Arc<dyn PhysicalExpr>,
exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
left_interval: Interval,
right_interval: Interval,
left_expected: Interval,
right_expected: Interval,
result: PropagationResult,
schema: &Schema,
) -> Result<()> {
let col_stats = [
(Arc::clone(&exprs_with_interval.0), left_interval),
(Arc::clone(&exprs_with_interval.1), right_interval),
];
let expected = [
(Arc::clone(&exprs_with_interval.0), left_expected),
(Arc::clone(&exprs_with_interval.1), right_expected),
];
let mut graph = ExprIntervalGraph::try_new(expr, schema)?;
let expr_indexes = graph.gather_node_indices(
&col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(),
);
let mut col_stat_nodes = col_stats
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let expected_nodes = expected
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let exp_result = graph.update_ranges(&mut col_stat_nodes[..], Interval::TRUE)?;
assert_eq!(exp_result, result);
col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
|((_, calculated_interval_node), (_, expected))| {
let one = ScalarValue::new_one(&expected.data_type()).unwrap();
assert!(
calculated_interval_node.lower()
<= &expected.lower().add(&one).unwrap(),
"{}",
format!(
"Calculated {} must be less than or equal {}",
calculated_interval_node.lower(),
expected.lower()
)
);
assert!(
calculated_interval_node.upper()
>= &expected.upper().sub(&one).unwrap(),
"{}",
format!(
"Calculated {} must be greater than or equal {}",
calculated_interval_node.upper(),
expected.upper()
)
);
},
);
Ok(())
}
macro_rules! generate_cases {
($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
fn $FUNC_NAME<const ASC: bool>(
expr: Arc<dyn PhysicalExpr>,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
seed: u64,
expr_left: $TYPE,
expr_right: $TYPE,
) -> Result<()> {
let mut r = StdRng::seed_from_u64(seed);
let (left_given, right_given, left_expected, right_expected) = if ASC {
let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
(
(Some(left), None),
(Some(right), None),
(Some(<$TYPE>::max(left, right + expr_left)), None),
(Some(<$TYPE>::max(right, left + expr_right)), None),
)
} else {
let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
(
(None, Some(left)),
(None, Some(right)),
(None, Some(<$TYPE>::min(left, right + expr_left))),
(None, Some(<$TYPE>::min(right, left + expr_right))),
)
};
experiment(
expr,
(left_col.clone(), right_col.clone()),
Interval::make(left_given.0, left_given.1).unwrap(),
Interval::make(right_given.0, right_given.1).unwrap(),
Interval::make(left_expected.0, left_expected.1).unwrap(),
Interval::make(right_expected.0, right_expected.1).unwrap(),
PropagationResult::Success,
&Schema::new(vec![
Field::new(
left_col.as_any().downcast_ref::<Column>().unwrap().name(),
DataType::$SCALAR,
true,
),
Field::new(
right_col.as_any().downcast_ref::<Column>().unwrap().name(),
DataType::$SCALAR,
true,
),
]),
)
}
};
}
generate_cases!(generate_case_i32, i32, Int32);
generate_cases!(generate_case_i64, i64, Int64);
generate_cases!(generate_case_f32, f32, Float32);
generate_cases!(generate_case_f64, f64, Float64);
#[test]
fn testing_not_possible() -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let left_and_1 = Arc::new(BinaryExpr::new(
Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
));
let expr = Arc::new(BinaryExpr::new(
left_and_1,
Operator::Gt,
Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
));
experiment(
expr,
(
Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
),
Interval::make(Some(10_i32), Some(20_i32))?,
Interval::make(Some(100), None)?,
Interval::make(Some(10), Some(20))?,
Interval::make(Some(100), None)?,
PropagationResult::Infeasible,
&Schema::new(vec![
Field::new(
left_col.as_any().downcast_ref::<Column>().unwrap().name(),
DataType::Int32,
true,
),
Field::new(
right_col.as_any().downcast_ref::<Column>().unwrap().name(),
DataType::Int32,
true,
),
]),
)
}
macro_rules! integer_float_case_1 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Plus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(11 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(33 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 10 as $TYPE;
let r_gt_l = -30 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_2 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -7 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_3 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_4 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_5 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(30 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -27 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
#[test]
fn test_gather_node_indices_dont_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("b", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(
expr,
&Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]),
)
.unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
#[test]
fn test_gather_node_indices_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(
expr,
&Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("y", DataType::Int32, true),
Field::new("z", DataType::Int32, true),
]),
)
.unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 2);
Ok(())
}
#[test]
fn test_gather_node_indices_remove_one() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(
expr,
&Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("z", DataType::Int32, true),
]),
)
.unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 1);
Ok(())
}
#[test]
fn test_gather_node_indices_cannot_provide() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(
expr,
&Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("y", DataType::Int32, true),
Field::new("z", DataType::Int32, true),
]),
)
.unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
#[test]
fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> {
let expression = BinaryExpr::new(
Arc::new(Column::new("ts_column", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))),
);
let parent = Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None),
ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None),
)?;
let left_child = Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None),
ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None),
)?;
let right_child = Interval::try_new(
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
months: 0,
days: 1,
nanoseconds: 321,
})),
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
months: 0,
days: 1,
nanoseconds: 321,
})),
)?;
let children = vec![&left_child, &right_child];
let result = expression
.propagate_constraints(&parent, &children)?
.unwrap();
assert_eq!(
vec![
Interval::try_new(
ScalarValue::TimestampNanosecond(
Some(1_602_670_272_000_000_000),
None
),
ScalarValue::TimestampNanosecond(
Some(1_602_756_672_000_000_000),
None
),
)?,
Interval::try_new(
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
months: 0,
days: 1,
nanoseconds: 321,
})),
ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
months: 0,
days: 1,
nanoseconds: 321,
})),
)?
],
result
);
Ok(())
}
#[test]
fn test_propagate_constraints_column_interval_at_left() -> Result<()> {
let expression = BinaryExpr::new(
Arc::new(Column::new("interval_column", 1)),
Operator::Plus,
Arc::new(Column::new("ts_column", 0)),
);
let parent = Interval::try_new(
ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None),
ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None),
)?;
let right_child = Interval::try_new(
ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None),
)?;
let left_child = Interval::try_new(
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 0,
milliseconds: 172_800_000,
})),
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 0,
milliseconds: 864_000_000,
})),
)?;
let children = vec![&left_child, &right_child];
let result = expression
.propagate_constraints(&parent, &children)?
.unwrap();
assert_eq!(
vec![
Interval::try_new(
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 0,
milliseconds: 172_800_000,
})),
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 0,
milliseconds: 518_400_000,
})),
)?,
Interval::try_new(
ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None),
)?
],
result
);
Ok(())
}
#[test]
fn test_propagate_comparison() -> Result<()> {
let left = Interval::make_unbounded(&DataType::Int64)?;
let right = Interval::make(Some(1000_i64), Some(1000_i64))?;
assert_eq!(
(Some((
Interval::make(None, Some(999_i64))?,
Interval::make(Some(1000_i64), Some(1000_i64))?,
))),
propagate_comparison(&Operator::Lt, &Interval::TRUE, &left, &right)?
);
let left =
Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?;
let right = Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1000), None),
ScalarValue::TimestampNanosecond(Some(1000), None),
)?;
assert_eq!(
(Some((
Interval::try_new(
ScalarValue::try_from(&DataType::Timestamp(
TimeUnit::Nanosecond,
None
))
.unwrap(),
ScalarValue::TimestampNanosecond(Some(999), None),
)?,
Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1000), None),
ScalarValue::TimestampNanosecond(Some(1000), None),
)?
))),
propagate_comparison(&Operator::Lt, &Interval::TRUE, &left, &right)?
);
let left = Interval::make_unbounded(&DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+05:00".into()),
))?;
let right = Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
)?;
assert_eq!(
(Some((
Interval::try_new(
ScalarValue::try_from(&DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+05:00".into()),
))
.unwrap(),
ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())),
)?,
Interval::try_new(
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
)?
))),
propagate_comparison(&Operator::Lt, &Interval::TRUE, &left, &right)?
);
Ok(())
}
#[test]
fn test_propagate_or() -> Result<()> {
let expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Or,
Arc::new(Column::new("b", 1)),
));
let parent = Interval::FALSE;
let children_set = vec![
vec![&Interval::FALSE, &Interval::TRUE_OR_FALSE],
vec![&Interval::TRUE_OR_FALSE, &Interval::FALSE],
vec![&Interval::FALSE, &Interval::FALSE],
vec![&Interval::TRUE_OR_FALSE, &Interval::TRUE_OR_FALSE],
];
for children in children_set {
assert_eq!(
expr.propagate_constraints(&parent, &children)?.unwrap(),
vec![Interval::FALSE, Interval::FALSE],
);
}
let parent = Interval::FALSE;
let children_set = vec![
vec![&Interval::TRUE, &Interval::TRUE_OR_FALSE],
vec![&Interval::TRUE_OR_FALSE, &Interval::TRUE],
];
for children in children_set {
assert_eq!(expr.propagate_constraints(&parent, &children)?, None,);
}
let parent = Interval::TRUE;
let children = vec![&Interval::FALSE, &Interval::TRUE_OR_FALSE];
assert_eq!(
expr.propagate_constraints(&parent, &children)?.unwrap(),
vec![Interval::FALSE, Interval::TRUE]
);
let parent = Interval::TRUE;
let children = vec![&Interval::TRUE_OR_FALSE, &Interval::TRUE_OR_FALSE];
assert_eq!(
expr.propagate_constraints(&parent, &children)?.unwrap(),
vec![]
);
Ok(())
}
#[test]
fn test_propagate_certainly_false_and() -> Result<()> {
let expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::And,
Arc::new(Column::new("b", 1)),
));
let parent = Interval::FALSE;
let children_and_results_set = vec![
(
vec![&Interval::TRUE, &Interval::TRUE_OR_FALSE],
vec![Interval::TRUE, Interval::FALSE],
),
(
vec![&Interval::TRUE_OR_FALSE, &Interval::TRUE],
vec![Interval::FALSE, Interval::TRUE],
),
(
vec![&Interval::TRUE_OR_FALSE, &Interval::TRUE_OR_FALSE],
vec![],
),
(vec![&Interval::FALSE, &Interval::TRUE_OR_FALSE], vec![]),
];
for (children, result) in children_and_results_set {
assert_eq!(
expr.propagate_constraints(&parent, &children)?.unwrap(),
result
);
}
Ok(())
}
}