use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use arrow_schema::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::Operator;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::{DefaultIx, StableGraph};
use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
use petgraph::Outgoing;
use crate::expressions::{BinaryExpr, CastExpr, Column, Literal};
use crate::intervals::interval_aritmetic::{
apply_operator, is_operator_supported, Interval,
};
use crate::utils::{build_dag, ExprTreeNode};
use crate::PhysicalExpr;
use super::IntervalBound;
#[derive(Clone, Debug)]
pub struct ExprIntervalGraph {
graph: StableGraph<ExprIntervalGraphNode, usize>,
root: NodeIndex,
}
impl ExprIntervalGraph {
pub fn size(&self) -> usize {
let node_memory_usage = self.graph.node_count()
* (std::mem::size_of::<ExprIntervalGraphNode>()
+ std::mem::size_of::<NodeIndex>());
let edge_memory_usage = self.graph.edge_count()
* (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>() * 2);
std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage
}
}
#[derive(PartialEq, Debug)]
pub enum PropagationResult {
CannotPropagate,
Infeasible,
Success,
}
#[derive(Clone, Debug)]
pub struct ExprIntervalGraphNode {
expr: Arc<dyn PhysicalExpr>,
interval: Interval,
}
impl Display for ExprIntervalGraphNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.expr)
}
}
impl ExprIntervalGraphNode {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
ExprIntervalGraphNode {
expr,
interval: Interval::default(),
}
}
pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
ExprIntervalGraphNode { expr, interval }
}
pub fn interval(&self) -> &Interval {
&self.interval
}
pub fn make_node(node: &ExprTreeNode<NodeIndex>) -> ExprIntervalGraphNode {
let expr = node.expression().clone();
if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
let value = literal.value();
let interval = Interval::new(
IntervalBound::new(value.clone(), false),
IntervalBound::new(value.clone(), false),
);
ExprIntervalGraphNode::new_with_interval(expr, interval)
} else {
ExprIntervalGraphNode::new(expr)
}
}
}
impl PartialEq for ExprIntervalGraphNode {
fn eq(&self, other: &ExprIntervalGraphNode) -> bool {
self.expr.eq(&other.expr)
}
}
fn get_inverse_op(op: Operator) -> Operator {
match op {
Operator::Plus => Operator::Minus,
Operator::Minus => Operator::Plus,
_ => unreachable!(),
}
}
pub fn propagate_arithmetic(
op: &Operator,
parent: &Interval,
left_child: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
let inverse_op = get_inverse_op(*op);
match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? {
Some(value) => {
let right = match op {
Operator::Minus => apply_operator(op, &value, parent),
Operator::Plus => apply_operator(&inverse_op, parent, &value),
_ => unreachable!(),
}?
.intersect(right_child)?;
Ok((Some(value), right))
}
None => Ok((None, None)),
}
}
fn comparison_operator_target(
left_datatype: &DataType,
op: &Operator,
right_datatype: &DataType,
) -> Result<Interval> {
let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?;
let unbounded = IntervalBound::make_unbounded(&datatype)?;
let zero = ScalarValue::new_zero(&datatype)?;
Ok(match *op {
Operator::GtEq => Interval::new(IntervalBound::new(zero, false), unbounded),
Operator::Gt => Interval::new(IntervalBound::new(zero, true), unbounded),
Operator::LtEq => Interval::new(unbounded, IntervalBound::new(zero, false)),
Operator::Lt => Interval::new(unbounded, IntervalBound::new(zero, true)),
Operator::Eq => Interval::new(
IntervalBound::new(zero.clone(), false),
IntervalBound::new(zero, false),
),
_ => unreachable!(),
})
}
pub fn propagate_comparison(
op: &Operator,
left_child: &Interval,
right_child: &Interval,
) -> Result<(Option<Interval>, Option<Interval>)> {
let parent = comparison_operator_target(
&left_child.get_datatype()?,
op,
&right_child.get_datatype()?,
)?;
propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child)
}
impl ExprIntervalGraph {
pub fn try_new(expr: Arc<dyn PhysicalExpr>) -> Result<Self> {
let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?;
Ok(Self { graph, root })
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn gather_node_indices(
&mut self,
exprs: &[Arc<dyn PhysicalExpr>],
) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
let graph = &self.graph;
let mut bfs = Bfs::new(graph, self.root);
let mut removals = vec![];
let mut expr_node_indices = exprs
.iter()
.map(|e| (e.clone(), usize::MAX))
.collect::<Vec<_>>();
while let Some(node) = bfs.next(graph) {
let expr = &graph[node].expr;
if let Some(value) = exprs.iter().position(|e| expr.eq(e)) {
expr_node_indices[value].1 = node.index();
for edge in graph.edges_directed(node, Outgoing) {
removals.push(edge.id());
}
}
}
for edge_idx in removals {
self.graph.remove_edge(edge_idx);
}
let connected_nodes = self.connected_nodes();
self.graph
.retain_nodes(|_, index| connected_nodes.contains(&index));
expr_node_indices
}
fn connected_nodes(&self) -> HashSet<NodeIndex> {
let mut nodes = HashSet::new();
let mut dfs = Dfs::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
nodes.insert(node);
}
nodes
}
pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
for (index, interval) in assignments {
let node_index = NodeIndex::from(*index as DefaultIx);
self.graph[node_index].interval = interval.clone();
}
}
pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
for (index, interval) in assignments.iter_mut() {
let node_index = NodeIndex::from(*index as DefaultIx);
*interval = self.graph[node_index].interval.clone();
}
}
pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
let mut dfs = DfsPostOrder::new(&self.graph, self.root);
while let Some(node) = dfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children_intervals = neighbors
.map(|child| self.graph[child].interval())
.collect::<Vec<_>>();
if !children_intervals.is_empty() {
children_intervals.reverse();
self.graph[node].interval =
self.graph[node].expr.evaluate_bounds(&children_intervals)?;
}
}
Ok(&self.graph[self.root].interval)
}
fn propagate_constraints(&mut self) -> Result<PropagationResult> {
let mut bfs = Bfs::new(&self.graph, self.root);
while let Some(node) = bfs.next(&self.graph) {
let neighbors = self.graph.neighbors_directed(node, Outgoing);
let mut children = neighbors.collect::<Vec<_>>();
if children.is_empty() {
continue;
}
children.reverse();
let children_intervals = children
.iter()
.map(|child| self.graph[*child].interval())
.collect::<Vec<_>>();
let node_interval = self.graph[node].interval();
let propagated_intervals = self.graph[node]
.expr
.propagate_constraints(node_interval, &children_intervals)?;
for (child, interval) in children.into_iter().zip(propagated_intervals) {
if let Some(interval) = interval {
self.graph[child].interval = interval;
} else {
return Ok(PropagationResult::Infeasible);
}
}
}
Ok(PropagationResult::Success)
}
pub fn update_ranges(
&mut self,
leaf_bounds: &mut [(usize, Interval)],
) -> Result<PropagationResult> {
self.assign_intervals(leaf_bounds);
let bounds = self.evaluate_bounds()?;
if bounds == &Interval::CERTAINLY_FALSE {
Ok(PropagationResult::Infeasible)
} else if bounds == &Interval::UNCERTAIN {
let result = self.propagate_constraints();
self.update_intervals(leaf_bounds);
result
} else {
Ok(PropagationResult::CannotPropagate)
}
}
pub fn get_interval(&self, index: usize) -> Interval {
self.graph[NodeIndex::new(index)].interval.clone()
}
}
pub fn check_support(expr: &Arc<dyn PhysicalExpr>) -> bool {
let expr_any = expr.as_any();
let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::<BinaryExpr>()
{
is_operator_supported(binary_expr.op())
} else {
expr_any.is::<Column>() || expr_any.is::<Literal>() || expr_any.is::<CastExpr>()
};
expr_supported && expr.children().iter().all(check_support)
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use crate::expressions::{BinaryExpr, Column};
use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
use datafusion_common::ScalarValue;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rstest::*;
fn experiment(
expr: Arc<dyn PhysicalExpr>,
exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
left_interval: Interval,
right_interval: Interval,
left_expected: Interval,
right_expected: Interval,
result: PropagationResult,
) -> Result<()> {
let col_stats = vec![
(exprs_with_interval.0.clone(), left_interval),
(exprs_with_interval.1.clone(), right_interval),
];
let expected = vec![
(exprs_with_interval.0.clone(), left_expected),
(exprs_with_interval.1.clone(), right_expected),
];
let mut graph = ExprIntervalGraph::try_new(expr)?;
let expr_indexes = graph
.gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec());
let mut col_stat_nodes = col_stats
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let expected_nodes = expected
.iter()
.zip(expr_indexes.iter())
.map(|((_, interval), (_, index))| (*index, interval.clone()))
.collect_vec();
let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?;
assert_eq!(exp_result, result);
col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
|((_, calculated_interval_node), (_, expected))| {
assert!(calculated_interval_node.lower.value <= expected.lower.value);
assert!(calculated_interval_node.upper.value >= expected.upper.value);
},
);
Ok(())
}
macro_rules! generate_cases {
($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
fn $FUNC_NAME<const ASC: bool>(
expr: Arc<dyn PhysicalExpr>,
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
seed: u64,
expr_left: $TYPE,
expr_right: $TYPE,
) -> Result<()> {
let mut r = StdRng::seed_from_u64(seed);
let (left_given, right_given, left_expected, right_expected) = if ASC {
let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
(
(Some(left), None),
(Some(right), None),
(Some(<$TYPE>::max(left, right + expr_left)), None),
(Some(<$TYPE>::max(right, left + expr_right)), None),
)
} else {
let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
(
(None, Some(left)),
(None, Some(right)),
(None, Some(<$TYPE>::min(left, right + expr_left))),
(None, Some(<$TYPE>::min(right, left + expr_right))),
)
};
experiment(
expr,
(left_col, right_col),
Interval::make(left_given.0, left_given.1, (true, true)),
Interval::make(right_given.0, right_given.1, (true, true)),
Interval::make(left_expected.0, left_expected.1, (true, true)),
Interval::make(right_expected.0, right_expected.1, (true, true)),
PropagationResult::Success,
)
}
};
}
generate_cases!(generate_case_i32, i32, Int32);
generate_cases!(generate_case_i64, i64, Int64);
generate_cases!(generate_case_f32, f32, Float32);
generate_cases!(generate_case_f64, f64, Float64);
#[test]
fn testing_not_possible() -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let left_and_1 = Arc::new(BinaryExpr::new(
left_col.clone(),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
));
let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone()));
experiment(
expr,
(left_col, right_col),
Interval::make(Some(10), Some(20), (true, true)),
Interval::make(Some(100), None, (true, true)),
Interval::make(Some(10), Some(20), (true, true)),
Interval::make(Some(100), None, (true, true)),
PropagationResult::Infeasible,
)
}
macro_rules! integer_float_case_1 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Plus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(11 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(33 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 10 as $TYPE;
let r_gt_l = -30 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_2 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Plus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -7 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_3 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Plus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(1 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 6 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_4 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Plus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -13 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
macro_rules! integer_float_case_5 {
($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
#[rstest]
#[test]
fn $TEST_FUNC_NAME(
#[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
seed: u64,
#[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
#[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
) -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
let expr = gen_conjunctive_numerical_expr(
left_col.clone(),
right_col.clone(),
(
Operator::Minus,
Operator::Minus,
Operator::Minus,
Operator::Minus,
),
ScalarValue::$SCALAR(Some(10 as $TYPE)),
ScalarValue::$SCALAR(Some(5 as $TYPE)),
ScalarValue::$SCALAR(Some(30 as $TYPE)),
ScalarValue::$SCALAR(Some(3 as $TYPE)),
(greater_op, less_op),
);
let l_gt_r = 5 as $TYPE;
let r_gt_l = -27 as $TYPE;
$GENERATE_CASE_FUNC_NAME::<true>(
expr.clone(),
left_col.clone(),
right_col.clone(),
seed,
l_gt_r,
r_gt_l,
)?;
let r_lt_l = -l_gt_r;
let l_lt_r = -r_gt_l;
$GENERATE_CASE_FUNC_NAME::<false>(
expr, left_col, right_col, seed, l_lt_r, r_lt_l,
)
}
};
}
integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
#[test]
fn test_gather_node_indices_dont_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("b", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
#[test]
fn test_gather_node_indices_remove() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 2);
Ok(())
}
#[test]
fn test_gather_node_indices_remove_one() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count + 1);
Ok(())
}
#[test]
fn test_gather_node_indices_cannot_provide() -> Result<()> {
let left_expr = Arc::new(BinaryExpr::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let right_expr = Arc::new(BinaryExpr::new(
Arc::new(Column::new("y", 0)),
Operator::Minus,
Arc::new(Column::new("z", 1)),
));
let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
let mut graph = ExprIntervalGraph::try_new(expr).unwrap();
let leaf_node = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Plus,
Arc::new(Column::new("b", 1)),
));
let prev_node_count = graph.node_count();
graph.gather_node_indices(&[leaf_node]);
let final_node_count = graph.node_count();
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
}