use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::utils::{
convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
};
use super::IntervalBound;
use crate::expressions::Literal;
use crate::intervals::interval_aritmetic::{apply_operator, Interval};
use crate::utils::{build_dag, ExprTreeNode};
use crate::PhysicalExpr;
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::Operator;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::{DefaultIx, StableGraph};
use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
use petgraph::Outgoing;
#[derive(Clone, Debug)]
pub struct ExprIntervalGraph {
graph: StableGraph<ExprIntervalGraphNode, usize>,
root: NodeIndex,
}
impl ExprIntervalGraph {
pub fn size(&self) -> usize {
let node_memory_usage = self.graph.node_count()
* (std::mem::size_of::<ExprIntervalGraphNode>()
+ std::mem::size_of::<NodeIndex>());
let edge_memory_usage = self.graph.edge_count()
* (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>() * 2);
std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage
}
}
#[derive(PartialEq, Debug)]
pub enum PropagationResult {
CannotPropagate,
Infeasible,
Success,
}
#[derive(Clone, Debug)]
pub struct ExprIntervalGraphNode {
expr: Arc<dyn PhysicalExpr>,
interval: Interval,
}
impl Display for ExprIntervalGraphNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)
}
}
impl ExprIntervalGraphNode {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
ExprIntervalGraphNode {
expr,
interval: Interval::default(),
}
}
pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
ExprIntervalGraphNode { expr, interval }
}
pub fn interval(&self) -> &Interval {
&self.interval
}
pub fn make_node(node: &ExprTreeNode<NodeIndex>) -> ExprIntervalGraphNode {
let expr = node.expression().clone();
if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
let value = literal.value();
let interval = Interval::new(
IntervalBound::new_closed(value.clone()),
IntervalBound::new_closed(value.clone()),
);
ExprIntervalGraphNode::new_with_interval(expr, interval)
} else {
ExprIntervalGraphNode::new(expr)
}
}
}
impl PartialEq for ExprIntervalGraphNode {
fn eq(&self, other: &ExprIntervalGraphNode) -> bool {
self.expr.eq(&other.expr)
}
}
pub fn propagate_arithmetic(
op: &Operator,
parent: &Interval,
left_child: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
let inverse_op = get_inverse_op(*op);
match (left_child.get_datatype()?, right_child.get_datatype()?) {
(DataType::Timestamp(..), DataType::Interval(_)) => {
propagate_time_interval_at_right(
left_child,
right_child,
parent,
op,
&inverse_op,
)
}
(DataType::Interval(_), DataType::Timestamp(..)) => {
propagate_time_interval_at_left(
left_child,
right_child,
parent,
op,
&inverse_op,
)
}
_ => {
match apply_operator(&inverse_op, parent, right_child)?
.intersect(left_child)?
{
Some(value) => {
let right =
propagate_right(&value, parent, right_child, op, &inverse_op)?;
Ok((Some(value), right))
}
None => Ok((None, None)),
}
}
}
}
fn comparison_operator_target(
left_datatype: &DataType,
op: &Operator,
right_datatype: &DataType,
) -> Result<Interval> {
let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?;
let unbounded = IntervalBound::make_unbounded(&datatype)?;
let zero = ScalarValue::new_zero(&datatype)?;
Ok(match *op {
Operator::GtEq => Interval::new(IntervalBound::new_closed(zero), unbounded),
Operator::Gt => Interval::new(IntervalBound::new_open(zero), unbounded),
Operator::LtEq => Interval::new(unbounded, IntervalBound::new_closed(zero)),
Operator::Lt => Interval::new(unbounded, IntervalBound::new_open(zero)),
Operator::Eq => Interval::new(
IntervalBound::new_closed(zero.clone()),
IntervalBound::new_closed(zero),
),
_ => unreachable!(),
})
}
pub fn propagate_comparison(
op: &Operator,
left_child: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
let left_type = left_child.get_datatype()?;
let right_type = right_child.get_datatype()?;
let parent = comparison_operator_target(&left_type, op, &right_type)?;
match (&left_type, &right_type) {
(DataType::Interval(_), DataType::Duration(_)) => {
propagate_comparison_to_time_interval_at_left(
left_child,
&parent,
right_child,
)
}
(DataType::Duration(_), DataType::Interval(_)) => {
propagate_comparison_to_time_interval_at_left(
left_child,
&parent,
right_child,
)
}
_ => propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child),
}
}
impl ExprIntervalGraph {
pub fn try_new(expr: Arc<dyn PhysicalExpr>) -> Result<Self> {
let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?;
Ok(Self { graph, root })
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn gather_node_indices(
&mut self,
exprs: &[Arc<dyn PhysicalExpr>],
) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
let graph = &self.graph;
let mut bfs = Bfs::new(graph, self.root);
let mut removals = vec![];
let mut expr_node_indices = exprs
.iter()
.map(|e| (e.clone(), usize::MAX))
.collect::<Vec<_>>();
while let Some(node) = bfs.next(graph) {
let expr = &graph[node].expr;
if let Some(value) = exprs.iter().position(|e| expr.eq(e)) {
expr_node_indices[value].1 = node.index();
for edge in graph.edges_directed(node, Outgoing) {
removals.push(edge.id());
}
}
}
for edge_idx in removals {
self.graph.remove_edge(edge_idx);
}
let connected_nodes = self.connected_nodes();
self.graph
.retain_nodes(|_, index| connected_nodes.contains(&index));
expr_node_indices
}
fn connected_nodes(&self) -> HashSet<NodeIndex> {
let mut nodes = HashSet::new();
let mut dfs = Dfs::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
nodes.insert(node);
}
nodes
}
pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
for (index, interval) in assignments {
let node_index = NodeIndex::from(*index as DefaultIx);
self.graph[node_index].interval = interval.clone();
}
}
pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
for (index, interval) in assignments.iter_mut() {
let node_index = NodeIndex::from(*index as DefaultIx);
*interval = self.graph[node_index].interval.clone();
}
}
pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
let mut dfs = DfsPostOrder::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children_intervals = neighbors
.map(|child| self.graph[child].interval())
.collect::<Vec<_>>();
if !children_intervals.is_empty() {
children_intervals.reverse();
self.graph[node].interval =
self.graph[node].expr.evaluate_bounds(&children_intervals)?;
}
}
Ok(&self.graph[self.root].interval)
}
fn propagate_constraints(&mut self) -> Result<PropagationResult> {
let mut bfs = Bfs::new(&self.graph, self.root);
while let Some(node) = bfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children = neighbors.collect::<Vec<_>>();
if children.is_empty() {
continue;
}
children.reverse();
let children_intervals = children
.iter()
.map(|child| self.graph[*child].interval())
.collect::<Vec<_>>();
let node_interval = self.graph[node].interval();
let propagated_intervals = self.graph[node]
.expr
.propagate_constraints(node_interval, &children_intervals)?;
for (child, interval) in children.into_iter().zip(propagated_intervals) {
if let Some(interval) = interval {
self.graph[child].interval = interval;
} else {
return Ok(PropagationResult::Infeasible);
}
}
}
Ok(PropagationResult::Success)
}
pub fn update_ranges(
&mut self,
leaf_bounds: &mut [(usize, Interval)],
) -> Result<PropagationResult> {
self.assign_intervals(leaf_bounds);
let bounds = self.evaluate_bounds()?;
if bounds == &Interval::CERTAINLY_FALSE {
Ok(PropagationResult::Infeasible)
} else if bounds == &Interval::UNCERTAIN {
let result = self.propagate_constraints();
self.update_intervals(leaf_bounds);
result
} else {
Ok(PropagationResult::CannotPropagate)
}
}
pub fn get_interval(&self, index: usize) -> Interval {
self.graph[NodeIndex::new(index)].interval.clone()
}
}
fn propagate_time_interval_at_left(
left_child: &Interval,
right_child: &Interval,
parent: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<(Option<Interval>, Option<Interval>)> {
if let Some(duration) = convert_interval_type_to_duration(left_child) {
match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? {
Some(value) => {
let right = propagate_right(&value, parent, right_child, op, inverse_op)?;
let new_interval = convert_duration_type_to_interval(&value);
Ok((new_interval, right))
}
None => Ok((None, None)),
}
} else {
let right = propagate_right(left_child, parent, right_child, op, inverse_op)?;
Ok((Some(left_child.clone()), right))
}
}
fn propagate_time_interval_at_right(
left_child: &Interval,
right_child: &Interval,
parent: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<(Option<Interval>, Option<Interval>)> {
if let Some(duration) = convert_interval_type_to_duration(right_child) {
match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? {
Some(value) => {
let right =
propagate_right(left_child, parent, &duration, op, inverse_op)?;
let right =
right.and_then(|right| convert_duration_type_to_interval(&right));
Ok((Some(value), right))
}
None => Ok((None, None)),
}
} else {
match apply_operator(inverse_op, parent, right_child)?.intersect(left_child)? {
Some(value) => Ok((Some(value), Some(right_child.clone()))),
None => Ok((None, None)),
}
}
}
fn propagate_right(
left: &Interval,
parent: &Interval,
right: &Interval,
op: &Operator,
inverse_op: &Operator,
) -> Result<Option<Interval>> {
match op {
Operator::Minus => apply_operator(op, left, parent),
Operator::Plus => apply_operator(inverse_op, parent, left),
_ => unreachable!(),
}?
.intersect(right)
}
pub fn propagate_comparison_to_time_interval_at_left(
left_child: &Interval,
parent: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
if let Some(converted) = convert_interval_type_to_duration(left_child) {
propagate_arithmetic(&Operator::Minus, parent, &converted, right_child)
} else {
Err(DataFusionError::Internal(
"Interval type has a non-zero month field, cannot compare with a Duration type".to_string(),
))
}
}
pub fn propagate_comparison_to_time_interval_at_right(
left_child: &Interval,
parent: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
if let Some(converted) = convert_interval_type_to_duration(right_child) {
propagate_arithmetic(&Operator::Minus, parent, left_child, &converted)
} else {
Err(DataFusionError::Internal(
"Interval type has a non-zero month field, cannot compare with a Duration type".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use crate::expressions::{BinaryExpr, Column};
use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
use arrow::datatypes::TimeUnit;
use datafusion_common::ScalarValue;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rstest::*;
fn experiment(
expr: Arc<dyn PhysicalExpr>,
exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
left_interval: Interval,
right_interval: Interval,
left_expected: Interval,
right_expected: Interval,
result: PropagationResult,
) -> Result<()> {
let col_stats = vec![
(exprs_with_interval.0.clone(), left_interval),
(exprs_with_interval.1.clone(), right_interval),
];
let expected = vec![
(exprs_with_interval.0.clone(), left_expected),
(exprs_with_interval.1.clone(), right_expected),
];
let mut graph = ExprIntervalGraph::try_new(expr)?;
let expr_indexes = graph
.gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec());
let mut col_stat_nodes = col_stats
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let expected_nodes = expected
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?;
assert_eq!(exp_result, result);
col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
|((_, calculated_interval_node), (_, expected))| {
assert!(calculated_interval_node.lower.value <= expected.lower.value);
assert!(calculated_interval_node.upper.value >= expected.upper.value);
},
);
Ok(())
}
macro_rules! generate_cases {
($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
fn $FUNC_NAME<const ASC: bool>(
expr: Arc<dyn PhysicalExpr>,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
seed: u64,
expr_left: $TYPE,
expr_right: $TYPE,
) -> Result<()> {
let mut r = StdRng::seed_from_u64(seed);
let (left_given, right_given, left_expected, right_expected) = if ASC {
let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
(
(Some(left), None),
(Some(right), None),
(Some(<$TYPE>::max(left, right + expr_left)), None),
(Some(<$TYPE>::max(right, left + expr_right)), None),
)
} else {
let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
(
(None, Some(left)),
(None, Some(right)),
(None, Some(<$TYPE>::min(left, right + expr_left))),
(None, Some(<$TYPE>::min(right, left + expr_right))),
)
};
experiment(
expr,
(left_col, right_col),
Interval::make(left_given.0, left_given.1, (true, true)),
Interval::make(right_given.0, right_given.1, (true, true)),
Interval::make(left_expected.0, left_expected.1, (true, true)),
Interval::make(right_expected.0, right_expected.1, (true, true)),
PropagationResult::Success,
)
}
};
}
generate_cases!(generate_case_i32, i32, Int32);
generate_cases!(generate_case_i64, i64, Int64);
generate_cases!(generate_case_f32, f32, Float32);
generate_cases!(generate_case_f64, f64, Float64);
#[test]
fn testing_not_possible() -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let left_and_1 = Arc::new(BinaryExpr::new(
left_col.clone(),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
));
let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone()));
experiment(
expr,
(left_col, right_col),
Interval::make(Some(10), Some(20), (true, true)),
Interval::make(Some(100), None, (true, true)),
Interval::make(Some(10), Some(20), (true, true)),
Interval::make(Some(100), None, (true, true)),
PropagationResult::Infeasible,
)
}
macro_rules! integer_float_case_1 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Plus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(11 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(33 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 10 as $TYPE;
let r_gt_l = -30 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_2 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -7 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_3 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_4 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_5 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(30 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -27 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
#[test]
fn test_gather_node_indices_dont_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("b", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
#[test]
fn test_gather_node_indices_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 2);
Ok(())
}
#[test]
fn test_gather_node_indices_remove_one() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 1);
Ok(())
}
#[test]
fn test_gather_node_indices_cannot_provide() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
#[test]
fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> {
let expression = BinaryExpr::new(
Arc::new(Column::new("ts_column", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))),
);
let parent = Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None),
false,
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None),
false,
),
);
let left_child = Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None),
false,
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None),
false,
),
);
let right_child = Interval::new(
IntervalBound::new(
ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)),
false,
),
IntervalBound::new(
ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)),
false,
),
);
let children = vec![&left_child, &right_child];
let result = expression.propagate_constraints(&parent, &children)?;
assert_eq!(
Some(Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(
Some(1_602_670_272_000_000_000),
None
),
false,
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(
Some(1_602_756_672_000_000_000),
None
),
false,
),
)),
result[0]
);
assert_eq!(
Some(Interval::new(
IntervalBound::new(
ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)),
false,
),
IntervalBound::new(
ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)),
false,
),
)),
result[1]
);
Ok(())
}
#[test]
fn test_propagate_constraints_column_interval_at_left() -> Result<()> {
let expression = BinaryExpr::new(
Arc::new(Column::new("interval_column", 1)),
Operator::Plus,
Arc::new(Column::new("ts_column", 0)),
);
let parent = Interval::new(
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None),
false,
),
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None),
false,
),
);
let right_child = Interval::new(
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
false,
),
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None),
false,
),
);
let left_child = Interval::new(
IntervalBound::new(
ScalarValue::IntervalDayTime(Some(172_800_000)),
false,
),
IntervalBound::new(
ScalarValue::IntervalDayTime(Some(864_000_000)),
false,
),
);
let children = vec![&left_child, &right_child];
let result = expression.propagate_constraints(&parent, &children)?;
assert_eq!(
Some(Interval::new(
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
false,
),
IntervalBound::new(
ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None),
false,
)
)),
result[1]
);
assert_eq!(
Some(Interval::new(
IntervalBound::new(
ScalarValue::IntervalDayTime(Some(172_800_000)),
false,
),
IntervalBound::new(
ScalarValue::IntervalDayTime(Some(518_400_000)),
false,
),
)),
result[0]
);
Ok(())
}
#[test]
fn test_propagate_comparison() {
let left = Interval::new(
IntervalBound::make_unbounded(DataType::Int64).unwrap(),
IntervalBound::make_unbounded(DataType::Int64).unwrap(),
);
let right = Interval::new(
IntervalBound::new(ScalarValue::Int64(Some(1000)), false),
IntervalBound::new(ScalarValue::Int64(Some(1000)), false),
);
assert_eq!(
(
Some(Interval::new(
IntervalBound::make_unbounded(DataType::Int64).unwrap(),
IntervalBound::new(ScalarValue::Int64(Some(1000)), true)
)),
Some(Interval::new(
IntervalBound::new(ScalarValue::Int64(Some(1000)), false),
IntervalBound::new(ScalarValue::Int64(Some(1000)), false)
)),
),
propagate_comparison(&Operator::Lt, &left, &right).unwrap()
);
let left = Interval::new(
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
None,
))
.unwrap(),
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
None,
))
.unwrap(),
);
let right = Interval::new(
IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false),
IntervalBound::new(ScalarValue::TimestampNanosecond(Some(1000), None), false),
);
assert_eq!(
(
Some(Interval::new(
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
None
))
.unwrap(),
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1000), None),
true
)
)),
Some(Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1000), None),
false
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1000), None),
false
)
)),
),
propagate_comparison(&Operator::Lt, &left, &right).unwrap()
);
let left = Interval::new(
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+05:00".into()),
))
.unwrap(),
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+05:00".into()),
))
.unwrap(),
);
let right = Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
false,
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
false,
),
);
assert_eq!(
(
Some(Interval::new(
IntervalBound::make_unbounded(DataType::Timestamp(
TimeUnit::Nanosecond,
Some("+05:00".into()),
))
.unwrap(),
IntervalBound::new(
ScalarValue::TimestampNanosecond(
Some(1000),
Some("+05:00".into())
),
true
)
)),
Some(Interval::new(
IntervalBound::new(
ScalarValue::TimestampNanosecond(
Some(1000),
Some("+05:00".into())
),
false
),
IntervalBound::new(
ScalarValue::TimestampNanosecond(
Some(1000),
Some("+05:00".into())
),
false
)
)),
),
propagate_comparison(&Operator::Lt, &left, &right).unwrap()
);
}
}