datafusion_physical_expr/intervals/
cp_solver.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Constraint propagator/solver for custom [`PhysicalExpr`] graphs.
19//!
20//! The constraint propagator/solver in DataFusion uses interval arithmetic to
21//! perform mathematical operations on intervals, which represent a range of
22//! possible values rather than a single point value. This allows for the
23//! propagation of ranges through mathematical operations, and can be used to
24//! compute bounds for a complicated expression. The key idea is that by
25//! breaking down a complicated expression into simpler terms, and then
26//! combining the bounds for those simpler terms, one can obtain bounds for the
27//! overall expression.
28//!
29//! This way of using interval arithmetic to compute bounds for a complex
30//! expression by combining the bounds for the constituent terms within the
31//! original expression allows us to reason about the range of possible values
32//! of the expression. This information later can be used in range pruning of
33//! the provably unnecessary parts of `RecordBatch`es.
34//!
35//! # Example
36//!
37//! For example, consider a mathematical expression such as `x^2 + y = 4` \[1\].
38//! Since this expression would be a binary tree in [`PhysicalExpr`] notation,
39//! this type of an hierarchical computation is well-suited for a graph based
40//! implementation. In such an implementation, an equation system `f(x) = 0` is
41//! represented by a directed acyclic expression graph (DAEG).
42//!
43//! In order to use interval arithmetic to compute bounds for this expression,
44//! one would first determine intervals that represent the possible values of
45//! `x` and `y` Let's say that the interval for `x` is `[1, 2]` and the interval
46//! for `y` is `[-3, 1]`. In the chart below, you can see how the computation
47//! takes place.
48//!
49//! # References
50//!
51//! 1. Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval
52//!    Arithmetic Based Approach, Chapter 4. Stanford University, 2015.
53//! 2. Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966.
54//! 3. F. Messine, "Deterministic global optimization using interval constraint
55//!    propagation techniques," RAIRO-Operations Research, vol. 38, no. 04,
56//!    pp. 277-293, 2004.
57//!
58//! # Illustration
59//!
60//! ## Computing bounds for an expression using interval arithmetic
61//!
62//! ```text
63//!             +-----+                         +-----+
64//!        +----|  +  |----+               +----|  +  |----+
65//!        |    |     |    |               |    |     |    |
66//!        |    +-----+    |               |    +-----+    |
67//!        |               |               |               |
68//!    +-----+           +-----+       +-----+           +-----+
69//!    |   2 |           |  y  |       |   2 | [1, 4]    |  y  |
70//!    |[.]  |           |     |       |[.]  |           |     |
71//!    +-----+           +-----+       +-----+           +-----+
72//!       |                               |
73//!       |                               |
74//!     +---+                           +---+
75//!     | x | [1, 2]                    | x | [1, 2]
76//!     +---+                           +---+
77//!
78//!  (a) Bottom-up evaluation: Step 1 (b) Bottom up evaluation: Step 2
79//!
80//!                                      [1 - 3, 4 + 1] = [-2, 5]
81//!             +-----+                         +-----+
82//!        +----|  +  |----+               +----|  +  |----+
83//!        |    |     |    |               |    |     |    |
84//!        |    +-----+    |               |    +-----+    |
85//!        |               |               |               |
86//!    +-----+           +-----+       +-----+           +-----+
87//!    |   2 |[1, 4]     |  y  |       |   2 |[1, 4]     |  y  |
88//!    |[.]  |           |     |       |[.]  |           |     |
89//!    +-----+           +-----+       +-----+           +-----+
90//!       |              [-3, 1]          |              [-3, 1]
91//!       |                               |
92//!     +---+                           +---+
93//!     | x | [1, 2]                    | x | [1, 2]
94//!     +---+                           +---+
95//!
96//!  (c) Bottom-up evaluation: Step 3 (d) Bottom-up evaluation: Step 4
97//! ```
98//!
99//! ## Top-down constraint propagation using inverse semantics
100//!
101//! ```text
102//!    [-2, 5] ∩ [4, 4] = [4, 4]               [4, 4]
103//!            +-----+                         +-----+
104//!       +----|  +  |----+               +----|  +  |----+
105//!       |    |     |    |               |    |     |    |
106//!       |    +-----+    |               |    +-----+    |
107//!       |               |               |               |
108//!    +-----+           +-----+       +-----+           +-----+
109//!    |   2 | [1, 4]    |  y  |       |   2 | [1, 4]    |  y  | [0, 1]*
110//!    |[.]  |           |     |       |[.]  |           |     |
111//!    +-----+           +-----+       +-----+           +-----+
112//!      |              [-3, 1]          |
113//!      |                               |
114//!    +---+                           +---+
115//!    | x | [1, 2]                    | x | [1, 2]
116//!    +---+                           +---+
117//!
118//!  (a) Top-down propagation: Step 1 (b) Top-down propagation: Step 2
119//!
120//!                                     [1 - 3, 4 + 1] = [-2, 5]
121//!            +-----+                         +-----+
122//!       +----|  +  |----+               +----|  +  |----+
123//!       |    |     |    |               |    |     |    |
124//!       |    +-----+    |               |    +-----+    |
125//!       |               |               |               |
126//!    +-----+           +-----+       +-----+           +-----+
127//!    |   2 |[3, 4]**   |  y  |       |   2 |[3, 4]     |  y  |
128//!    |[.]  |           |     |       |[.]  |           |     |
129//!    +-----+           +-----+       +-----+           +-----+
130//!      |              [0, 1]           |              [-3, 1]
131//!      |                               |
132//!    +---+                           +---+
133//!    | x | [1, 2]                    | x | [sqrt(3), 2]***
134//!    +---+                           +---+
135//!
136//!  (c) Top-down propagation: Step 3  (d) Top-down propagation: Step 4
137//!
138//!    * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1]
139//!    ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4]
140//!    *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2]
141//! ```
142
143use std::collections::HashSet;
144use std::fmt::{Display, Formatter};
145use std::mem::{size_of, size_of_val};
146use std::sync::Arc;
147
148use super::utils::{
149    convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
150};
151use crate::expressions::{BinaryExpr, Literal};
152use crate::utils::{build_dag, ExprTreeNode};
153use crate::PhysicalExpr;
154
155use arrow::datatypes::{DataType, Schema};
156use datafusion_common::{internal_err, not_impl_err, Result};
157use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval};
158use datafusion_expr::Operator;
159
160use petgraph::graph::NodeIndex;
161use petgraph::stable_graph::{DefaultIx, StableGraph};
162use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
163use petgraph::Outgoing;
164
165/// This object implements a directed acyclic expression graph (DAEG) that
166/// is used to compute ranges for expressions through interval arithmetic.
167#[derive(Clone, Debug)]
168pub struct ExprIntervalGraph {
169    graph: StableGraph<ExprIntervalGraphNode, usize>,
170    root: NodeIndex,
171}
172
173/// This object encapsulates all possible constraint propagation results.
174#[derive(PartialEq, Debug)]
175pub enum PropagationResult {
176    CannotPropagate,
177    Infeasible,
178    Success,
179}
180
181/// This is a node in the DAEG; it encapsulates a reference to the actual
182/// [`PhysicalExpr`] as well as an interval containing expression bounds.
183#[derive(Clone, Debug)]
184pub struct ExprIntervalGraphNode {
185    expr: Arc<dyn PhysicalExpr>,
186    interval: Interval,
187}
188
189impl PartialEq for ExprIntervalGraphNode {
190    fn eq(&self, other: &Self) -> bool {
191        self.expr.eq(&other.expr)
192    }
193}
194
195impl Display for ExprIntervalGraphNode {
196    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
197        write!(f, "{}", self.expr)
198    }
199}
200
201impl ExprIntervalGraphNode {
202    /// Constructs a new DAEG node with an `[-∞, ∞]` range.
203    pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
204        Interval::make_unbounded(dt)
205            .map(|interval| ExprIntervalGraphNode { expr, interval })
206    }
207
208    /// Constructs a new DAEG node with the given range.
209    pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
210        ExprIntervalGraphNode { expr, interval }
211    }
212
213    /// Get the interval object representing the range of the expression.
214    pub fn interval(&self) -> &Interval {
215        &self.interval
216    }
217
218    /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`]
219    /// object. Literals are created with definite, singleton intervals while
220    /// any other expression starts with an indefinite interval (`[-∞, ∞]`).
221    pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
222        let expr = Arc::clone(&node.expr);
223        if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
224            let value = literal.value();
225            Interval::try_new(value.clone(), value.clone())
226                .map(|interval| Self::new_with_interval(expr, interval))
227        } else {
228            expr.data_type(schema)
229                .and_then(|dt| Self::new_unbounded(expr, &dt))
230        }
231    }
232}
233
234/// This function refines intervals `left_child` and `right_child` by applying
235/// constraint propagation through `parent` via operation. The main idea is
236/// that we can shrink ranges of variables x and y using parent interval p.
237///
238/// Assuming that x,y and p has ranges `[xL, xU]`, `[yL, yU]`, and `[pL, pU]`, we
239/// apply the following operations:
240/// - For plus operation, specifically, we would first do
241///     - `[xL, xU]` <- (`[pL, pU]` - `[yL, yU]`) ∩ `[xL, xU]`, and then
242///     - `[yL, yU]` <- (`[pL, pU]` - `[xL, xU]`) ∩ `[yL, yU]`.
243/// - For minus operation, specifically, we would first do
244///     - `[xL, xU]` <- (`[yL, yU]` + `[pL, pU]`) ∩ `[xL, xU]`, and then
245///     - `[yL, yU]` <- (`[xL, xU]` - `[pL, pU]`) ∩ `[yL, yU]`.
246/// - For multiplication operation, specifically, we would first do
247///     - `[xL, xU]` <- (`[pL, pU]` / `[yL, yU]`) ∩ `[xL, xU]`, and then
248///     - `[yL, yU]` <- (`[pL, pU]` / `[xL, xU]`) ∩ `[yL, yU]`.
249/// - For division operation, specifically, we would first do
250///     - `[xL, xU]` <- (`[yL, yU]` * `[pL, pU]`) ∩ `[xL, xU]`, and then
251///     - `[yL, yU]` <- (`[xL, xU]` / `[pL, pU]`) ∩ `[yL, yU]`.
252pub fn propagate_arithmetic(
253    op: &Operator,
254    parent: &Interval,
255    left_child: &Interval,
256    right_child: &Interval,
257) -> Result<Option<(Interval, Interval)>> {
258    let inverse_op = get_inverse_op(*op)?;
259    match (left_child.data_type(), right_child.data_type()) {
260        // If we have a child whose type is a time interval (i.e. DataType::Interval),
261        // we need special handling since timestamp differencing results in a
262        // Duration type.
263        (DataType::Timestamp(..), DataType::Interval(_)) => {
264            propagate_time_interval_at_right(
265                left_child,
266                right_child,
267                parent,
268                op,
269                &inverse_op,
270            )
271        }
272        (DataType::Interval(_), DataType::Timestamp(..)) => {
273            propagate_time_interval_at_left(
274                left_child,
275                right_child,
276                parent,
277                op,
278                &inverse_op,
279            )
280        }
281        _ => {
282            // First, propagate to the left:
283            match apply_operator(&inverse_op, parent, right_child)?
284                .intersect(left_child)?
285            {
286                // Left is feasible:
287                Some(value) => Ok(
288                    // Propagate to the right using the new left.
289                    propagate_right(&value, parent, right_child, op, &inverse_op)?
290                        .map(|right| (value, right)),
291                ),
292                // If the left child is infeasible, short-circuit.
293                None => Ok(None),
294            }
295        }
296    }
297}
298
299/// This function refines intervals `left_child` and `right_child` by applying
300/// comparison propagation through `parent` via operation. The main idea is
301/// that we can shrink ranges of variables x and y using parent interval p.
302/// Two intervals can be ordered in 6 ways for a Gt `>` operator:
303/// ```text
304///                           (1): Infeasible, short-circuit
305/// left:   |        ================                                               |
306/// right:  |                           ========================                    |
307///
308///                             (2): Update both interval
309/// left:   |              ======================                                   |
310/// right:  |                             ======================                    |
311///                                          |
312///                                          V
313/// left:   |                             =======                                   |
314/// right:  |                             =======                                   |
315///
316///                             (3): Update left interval
317/// left:   |                  ==============================                       |
318/// right:  |                           ==========                                  |
319///                                          |
320///                                          V
321/// left:   |                           =====================                       |
322/// right:  |                           ==========                                  |
323///
324///                             (4): Update right interval
325/// left:   |                           ==========                                  |
326/// right:  |                   ===========================                         |
327///                                          |
328///                                          V
329/// left:   |                           ==========                                  |
330/// right   |                   ==================                                  |
331///
332///                                   (5): No change
333/// left:   |                       ============================                    |
334/// right:  |               ===================                                     |
335///
336///                                   (6): No change
337/// left:   |                                    ====================               |
338/// right:  |                ===============                                        |
339///
340///         -inf --------------------------------------------------------------- +inf
341/// ```
342pub fn propagate_comparison(
343    op: &Operator,
344    parent: &Interval,
345    left_child: &Interval,
346    right_child: &Interval,
347) -> Result<Option<(Interval, Interval)>> {
348    if parent == &Interval::CERTAINLY_TRUE {
349        match op {
350            Operator::Eq => left_child.intersect(right_child).map(|result| {
351                result.map(|intersection| (intersection.clone(), intersection))
352            }),
353            Operator::Gt => satisfy_greater(left_child, right_child, true),
354            Operator::GtEq => satisfy_greater(left_child, right_child, false),
355            Operator::Lt => satisfy_greater(right_child, left_child, true)
356                .map(|t| t.map(reverse_tuple)),
357            Operator::LtEq => satisfy_greater(right_child, left_child, false)
358                .map(|t| t.map(reverse_tuple)),
359            _ => internal_err!(
360                "The operator must be a comparison operator to propagate intervals"
361            ),
362        }
363    } else if parent == &Interval::CERTAINLY_FALSE {
364        match op {
365            Operator::Eq => {
366                // TODO: Propagation is not possible until we support interval sets.
367                Ok(None)
368            }
369            Operator::Gt => satisfy_greater(right_child, left_child, false),
370            Operator::GtEq => satisfy_greater(right_child, left_child, true),
371            Operator::Lt => satisfy_greater(left_child, right_child, false)
372                .map(|t| t.map(reverse_tuple)),
373            Operator::LtEq => satisfy_greater(left_child, right_child, true)
374                .map(|t| t.map(reverse_tuple)),
375            _ => internal_err!(
376                "The operator must be a comparison operator to propagate intervals"
377            ),
378        }
379    } else {
380        // Uncertainty cannot change any end-point of the intervals.
381        Ok(None)
382    }
383}
384
385impl ExprIntervalGraph {
386    pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
387        // Build the full graph:
388        let (root, graph) =
389            build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?;
390        Ok(Self { graph, root })
391    }
392
393    pub fn node_count(&self) -> usize {
394        self.graph.node_count()
395    }
396
397    /// Estimate size of bytes including `Self`.
398    pub fn size(&self) -> usize {
399        let node_memory_usage = self.graph.node_count()
400            * (size_of::<ExprIntervalGraphNode>() + size_of::<NodeIndex>());
401        let edge_memory_usage =
402            self.graph.edge_count() * (size_of::<usize>() + size_of::<NodeIndex>() * 2);
403
404        size_of_val(self) + node_memory_usage + edge_memory_usage
405    }
406
407    // Sometimes, we do not want to calculate and/or propagate intervals all
408    // way down to leaf expressions. For example, assume that we have a
409    // `SymmetricHashJoin` which has a child with an output ordering like:
410    //
411    // ```text
412    // PhysicalSortExpr {
413    //     expr: BinaryExpr('a', +, 'b'),
414    //     sort_option: ..
415    // }
416    // ```
417    //
418    // i.e. its output order comes from a clause like `ORDER BY a + b`. In such
419    // a case, we must calculate the interval for the `BinaryExpr(a, +, b)`
420    // instead of the columns inside this `BinaryExpr`, because this interval
421    // decides whether we prune or not. Therefore, children `PhysicalExpr`s of
422    // this `BinaryExpr` may be pruned for performance. The figure below
423    // explains this example visually.
424    //
425    // Note that we just remove the nodes from the DAEG, do not make any change
426    // to the plan itself.
427    //
428    // ```text
429    //
430    //                                  +-----+                                          +-----+
431    //                                  | GT  |                                          | GT  |
432    //                         +--------|     |-------+                         +--------|     |-------+
433    //                         |        +-----+       |                         |        +-----+       |
434    //                         |                      |                         |                      |
435    //                      +-----+                   |                      +-----+                   |
436    //                      |Cast |                   |                      |Cast |                   |
437    //                      |     |                   |             --\      |     |                   |
438    //                      +-----+                   |       ----------     +-----+                   |
439    //                         |                      |             --/         |                      |
440    //                         |                      |                         |                      |
441    //                      +-----+                +-----+                   +-----+                +-----+
442    //                   +--|Plus |--+          +--|Plus |--+                |Plus |             +--|Plus |--+
443    //                   |  |     |  |          |  |     |  |                |     |             |  |     |  |
444    //  Prune from here  |  +-----+  |          |  +-----+  |                +-----+             |  +-----+  |
445    //  ------------------------------------    |           |                                    |           |
446    //                   |           |          |           |                                    |           |
447    //                +-----+     +-----+    +-----+     +-----+                              +-----+     +-----+
448    //                | a   |     |  b  |    |  c  |     |  2  |                              |  c  |     |  2  |
449    //                |     |     |     |    |     |     |     |                              |     |     |     |
450    //                +-----+     +-----+    +-----+     +-----+                              +-----+     +-----+
451    //
452    // ```
453
454    /// This function associates stable node indices with [`PhysicalExpr`]s so
455    /// that we can match `Arc<dyn PhysicalExpr>` and NodeIndex objects during
456    /// membership tests.
457    pub fn gather_node_indices(
458        &mut self,
459        exprs: &[Arc<dyn PhysicalExpr>],
460    ) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
461        let graph = &self.graph;
462        let mut bfs = Bfs::new(graph, self.root);
463        // We collect the node indices (usize) of [PhysicalExpr]s in the order
464        // given by argument `exprs`. To preserve this order, we initialize each
465        // expression's node index with usize::MAX, and then find the corresponding
466        // node indices by traversing the graph.
467        let mut removals = vec![];
468        let mut expr_node_indices = exprs
469            .iter()
470            .map(|e| (Arc::clone(e), usize::MAX))
471            .collect::<Vec<_>>();
472        while let Some(node) = bfs.next(graph) {
473            // Get the plan corresponding to this node:
474            let expr = &graph[node].expr;
475            // If the current expression is among `exprs`, slate its children
476            // for removal:
477            if let Some(value) = exprs.iter().position(|e| expr.eq(e)) {
478                // Update the node index of the associated `PhysicalExpr`:
479                expr_node_indices[value].1 = node.index();
480                for edge in graph.edges_directed(node, Outgoing) {
481                    // Slate the child for removal, do not remove immediately.
482                    removals.push(edge.id());
483                }
484            }
485        }
486        for edge_idx in removals {
487            self.graph.remove_edge(edge_idx);
488        }
489        // Get the set of node indices reachable from the root node:
490        let connected_nodes = self.connected_nodes();
491        // Remove nodes not connected to the root node:
492        self.graph
493            .retain_nodes(|_, index| connected_nodes.contains(&index));
494        expr_node_indices
495    }
496
497    /// Returns the set of node indices reachable from the root node via a
498    /// simple depth-first search.
499    fn connected_nodes(&self) -> HashSet<NodeIndex> {
500        let mut nodes = HashSet::new();
501        let mut dfs = Dfs::new(&self.graph, self.root);
502        while let Some(node) = dfs.next(&self.graph) {
503            nodes.insert(node);
504        }
505        nodes
506    }
507
508    /// Updates intervals for all expressions in the DAEG by successive
509    /// bottom-up and top-down traversals.
510    pub fn update_ranges(
511        &mut self,
512        leaf_bounds: &mut [(usize, Interval)],
513        given_range: Interval,
514    ) -> Result<PropagationResult> {
515        self.assign_intervals(leaf_bounds);
516        let bounds = self.evaluate_bounds()?;
517        // There are three possible cases to consider:
518        // (1) given_range ⊇ bounds => Nothing to propagate
519        // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate
520        // (3) Disjoint sets => Infeasible
521        if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE {
522            // First case:
523            Ok(PropagationResult::CannotPropagate)
524        } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE {
525            // Second case:
526            let result = self.propagate_constraints(given_range);
527            self.update_intervals(leaf_bounds);
528            result
529        } else {
530            // Third case:
531            Ok(PropagationResult::Infeasible)
532        }
533    }
534
535    /// This function assigns given ranges to expressions in the DAEG.
536    /// The argument `assignments` associates indices of sought expressions
537    /// with their corresponding new ranges.
538    pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
539        for (index, interval) in assignments {
540            let node_index = NodeIndex::from(*index as DefaultIx);
541            self.graph[node_index].interval = interval.clone();
542        }
543    }
544
545    /// This function fetches ranges of expressions from the DAEG. The argument
546    /// `assignments` associates indices of sought expressions with their ranges,
547    /// which this function modifies to reflect the intervals in the DAEG.
548    pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
549        for (index, interval) in assignments.iter_mut() {
550            let node_index = NodeIndex::from(*index as DefaultIx);
551            *interval = self.graph[node_index].interval.clone();
552        }
553    }
554
555    /// Computes bounds for an expression using interval arithmetic via a
556    /// bottom-up traversal.
557    ///
558    /// # Examples
559    ///
560    /// ```
561    /// use arrow::datatypes::DataType;
562    /// use arrow::datatypes::Field;
563    /// use arrow::datatypes::Schema;
564    /// use datafusion_common::ScalarValue;
565    /// use datafusion_expr::interval_arithmetic::Interval;
566    /// use datafusion_expr::Operator;
567    /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
568    /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
569    /// use datafusion_physical_expr::PhysicalExpr;
570    /// use std::sync::Arc;
571    ///
572    /// let expr = Arc::new(BinaryExpr::new(
573    ///     Arc::new(Column::new("gnz", 0)),
574    ///     Operator::Plus,
575    ///     Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
576    /// ));
577    ///
578    /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]);
579    ///
580    /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap();
581    /// // Do it once, while constructing.
582    /// let node_indices = graph
583    ///     .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]);
584    /// let left_index = node_indices.get(0).unwrap().1;
585    ///
586    /// // Provide intervals for leaf variables (here, there is only one).
587    /// let intervals = vec![(
588    ///     left_index,
589    ///     Interval::make(Some(10), Some(20)).unwrap(),
590    /// )];
591    ///
592    /// // Evaluate bounds for the composite expression:
593    /// graph.assign_intervals(&intervals);
594    /// assert_eq!(
595    ///     graph.evaluate_bounds().unwrap(),
596    ///     &Interval::make(Some(20), Some(30)).unwrap(),
597    /// )
598    /// ```
599    pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
600        let mut dfs = DfsPostOrder::new(&self.graph, self.root);
601        while let Some(node) = dfs.next(&self.graph) {
602            let neighbors = self.graph.neighbors_directed(node, Outgoing);
603            let mut children_intervals = neighbors
604                .map(|child| self.graph[child].interval())
605                .collect::<Vec<_>>();
606            // If the current expression is a leaf, its interval should already
607            // be set externally, just continue with the evaluation procedure:
608            if !children_intervals.is_empty() {
609                // Reverse to align with `PhysicalExpr`'s children:
610                children_intervals.reverse();
611                self.graph[node].interval =
612                    self.graph[node].expr.evaluate_bounds(&children_intervals)?;
613            }
614        }
615        Ok(self.graph[self.root].interval())
616    }
617
618    /// Updates/shrinks bounds for leaf expressions using interval arithmetic
619    /// via a top-down traversal.
620    fn propagate_constraints(
621        &mut self,
622        given_range: Interval,
623    ) -> Result<PropagationResult> {
624        // Adjust the root node with the given range:
625        if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? {
626            self.graph[self.root].interval = interval;
627        } else {
628            return Ok(PropagationResult::Infeasible);
629        }
630
631        let mut bfs = Bfs::new(&self.graph, self.root);
632
633        while let Some(node) = bfs.next(&self.graph) {
634            let neighbors = self.graph.neighbors_directed(node, Outgoing);
635            let mut children = neighbors.collect::<Vec<_>>();
636            // If the current expression is a leaf, its range is now final.
637            // So, just continue with the propagation procedure:
638            if children.is_empty() {
639                continue;
640            }
641            // Reverse to align with `PhysicalExpr`'s children:
642            children.reverse();
643            let children_intervals = children
644                .iter()
645                .map(|child| self.graph[*child].interval())
646                .collect::<Vec<_>>();
647            let node_interval = self.graph[node].interval();
648            // Special case: true OR could in principle be propagated by 3 interval sets,
649            // (i.e. left true, or right true, or both true) however we do not support this yet.
650            if node_interval == &Interval::CERTAINLY_TRUE
651                && self.graph[node]
652                    .expr
653                    .as_any()
654                    .downcast_ref::<BinaryExpr>()
655                    .is_some_and(|expr| expr.op() == &Operator::Or)
656            {
657                return not_impl_err!("OR operator cannot yet propagate true intervals");
658            }
659            let propagated_intervals = self.graph[node]
660                .expr
661                .propagate_constraints(node_interval, &children_intervals)?;
662            if let Some(propagated_intervals) = propagated_intervals {
663                for (child, interval) in children.into_iter().zip(propagated_intervals) {
664                    self.graph[child].interval = interval;
665                }
666            } else {
667                // The constraint is infeasible, report:
668                return Ok(PropagationResult::Infeasible);
669            }
670        }
671        Ok(PropagationResult::Success)
672    }
673
674    /// Returns the interval associated with the node at the given `index`.
675    pub fn get_interval(&self, index: usize) -> Interval {
676        self.graph[NodeIndex::new(index)].interval.clone()
677    }
678}
679
680/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child.
681fn propagate_right(
682    left: &Interval,
683    parent: &Interval,
684    right: &Interval,
685    op: &Operator,
686    inverse_op: &Operator,
687) -> Result<Option<Interval>> {
688    match op {
689        Operator::Minus => apply_operator(op, left, parent),
690        Operator::Plus => apply_operator(inverse_op, parent, left),
691        Operator::Divide => apply_operator(op, left, parent),
692        Operator::Multiply => apply_operator(inverse_op, parent, left),
693        _ => internal_err!("Interval arithmetic does not support the operator {}", op),
694    }?
695    .intersect(right)
696}
697
698/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
699/// if there exists a `timestamp - timestamp` operation, the result would be
700/// of type `Duration`. However, we may encounter a situation where a time interval
701/// is involved in an arithmetic operation with a `Duration` type. This function
702/// offers special handling for such cases, where the time interval resides on
703/// the left side of the operation.
704fn propagate_time_interval_at_left(
705    left_child: &Interval,
706    right_child: &Interval,
707    parent: &Interval,
708    op: &Operator,
709    inverse_op: &Operator,
710) -> Result<Option<(Interval, Interval)>> {
711    // We check if the child's time interval(s) has a non-zero month or day field(s).
712    // If so, we return it as is without propagating. Otherwise, we first convert
713    // the time intervals to the `Duration` type, then propagate, and then convert
714    // the bounds to time intervals again.
715    let result = if let Some(duration) = convert_interval_type_to_duration(left_child) {
716        match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? {
717            Some(value) => {
718                let left = convert_duration_type_to_interval(&value);
719                let right = propagate_right(&value, parent, right_child, op, inverse_op)?;
720                match (left, right) {
721                    (Some(left), Some(right)) => Some((left, right)),
722                    _ => None,
723                }
724            }
725            None => None,
726        }
727    } else {
728        propagate_right(left_child, parent, right_child, op, inverse_op)?
729            .map(|right| (left_child.clone(), right))
730    };
731    Ok(result)
732}
733
734/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
735/// if there exists a `timestamp - timestamp` operation, the result would be
736/// of type `Duration`. However, we may encounter a situation where a time interval
737/// is involved in an arithmetic operation with a `Duration` type. This function
738/// offers special handling for such cases, where the time interval resides on
739/// the right side of the operation.
740fn propagate_time_interval_at_right(
741    left_child: &Interval,
742    right_child: &Interval,
743    parent: &Interval,
744    op: &Operator,
745    inverse_op: &Operator,
746) -> Result<Option<(Interval, Interval)>> {
747    // We check if the child's time interval(s) has a non-zero month or day field(s).
748    // If so, we return it as is without propagating. Otherwise, we first convert
749    // the time intervals to the `Duration` type, then propagate, and then convert
750    // the bounds to time intervals again.
751    let result = if let Some(duration) = convert_interval_type_to_duration(right_child) {
752        match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? {
753            Some(value) => {
754                propagate_right(left_child, parent, &duration, op, inverse_op)?
755                    .and_then(|right| convert_duration_type_to_interval(&right))
756                    .map(|right| (value, right))
757            }
758            None => None,
759        }
760    } else {
761        apply_operator(inverse_op, parent, right_child)?
762            .intersect(left_child)?
763            .map(|value| (value, right_child.clone()))
764    };
765    Ok(result)
766}
767
768fn reverse_tuple<T, U>((first, second): (T, U)) -> (U, T) {
769    (second, first)
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use crate::expressions::{BinaryExpr, Column};
776    use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
777
778    use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
779    use arrow::datatypes::{Field, TimeUnit};
780    use datafusion_common::ScalarValue;
781
782    use itertools::Itertools;
783    use rand::rngs::StdRng;
784    use rand::{Rng, SeedableRng};
785    use rstest::*;
786
787    #[allow(clippy::too_many_arguments)]
788    fn experiment(
789        expr: Arc<dyn PhysicalExpr>,
790        exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
791        left_interval: Interval,
792        right_interval: Interval,
793        left_expected: Interval,
794        right_expected: Interval,
795        result: PropagationResult,
796        schema: &Schema,
797    ) -> Result<()> {
798        let col_stats = vec![
799            (Arc::clone(&exprs_with_interval.0), left_interval),
800            (Arc::clone(&exprs_with_interval.1), right_interval),
801        ];
802        let expected = vec![
803            (Arc::clone(&exprs_with_interval.0), left_expected),
804            (Arc::clone(&exprs_with_interval.1), right_expected),
805        ];
806        let mut graph = ExprIntervalGraph::try_new(expr, schema)?;
807        let expr_indexes = graph.gather_node_indices(
808            &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(),
809        );
810
811        let mut col_stat_nodes = col_stats
812            .iter()
813            .zip(expr_indexes.iter())
814            .map(|((_, interval), (_, index))| (*index, interval.clone()))
815            .collect_vec();
816        let expected_nodes = expected
817            .iter()
818            .zip(expr_indexes.iter())
819            .map(|((_, interval), (_, index))| (*index, interval.clone()))
820            .collect_vec();
821
822        let exp_result =
823            graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?;
824        assert_eq!(exp_result, result);
825        col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
826            |((_, calculated_interval_node), (_, expected))| {
827                // NOTE: These randomized tests only check for conservative containment,
828                // not openness/closedness of endpoints.
829
830                // Calculated bounds are relaxed by 1 to cover all strict and
831                // and non-strict comparison cases since we have only closed bounds.
832                let one = ScalarValue::new_one(&expected.data_type()).unwrap();
833                assert!(
834                    calculated_interval_node.lower()
835                        <= &expected.lower().add(&one).unwrap(),
836                    "{}",
837                    format!(
838                        "Calculated {} must be less than or equal {}",
839                        calculated_interval_node.lower(),
840                        expected.lower()
841                    )
842                );
843                assert!(
844                    calculated_interval_node.upper()
845                        >= &expected.upper().sub(&one).unwrap(),
846                    "{}",
847                    format!(
848                        "Calculated {} must be greater than or equal {}",
849                        calculated_interval_node.upper(),
850                        expected.upper()
851                    )
852                );
853            },
854        );
855        Ok(())
856    }
857
858    macro_rules! generate_cases {
859        ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
860            fn $FUNC_NAME<const ASC: bool>(
861                expr: Arc<dyn PhysicalExpr>,
862                left_col: Arc<dyn PhysicalExpr>,
863                right_col: Arc<dyn PhysicalExpr>,
864                seed: u64,
865                expr_left: $TYPE,
866                expr_right: $TYPE,
867            ) -> Result<()> {
868                let mut r = StdRng::seed_from_u64(seed);
869
870                let (left_given, right_given, left_expected, right_expected) = if ASC {
871                    let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
872                    let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
873                    (
874                        (Some(left), None),
875                        (Some(right), None),
876                        (Some(<$TYPE>::max(left, right + expr_left)), None),
877                        (Some(<$TYPE>::max(right, left + expr_right)), None),
878                    )
879                } else {
880                    let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
881                    let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
882                    (
883                        (None, Some(left)),
884                        (None, Some(right)),
885                        (None, Some(<$TYPE>::min(left, right + expr_left))),
886                        (None, Some(<$TYPE>::min(right, left + expr_right))),
887                    )
888                };
889
890                experiment(
891                    expr,
892                    (left_col.clone(), right_col.clone()),
893                    Interval::make(left_given.0, left_given.1).unwrap(),
894                    Interval::make(right_given.0, right_given.1).unwrap(),
895                    Interval::make(left_expected.0, left_expected.1).unwrap(),
896                    Interval::make(right_expected.0, right_expected.1).unwrap(),
897                    PropagationResult::Success,
898                    &Schema::new(vec![
899                        Field::new(
900                            left_col.as_any().downcast_ref::<Column>().unwrap().name(),
901                            DataType::$SCALAR,
902                            true,
903                        ),
904                        Field::new(
905                            right_col.as_any().downcast_ref::<Column>().unwrap().name(),
906                            DataType::$SCALAR,
907                            true,
908                        ),
909                    ]),
910                )
911            }
912        };
913    }
914    generate_cases!(generate_case_i32, i32, Int32);
915    generate_cases!(generate_case_i64, i64, Int64);
916    generate_cases!(generate_case_f32, f32, Float32);
917    generate_cases!(generate_case_f64, f64, Float64);
918
919    #[test]
920    fn testing_not_possible() -> Result<()> {
921        let left_col = Arc::new(Column::new("left_watermark", 0));
922        let right_col = Arc::new(Column::new("right_watermark", 0));
923
924        // left_watermark > right_watermark + 5
925        let left_and_1 = Arc::new(BinaryExpr::new(
926            Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
927            Operator::Plus,
928            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
929        ));
930        let expr = Arc::new(BinaryExpr::new(
931            left_and_1,
932            Operator::Gt,
933            Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
934        ));
935        experiment(
936            expr,
937            (
938                Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
939                Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
940            ),
941            Interval::make(Some(10_i32), Some(20_i32))?,
942            Interval::make(Some(100), None)?,
943            Interval::make(Some(10), Some(20))?,
944            Interval::make(Some(100), None)?,
945            PropagationResult::Infeasible,
946            &Schema::new(vec![
947                Field::new(
948                    left_col.as_any().downcast_ref::<Column>().unwrap().name(),
949                    DataType::Int32,
950                    true,
951                ),
952                Field::new(
953                    right_col.as_any().downcast_ref::<Column>().unwrap().name(),
954                    DataType::Int32,
955                    true,
956                ),
957            ]),
958        )
959    }
960
961    macro_rules! integer_float_case_1 {
962        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
963            #[rstest]
964            #[test]
965            fn $TEST_FUNC_NAME(
966                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
967                seed: u64,
968                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
969                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
970            ) -> Result<()> {
971                let left_col = Arc::new(Column::new("left_watermark", 0));
972                let right_col = Arc::new(Column::new("right_watermark", 0));
973
974                // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33
975                let expr = gen_conjunctive_numerical_expr(
976                    left_col.clone(),
977                    right_col.clone(),
978                    (
979                        Operator::Plus,
980                        Operator::Plus,
981                        Operator::Plus,
982                        Operator::Plus,
983                    ),
984                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
985                    ScalarValue::$SCALAR(Some(11 as $TYPE)),
986                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
987                    ScalarValue::$SCALAR(Some(33 as $TYPE)),
988                    (greater_op, less_op),
989                );
990                // l > r + 10 AND r > l - 30
991                let l_gt_r = 10 as $TYPE;
992                let r_gt_l = -30 as $TYPE;
993                $GENERATE_CASE_FUNC_NAME::<true>(
994                    expr.clone(),
995                    left_col.clone(),
996                    right_col.clone(),
997                    seed,
998                    l_gt_r,
999                    r_gt_l,
1000                )?;
1001                // Descending tests
1002                // r < l - 10 AND l < r + 30
1003                let r_lt_l = -l_gt_r;
1004                let l_lt_r = -r_gt_l;
1005                $GENERATE_CASE_FUNC_NAME::<false>(
1006                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1007                )
1008            }
1009        };
1010    }
1011
1012    integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
1013    integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
1014    integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
1015    integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
1016
1017    macro_rules! integer_float_case_2 {
1018        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1019            #[rstest]
1020            #[test]
1021            fn $TEST_FUNC_NAME(
1022                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1023                seed: u64,
1024                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1025                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1026            ) -> Result<()> {
1027                let left_col = Arc::new(Column::new("left_watermark", 0));
1028                let right_col = Arc::new(Column::new("right_watermark", 0));
1029
1030                // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10
1031                let expr = gen_conjunctive_numerical_expr(
1032                    left_col.clone(),
1033                    right_col.clone(),
1034                    (
1035                        Operator::Minus,
1036                        Operator::Plus,
1037                        Operator::Plus,
1038                        Operator::Plus,
1039                    ),
1040                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
1041                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1042                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1043                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1044                    (greater_op, less_op),
1045                );
1046                // l > r + 6 AND r > l - 7
1047                let l_gt_r = 6 as $TYPE;
1048                let r_gt_l = -7 as $TYPE;
1049                $GENERATE_CASE_FUNC_NAME::<true>(
1050                    expr.clone(),
1051                    left_col.clone(),
1052                    right_col.clone(),
1053                    seed,
1054                    l_gt_r,
1055                    r_gt_l,
1056                )?;
1057                // Descending tests
1058                // r < l - 6 AND l < r + 7
1059                let r_lt_l = -l_gt_r;
1060                let l_lt_r = -r_gt_l;
1061                $GENERATE_CASE_FUNC_NAME::<false>(
1062                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1063                )
1064            }
1065        };
1066    }
1067
1068    integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
1069    integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
1070    integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
1071    integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
1072
1073    macro_rules! integer_float_case_3 {
1074        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1075            #[rstest]
1076            #[test]
1077            fn $TEST_FUNC_NAME(
1078                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1079                seed: u64,
1080                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1081                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1082            ) -> Result<()> {
1083                let left_col = Arc::new(Column::new("left_watermark", 0));
1084                let right_col = Arc::new(Column::new("right_watermark", 0));
1085
1086                // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10
1087                let expr = gen_conjunctive_numerical_expr(
1088                    left_col.clone(),
1089                    right_col.clone(),
1090                    (
1091                        Operator::Minus,
1092                        Operator::Plus,
1093                        Operator::Minus,
1094                        Operator::Plus,
1095                    ),
1096                    ScalarValue::$SCALAR(Some(1 as $TYPE)),
1097                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1098                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1099                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1100                    (greater_op, less_op),
1101                );
1102                // l > r + 6 AND r > l - 13
1103                let l_gt_r = 6 as $TYPE;
1104                let r_gt_l = -13 as $TYPE;
1105                $GENERATE_CASE_FUNC_NAME::<true>(
1106                    expr.clone(),
1107                    left_col.clone(),
1108                    right_col.clone(),
1109                    seed,
1110                    l_gt_r,
1111                    r_gt_l,
1112                )?;
1113                // Descending tests
1114                // r < l - 6 AND l < r + 13
1115                let r_lt_l = -l_gt_r;
1116                let l_lt_r = -r_gt_l;
1117                $GENERATE_CASE_FUNC_NAME::<false>(
1118                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1119                )
1120            }
1121        };
1122    }
1123
1124    integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
1125    integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
1126    integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
1127    integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
1128
1129    macro_rules! integer_float_case_4 {
1130        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1131            #[rstest]
1132            #[test]
1133            fn $TEST_FUNC_NAME(
1134                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1135                seed: u64,
1136                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1137                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1138            ) -> Result<()> {
1139                let left_col = Arc::new(Column::new("left_watermark", 0));
1140                let right_col = Arc::new(Column::new("right_watermark", 0));
1141
1142                // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1143                let expr = gen_conjunctive_numerical_expr(
1144                    left_col.clone(),
1145                    right_col.clone(),
1146                    (
1147                        Operator::Minus,
1148                        Operator::Minus,
1149                        Operator::Minus,
1150                        Operator::Plus,
1151                    ),
1152                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1153                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1154                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1155                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1156                    (greater_op, less_op),
1157                );
1158                // l > r + 5 AND r > l - 13
1159                let l_gt_r = 5 as $TYPE;
1160                let r_gt_l = -13 as $TYPE;
1161                $GENERATE_CASE_FUNC_NAME::<true>(
1162                    expr.clone(),
1163                    left_col.clone(),
1164                    right_col.clone(),
1165                    seed,
1166                    l_gt_r,
1167                    r_gt_l,
1168                )?;
1169                // Descending tests
1170                // r < l - 5 AND l < r + 13
1171                let r_lt_l = -l_gt_r;
1172                let l_lt_r = -r_gt_l;
1173                $GENERATE_CASE_FUNC_NAME::<false>(
1174                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1175                )
1176            }
1177        };
1178    }
1179
1180    integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
1181    integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
1182    integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
1183    integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
1184
1185    macro_rules! integer_float_case_5 {
1186        ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1187            #[rstest]
1188            #[test]
1189            fn $TEST_FUNC_NAME(
1190                #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1191                seed: u64,
1192                #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1193                #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1194            ) -> Result<()> {
1195                let left_col = Arc::new(Column::new("left_watermark", 0));
1196                let right_col = Arc::new(Column::new("right_watermark", 0));
1197
1198                // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1199                let expr = gen_conjunctive_numerical_expr(
1200                    left_col.clone(),
1201                    right_col.clone(),
1202                    (
1203                        Operator::Minus,
1204                        Operator::Minus,
1205                        Operator::Minus,
1206                        Operator::Minus,
1207                    ),
1208                    ScalarValue::$SCALAR(Some(10 as $TYPE)),
1209                    ScalarValue::$SCALAR(Some(5 as $TYPE)),
1210                    ScalarValue::$SCALAR(Some(30 as $TYPE)),
1211                    ScalarValue::$SCALAR(Some(3 as $TYPE)),
1212                    (greater_op, less_op),
1213                );
1214                // l > r + 5 AND r > l - 27
1215                let l_gt_r = 5 as $TYPE;
1216                let r_gt_l = -27 as $TYPE;
1217                $GENERATE_CASE_FUNC_NAME::<true>(
1218                    expr.clone(),
1219                    left_col.clone(),
1220                    right_col.clone(),
1221                    seed,
1222                    l_gt_r,
1223                    r_gt_l,
1224                )?;
1225                // Descending tests
1226                // r < l - 5 AND l < r + 27
1227                let r_lt_l = -l_gt_r;
1228                let l_lt_r = -r_gt_l;
1229                $GENERATE_CASE_FUNC_NAME::<false>(
1230                    expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1231                )
1232            }
1233        };
1234    }
1235
1236    integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
1237    integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
1238    integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
1239    integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
1240
1241    #[test]
1242    fn test_gather_node_indices_dont_remove() -> Result<()> {
1243        // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1.
1244        // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also
1245        // depends on leaf nodes a@0 and b@1.
1246        let left_expr = Arc::new(BinaryExpr::new(
1247            Arc::new(BinaryExpr::new(
1248                Arc::new(Column::new("a", 0)),
1249                Operator::Plus,
1250                Arc::new(Column::new("b", 1)),
1251            )),
1252            Operator::Plus,
1253            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1254        ));
1255
1256        let right_expr = Arc::new(BinaryExpr::new(
1257            Arc::new(Column::new("a", 0)),
1258            Operator::Minus,
1259            Arc::new(Column::new("b", 1)),
1260        ));
1261        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1262        let mut graph = ExprIntervalGraph::try_new(
1263            expr,
1264            &Schema::new(vec![
1265                Field::new("a", DataType::Int32, true),
1266                Field::new("b", DataType::Int32, true),
1267            ]),
1268        )
1269        .unwrap();
1270        // Define a test leaf node.
1271        let leaf_node = Arc::new(BinaryExpr::new(
1272            Arc::new(Column::new("a", 0)),
1273            Operator::Plus,
1274            Arc::new(Column::new("b", 1)),
1275        ));
1276        // Store the current node count.
1277        let prev_node_count = graph.node_count();
1278        // Gather the index of node in the expression graph that match the test leaf node.
1279        graph.gather_node_indices(&[leaf_node]);
1280        // Store the final node count.
1281        let final_node_count = graph.node_count();
1282        // Assert that the final node count is equal the previous node count.
1283        // This means we did not remove any node.
1284        assert_eq!(prev_node_count, final_node_count);
1285        Ok(())
1286    }
1287
1288    #[test]
1289    fn test_gather_node_indices_remove() -> Result<()> {
1290        // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1.
1291        // We expect to remove two nodes since we do not need a@ and b@.
1292        let left_expr = Arc::new(BinaryExpr::new(
1293            Arc::new(BinaryExpr::new(
1294                Arc::new(Column::new("a", 0)),
1295                Operator::Plus,
1296                Arc::new(Column::new("b", 1)),
1297            )),
1298            Operator::Plus,
1299            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1300        ));
1301
1302        let right_expr = Arc::new(BinaryExpr::new(
1303            Arc::new(Column::new("y", 0)),
1304            Operator::Minus,
1305            Arc::new(Column::new("z", 1)),
1306        ));
1307        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1308        let mut graph = ExprIntervalGraph::try_new(
1309            expr,
1310            &Schema::new(vec![
1311                Field::new("a", DataType::Int32, true),
1312                Field::new("b", DataType::Int32, true),
1313                Field::new("y", DataType::Int32, true),
1314                Field::new("z", DataType::Int32, true),
1315            ]),
1316        )
1317        .unwrap();
1318        // Define a test leaf node.
1319        let leaf_node = Arc::new(BinaryExpr::new(
1320            Arc::new(Column::new("a", 0)),
1321            Operator::Plus,
1322            Arc::new(Column::new("b", 1)),
1323        ));
1324        // Store the current node count.
1325        let prev_node_count = graph.node_count();
1326        // Gather the index of node in the expression graph that match the test leaf node.
1327        graph.gather_node_indices(&[leaf_node]);
1328        // Store the final node count.
1329        let final_node_count = graph.node_count();
1330        // Assert that the final node count is two less than the previous node
1331        // count; i.e. that we did remove two nodes.
1332        assert_eq!(prev_node_count, final_node_count + 2);
1333        Ok(())
1334    }
1335
1336    #[test]
1337    fn test_gather_node_indices_remove_one() -> Result<()> {
1338        // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1.
1339        // We expect to remove one nodesince we still need a@ but not b@.
1340        let left_expr = Arc::new(BinaryExpr::new(
1341            Arc::new(BinaryExpr::new(
1342                Arc::new(Column::new("a", 0)),
1343                Operator::Plus,
1344                Arc::new(Column::new("b", 1)),
1345            )),
1346            Operator::Plus,
1347            Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1348        ));
1349
1350        let right_expr = Arc::new(BinaryExpr::new(
1351            Arc::new(Column::new("a", 0)),
1352            Operator::Minus,
1353            Arc::new(Column::new("z", 1)),
1354        ));
1355        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1356        let mut graph = ExprIntervalGraph::try_new(
1357            expr,
1358            &Schema::new(vec![
1359                Field::new("a", DataType::Int32, true),
1360                Field::new("b", DataType::Int32, true),
1361                Field::new("z", DataType::Int32, true),
1362            ]),
1363        )
1364        .unwrap();
1365        // Define a test leaf node.
1366        let leaf_node = Arc::new(BinaryExpr::new(
1367            Arc::new(Column::new("a", 0)),
1368            Operator::Plus,
1369            Arc::new(Column::new("b", 1)),
1370        ));
1371        // Store the current node count.
1372        let prev_node_count = graph.node_count();
1373        // Gather the index of node in the expression graph that match the test leaf node.
1374        graph.gather_node_indices(&[leaf_node]);
1375        // Store the final node count.
1376        let final_node_count = graph.node_count();
1377        // Assert that the final node count is one less than the previous node
1378        // count; i.e. that we did remove two nodes.
1379        assert_eq!(prev_node_count, final_node_count + 1);
1380        Ok(())
1381    }
1382
1383    #[test]
1384    fn test_gather_node_indices_cannot_provide() -> Result<()> {
1385        // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1
1386        // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node.
1387        // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions.
1388        // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches.
1389        // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future.
1390        let left_expr = Arc::new(BinaryExpr::new(
1391            Arc::new(BinaryExpr::new(
1392                Arc::new(Column::new("a", 0)),
1393                Operator::Plus,
1394                Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1395            )),
1396            Operator::Plus,
1397            Arc::new(Column::new("b", 1)),
1398        ));
1399
1400        let right_expr = Arc::new(BinaryExpr::new(
1401            Arc::new(Column::new("y", 0)),
1402            Operator::Minus,
1403            Arc::new(Column::new("z", 1)),
1404        ));
1405        let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1406        let mut graph = ExprIntervalGraph::try_new(
1407            expr,
1408            &Schema::new(vec![
1409                Field::new("a", DataType::Int32, true),
1410                Field::new("b", DataType::Int32, true),
1411                Field::new("y", DataType::Int32, true),
1412                Field::new("z", DataType::Int32, true),
1413            ]),
1414        )
1415        .unwrap();
1416        // Define a test leaf node.
1417        let leaf_node = Arc::new(BinaryExpr::new(
1418            Arc::new(Column::new("a", 0)),
1419            Operator::Plus,
1420            Arc::new(Column::new("b", 1)),
1421        ));
1422        // Store the current node count.
1423        let prev_node_count = graph.node_count();
1424        // Gather the index of node in the expression graph that match the test leaf node.
1425        graph.gather_node_indices(&[leaf_node]);
1426        // Store the final node count.
1427        let final_node_count = graph.node_count();
1428        // Assert that the final node count is equal the previous node count (i.e., no node was pruned).
1429        assert_eq!(prev_node_count, final_node_count);
1430        Ok(())
1431    }
1432
1433    #[test]
1434    fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> {
1435        let expression = BinaryExpr::new(
1436            Arc::new(Column::new("ts_column", 0)),
1437            Operator::Plus,
1438            Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))),
1439        );
1440        let parent = Interval::try_new(
1441            // 15.10.2020 - 10:11:12.000_000_321 AM
1442            ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None),
1443            // 16.10.2020 - 10:11:12.000_000_321 AM
1444            ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None),
1445        )?;
1446        let left_child = Interval::try_new(
1447            // 10.10.2020 - 10:11:12 AM
1448            ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None),
1449            // 20.10.2020 - 10:11:12 AM
1450            ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None),
1451        )?;
1452        let right_child = Interval::try_new(
1453            // 1 day 321 ns
1454            ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1455                months: 0,
1456                days: 1,
1457                nanoseconds: 321,
1458            })),
1459            // 1 day 321 ns
1460            ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1461                months: 0,
1462                days: 1,
1463                nanoseconds: 321,
1464            })),
1465        )?;
1466        let children = vec![&left_child, &right_child];
1467        let result = expression
1468            .propagate_constraints(&parent, &children)?
1469            .unwrap();
1470
1471        assert_eq!(
1472            vec![
1473                Interval::try_new(
1474                    // 14.10.2020 - 10:11:12 AM
1475                    ScalarValue::TimestampNanosecond(
1476                        Some(1_602_670_272_000_000_000),
1477                        None
1478                    ),
1479                    // 15.10.2020 - 10:11:12 AM
1480                    ScalarValue::TimestampNanosecond(
1481                        Some(1_602_756_672_000_000_000),
1482                        None
1483                    ),
1484                )?,
1485                Interval::try_new(
1486                    // 1 day 321 ns in Duration type
1487                    ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1488                        months: 0,
1489                        days: 1,
1490                        nanoseconds: 321,
1491                    })),
1492                    // 1 day 321 ns in Duration type
1493                    ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1494                        months: 0,
1495                        days: 1,
1496                        nanoseconds: 321,
1497                    })),
1498                )?
1499            ],
1500            result
1501        );
1502
1503        Ok(())
1504    }
1505
1506    #[test]
1507    fn test_propagate_constraints_column_interval_at_left() -> Result<()> {
1508        let expression = BinaryExpr::new(
1509            Arc::new(Column::new("interval_column", 1)),
1510            Operator::Plus,
1511            Arc::new(Column::new("ts_column", 0)),
1512        );
1513        let parent = Interval::try_new(
1514            // 15.10.2020 - 10:11:12 AM
1515            ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None),
1516            // 16.10.2020 - 10:11:12 AM
1517            ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None),
1518        )?;
1519        let right_child = Interval::try_new(
1520            // 10.10.2020 - 10:11:12 AM
1521            ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1522            // 20.10.2020 - 10:11:12 AM
1523            ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None),
1524        )?;
1525        let left_child = Interval::try_new(
1526            // 2 days in millisecond
1527            ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1528                days: 0,
1529                milliseconds: 172_800_000,
1530            })),
1531            // 10 days in millisecond
1532            ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1533                days: 0,
1534                milliseconds: 864_000_000,
1535            })),
1536        )?;
1537        let children = vec![&left_child, &right_child];
1538        let result = expression
1539            .propagate_constraints(&parent, &children)?
1540            .unwrap();
1541
1542        assert_eq!(
1543            vec![
1544                Interval::try_new(
1545                    // 2 days in millisecond
1546                    ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1547                        days: 0,
1548                        milliseconds: 172_800_000,
1549                    })),
1550                    // 6 days
1551                    ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1552                        days: 0,
1553                        milliseconds: 518_400_000,
1554                    })),
1555                )?,
1556                Interval::try_new(
1557                    // 10.10.2020 - 10:11:12 AM
1558                    ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1559                    // 14.10.2020 - 10:11:12 AM
1560                    ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None),
1561                )?
1562            ],
1563            result
1564        );
1565
1566        Ok(())
1567    }
1568
1569    #[test]
1570    fn test_propagate_comparison() -> Result<()> {
1571        // In the examples below:
1572        // `left` is unbounded: [?, ?],
1573        // `right` is known to be [1000,1000]
1574        // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999]
1575        let left = Interval::make_unbounded(&DataType::Int64)?;
1576        let right = Interval::make(Some(1000_i64), Some(1000_i64))?;
1577        assert_eq!(
1578            (Some((
1579                Interval::make(None, Some(999_i64))?,
1580                Interval::make(Some(1000_i64), Some(1000_i64))?,
1581            ))),
1582            propagate_comparison(
1583                &Operator::Lt,
1584                &Interval::CERTAINLY_TRUE,
1585                &left,
1586                &right
1587            )?
1588        );
1589
1590        let left =
1591            Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?;
1592        let right = Interval::try_new(
1593            ScalarValue::TimestampNanosecond(Some(1000), None),
1594            ScalarValue::TimestampNanosecond(Some(1000), None),
1595        )?;
1596        assert_eq!(
1597            (Some((
1598                Interval::try_new(
1599                    ScalarValue::try_from(&DataType::Timestamp(
1600                        TimeUnit::Nanosecond,
1601                        None
1602                    ))
1603                    .unwrap(),
1604                    ScalarValue::TimestampNanosecond(Some(999), None),
1605                )?,
1606                Interval::try_new(
1607                    ScalarValue::TimestampNanosecond(Some(1000), None),
1608                    ScalarValue::TimestampNanosecond(Some(1000), None),
1609                )?
1610            ))),
1611            propagate_comparison(
1612                &Operator::Lt,
1613                &Interval::CERTAINLY_TRUE,
1614                &left,
1615                &right
1616            )?
1617        );
1618
1619        let left = Interval::make_unbounded(&DataType::Timestamp(
1620            TimeUnit::Nanosecond,
1621            Some("+05:00".into()),
1622        ))?;
1623        let right = Interval::try_new(
1624            ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1625            ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1626        )?;
1627        assert_eq!(
1628            (Some((
1629                Interval::try_new(
1630                    ScalarValue::try_from(&DataType::Timestamp(
1631                        TimeUnit::Nanosecond,
1632                        Some("+05:00".into()),
1633                    ))
1634                    .unwrap(),
1635                    ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())),
1636                )?,
1637                Interval::try_new(
1638                    ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1639                    ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1640                )?
1641            ))),
1642            propagate_comparison(
1643                &Operator::Lt,
1644                &Interval::CERTAINLY_TRUE,
1645                &left,
1646                &right
1647            )?
1648        );
1649
1650        Ok(())
1651    }
1652
1653    #[test]
1654    fn test_propagate_or() -> Result<()> {
1655        let expr = Arc::new(BinaryExpr::new(
1656            Arc::new(Column::new("a", 0)),
1657            Operator::Or,
1658            Arc::new(Column::new("b", 1)),
1659        ));
1660        let parent = Interval::CERTAINLY_FALSE;
1661        let children_set = vec![
1662            vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1663            vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE],
1664            vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE],
1665            vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1666        ];
1667        for children in children_set {
1668            assert_eq!(
1669                expr.propagate_constraints(&parent, &children)?.unwrap(),
1670                vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE],
1671            );
1672        }
1673
1674        let parent = Interval::CERTAINLY_FALSE;
1675        let children_set = vec![
1676            vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1677            vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1678        ];
1679        for children in children_set {
1680            assert_eq!(expr.propagate_constraints(&parent, &children)?, None,);
1681        }
1682
1683        let parent = Interval::CERTAINLY_TRUE;
1684        let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN];
1685        assert_eq!(
1686            expr.propagate_constraints(&parent, &children)?.unwrap(),
1687            vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE]
1688        );
1689
1690        let parent = Interval::CERTAINLY_TRUE;
1691        let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN];
1692        assert_eq!(
1693            expr.propagate_constraints(&parent, &children)?.unwrap(),
1694            // Empty means unchanged intervals.
1695            vec![]
1696        );
1697
1698        Ok(())
1699    }
1700
1701    #[test]
1702    fn test_propagate_certainly_false_and() -> Result<()> {
1703        let expr = Arc::new(BinaryExpr::new(
1704            Arc::new(Column::new("a", 0)),
1705            Operator::And,
1706            Arc::new(Column::new("b", 1)),
1707        ));
1708        let parent = Interval::CERTAINLY_FALSE;
1709        let children_and_results_set = vec![
1710            (
1711                vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1712                vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE],
1713            ),
1714            (
1715                vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1716                vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE],
1717            ),
1718            (
1719                vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1720                // Empty means unchanged intervals.
1721                vec![],
1722            ),
1723            (
1724                vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1725                vec![],
1726            ),
1727        ];
1728        for (children, result) in children_and_results_set {
1729            assert_eq!(
1730                expr.propagate_constraints(&parent, &children)?.unwrap(),
1731                result
1732            );
1733        }
1734
1735        Ok(())
1736    }
1737}