datafusion_physical_expr/intervals/cp_solver.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Constraint propagator/solver for custom [`PhysicalExpr`] graphs.
19//!
20//! The constraint propagator/solver in DataFusion uses interval arithmetic to
21//! perform mathematical operations on intervals, which represent a range of
22//! possible values rather than a single point value. This allows for the
23//! propagation of ranges through mathematical operations, and can be used to
24//! compute bounds for a complicated expression. The key idea is that by
25//! breaking down a complicated expression into simpler terms, and then
26//! combining the bounds for those simpler terms, one can obtain bounds for the
27//! overall expression.
28//!
29//! This way of using interval arithmetic to compute bounds for a complex
30//! expression by combining the bounds for the constituent terms within the
31//! original expression allows us to reason about the range of possible values
32//! of the expression. This information later can be used in range pruning of
33//! the provably unnecessary parts of `RecordBatch`es.
34//!
35//! # Example
36//!
37//! For example, consider a mathematical expression such as `x^2 + y = 4` \[1\].
38//! Since this expression would be a binary tree in [`PhysicalExpr`] notation,
39//! this type of an hierarchical computation is well-suited for a graph based
40//! implementation. In such an implementation, an equation system `f(x) = 0` is
41//! represented by a directed acyclic expression graph (DAEG).
42//!
43//! In order to use interval arithmetic to compute bounds for this expression,
44//! one would first determine intervals that represent the possible values of
45//! `x` and `y` Let's say that the interval for `x` is `[1, 2]` and the interval
46//! for `y` is `[-3, 1]`. In the chart below, you can see how the computation
47//! takes place.
48//!
49//! # References
50//!
51//! 1. Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval
52//! Arithmetic Based Approach, Chapter 4. Stanford University, 2015.
53//! 2. Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966.
54//! 3. F. Messine, "Deterministic global optimization using interval constraint
55//! propagation techniques," RAIRO-Operations Research, vol. 38, no. 04,
56//! pp. 277-293, 2004.
57//!
58//! # Illustration
59//!
60//! ## Computing bounds for an expression using interval arithmetic
61//!
62//! ```text
63//! +-----+ +-----+
64//! +----| + |----+ +----| + |----+
65//! | | | | | | | |
66//! | +-----+ | | +-----+ |
67//! | | | |
68//! +-----+ +-----+ +-----+ +-----+
69//! | 2 | | y | | 2 | [1, 4] | y |
70//! |[.] | | | |[.] | | |
71//! +-----+ +-----+ +-----+ +-----+
72//! | |
73//! | |
74//! +---+ +---+
75//! | x | [1, 2] | x | [1, 2]
76//! +---+ +---+
77//!
78//! (a) Bottom-up evaluation: Step 1 (b) Bottom up evaluation: Step 2
79//!
80//! [1 - 3, 4 + 1] = [-2, 5]
81//! +-----+ +-----+
82//! +----| + |----+ +----| + |----+
83//! | | | | | | | |
84//! | +-----+ | | +-----+ |
85//! | | | |
86//! +-----+ +-----+ +-----+ +-----+
87//! | 2 |[1, 4] | y | | 2 |[1, 4] | y |
88//! |[.] | | | |[.] | | |
89//! +-----+ +-----+ +-----+ +-----+
90//! | [-3, 1] | [-3, 1]
91//! | |
92//! +---+ +---+
93//! | x | [1, 2] | x | [1, 2]
94//! +---+ +---+
95//!
96//! (c) Bottom-up evaluation: Step 3 (d) Bottom-up evaluation: Step 4
97//! ```
98//!
99//! ## Top-down constraint propagation using inverse semantics
100//!
101//! ```text
102//! [-2, 5] ∩ [4, 4] = [4, 4] [4, 4]
103//! +-----+ +-----+
104//! +----| + |----+ +----| + |----+
105//! | | | | | | | |
106//! | +-----+ | | +-----+ |
107//! | | | |
108//! +-----+ +-----+ +-----+ +-----+
109//! | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]*
110//! |[.] | | | |[.] | | |
111//! +-----+ +-----+ +-----+ +-----+
112//! | [-3, 1] |
113//! | |
114//! +---+ +---+
115//! | x | [1, 2] | x | [1, 2]
116//! +---+ +---+
117//!
118//! (a) Top-down propagation: Step 1 (b) Top-down propagation: Step 2
119//!
120//! [1 - 3, 4 + 1] = [-2, 5]
121//! +-----+ +-----+
122//! +----| + |----+ +----| + |----+
123//! | | | | | | | |
124//! | +-----+ | | +-----+ |
125//! | | | |
126//! +-----+ +-----+ +-----+ +-----+
127//! | 2 |[3, 4]** | y | | 2 |[3, 4] | y |
128//! |[.] | | | |[.] | | |
129//! +-----+ +-----+ +-----+ +-----+
130//! | [0, 1] | [-3, 1]
131//! | |
132//! +---+ +---+
133//! | x | [1, 2] | x | [sqrt(3), 2]***
134//! +---+ +---+
135//!
136//! (c) Top-down propagation: Step 3 (d) Top-down propagation: Step 4
137//!
138//! * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1]
139//! ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4]
140//! *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2]
141//! ```
142
143use std::collections::HashSet;
144use std::fmt::{Display, Formatter};
145use std::mem::{size_of, size_of_val};
146use std::sync::Arc;
147
148use super::utils::{
149 convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
150};
151use crate::expressions::{BinaryExpr, Literal};
152use crate::utils::{build_dag, ExprTreeNode};
153use crate::PhysicalExpr;
154
155use arrow::datatypes::{DataType, Schema};
156use datafusion_common::{internal_err, not_impl_err, Result};
157use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval};
158use datafusion_expr::Operator;
159
160use petgraph::graph::NodeIndex;
161use petgraph::stable_graph::{DefaultIx, StableGraph};
162use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef};
163use petgraph::Outgoing;
164
165/// This object implements a directed acyclic expression graph (DAEG) that
166/// is used to compute ranges for expressions through interval arithmetic.
167#[derive(Clone, Debug)]
168pub struct ExprIntervalGraph {
169 graph: StableGraph<ExprIntervalGraphNode, usize>,
170 root: NodeIndex,
171}
172
173/// This object encapsulates all possible constraint propagation results.
174#[derive(PartialEq, Debug)]
175pub enum PropagationResult {
176 CannotPropagate,
177 Infeasible,
178 Success,
179}
180
181/// This is a node in the DAEG; it encapsulates a reference to the actual
182/// [`PhysicalExpr`] as well as an interval containing expression bounds.
183#[derive(Clone, Debug)]
184pub struct ExprIntervalGraphNode {
185 expr: Arc<dyn PhysicalExpr>,
186 interval: Interval,
187}
188
189impl PartialEq for ExprIntervalGraphNode {
190 fn eq(&self, other: &Self) -> bool {
191 self.expr.eq(&other.expr)
192 }
193}
194
195impl Display for ExprIntervalGraphNode {
196 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
197 write!(f, "{}", self.expr)
198 }
199}
200
201impl ExprIntervalGraphNode {
202 /// Constructs a new DAEG node with an `[-∞, ∞]` range.
203 pub fn new_unbounded(expr: Arc<dyn PhysicalExpr>, dt: &DataType) -> Result<Self> {
204 Interval::make_unbounded(dt)
205 .map(|interval| ExprIntervalGraphNode { expr, interval })
206 }
207
208 /// Constructs a new DAEG node with the given range.
209 pub fn new_with_interval(expr: Arc<dyn PhysicalExpr>, interval: Interval) -> Self {
210 ExprIntervalGraphNode { expr, interval }
211 }
212
213 /// Get the interval object representing the range of the expression.
214 pub fn interval(&self) -> &Interval {
215 &self.interval
216 }
217
218 /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`]
219 /// object. Literals are created with definite, singleton intervals while
220 /// any other expression starts with an indefinite interval (`[-∞, ∞]`).
221 pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
222 let expr = Arc::clone(&node.expr);
223 if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
224 let value = literal.value();
225 Interval::try_new(value.clone(), value.clone())
226 .map(|interval| Self::new_with_interval(expr, interval))
227 } else {
228 expr.data_type(schema)
229 .and_then(|dt| Self::new_unbounded(expr, &dt))
230 }
231 }
232}
233
234/// This function refines intervals `left_child` and `right_child` by applying
235/// constraint propagation through `parent` via operation. The main idea is
236/// that we can shrink ranges of variables x and y using parent interval p.
237///
238/// Assuming that x,y and p has ranges `[xL, xU]`, `[yL, yU]`, and `[pL, pU]`, we
239/// apply the following operations:
240/// - For plus operation, specifically, we would first do
241/// - `[xL, xU]` <- (`[pL, pU]` - `[yL, yU]`) ∩ `[xL, xU]`, and then
242/// - `[yL, yU]` <- (`[pL, pU]` - `[xL, xU]`) ∩ `[yL, yU]`.
243/// - For minus operation, specifically, we would first do
244/// - `[xL, xU]` <- (`[yL, yU]` + `[pL, pU]`) ∩ `[xL, xU]`, and then
245/// - `[yL, yU]` <- (`[xL, xU]` - `[pL, pU]`) ∩ `[yL, yU]`.
246/// - For multiplication operation, specifically, we would first do
247/// - `[xL, xU]` <- (`[pL, pU]` / `[yL, yU]`) ∩ `[xL, xU]`, and then
248/// - `[yL, yU]` <- (`[pL, pU]` / `[xL, xU]`) ∩ `[yL, yU]`.
249/// - For division operation, specifically, we would first do
250/// - `[xL, xU]` <- (`[yL, yU]` * `[pL, pU]`) ∩ `[xL, xU]`, and then
251/// - `[yL, yU]` <- (`[xL, xU]` / `[pL, pU]`) ∩ `[yL, yU]`.
252pub fn propagate_arithmetic(
253 op: &Operator,
254 parent: &Interval,
255 left_child: &Interval,
256 right_child: &Interval,
257) -> Result<Option<(Interval, Interval)>> {
258 let inverse_op = get_inverse_op(*op)?;
259 match (left_child.data_type(), right_child.data_type()) {
260 // If we have a child whose type is a time interval (i.e. DataType::Interval),
261 // we need special handling since timestamp differencing results in a
262 // Duration type.
263 (DataType::Timestamp(..), DataType::Interval(_)) => {
264 propagate_time_interval_at_right(
265 left_child,
266 right_child,
267 parent,
268 op,
269 &inverse_op,
270 )
271 }
272 (DataType::Interval(_), DataType::Timestamp(..)) => {
273 propagate_time_interval_at_left(
274 left_child,
275 right_child,
276 parent,
277 op,
278 &inverse_op,
279 )
280 }
281 _ => {
282 // First, propagate to the left:
283 match apply_operator(&inverse_op, parent, right_child)?
284 .intersect(left_child)?
285 {
286 // Left is feasible:
287 Some(value) => Ok(
288 // Propagate to the right using the new left.
289 propagate_right(&value, parent, right_child, op, &inverse_op)?
290 .map(|right| (value, right)),
291 ),
292 // If the left child is infeasible, short-circuit.
293 None => Ok(None),
294 }
295 }
296 }
297}
298
299/// This function refines intervals `left_child` and `right_child` by applying
300/// comparison propagation through `parent` via operation. The main idea is
301/// that we can shrink ranges of variables x and y using parent interval p.
302/// Two intervals can be ordered in 6 ways for a Gt `>` operator:
303/// ```text
304/// (1): Infeasible, short-circuit
305/// left: | ================ |
306/// right: | ======================== |
307///
308/// (2): Update both interval
309/// left: | ====================== |
310/// right: | ====================== |
311/// |
312/// V
313/// left: | ======= |
314/// right: | ======= |
315///
316/// (3): Update left interval
317/// left: | ============================== |
318/// right: | ========== |
319/// |
320/// V
321/// left: | ===================== |
322/// right: | ========== |
323///
324/// (4): Update right interval
325/// left: | ========== |
326/// right: | =========================== |
327/// |
328/// V
329/// left: | ========== |
330/// right | ================== |
331///
332/// (5): No change
333/// left: | ============================ |
334/// right: | =================== |
335///
336/// (6): No change
337/// left: | ==================== |
338/// right: | =============== |
339///
340/// -inf --------------------------------------------------------------- +inf
341/// ```
342pub fn propagate_comparison(
343 op: &Operator,
344 parent: &Interval,
345 left_child: &Interval,
346 right_child: &Interval,
347) -> Result<Option<(Interval, Interval)>> {
348 if parent == &Interval::CERTAINLY_TRUE {
349 match op {
350 Operator::Eq => left_child.intersect(right_child).map(|result| {
351 result.map(|intersection| (intersection.clone(), intersection))
352 }),
353 Operator::Gt => satisfy_greater(left_child, right_child, true),
354 Operator::GtEq => satisfy_greater(left_child, right_child, false),
355 Operator::Lt => satisfy_greater(right_child, left_child, true)
356 .map(|t| t.map(reverse_tuple)),
357 Operator::LtEq => satisfy_greater(right_child, left_child, false)
358 .map(|t| t.map(reverse_tuple)),
359 _ => internal_err!(
360 "The operator must be a comparison operator to propagate intervals"
361 ),
362 }
363 } else if parent == &Interval::CERTAINLY_FALSE {
364 match op {
365 Operator::Eq => {
366 // TODO: Propagation is not possible until we support interval sets.
367 Ok(None)
368 }
369 Operator::Gt => satisfy_greater(right_child, left_child, false),
370 Operator::GtEq => satisfy_greater(right_child, left_child, true),
371 Operator::Lt => satisfy_greater(left_child, right_child, false)
372 .map(|t| t.map(reverse_tuple)),
373 Operator::LtEq => satisfy_greater(left_child, right_child, true)
374 .map(|t| t.map(reverse_tuple)),
375 _ => internal_err!(
376 "The operator must be a comparison operator to propagate intervals"
377 ),
378 }
379 } else {
380 // Uncertainty cannot change any end-point of the intervals.
381 Ok(None)
382 }
383}
384
385impl ExprIntervalGraph {
386 pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
387 // Build the full graph:
388 let (root, graph) =
389 build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?;
390 Ok(Self { graph, root })
391 }
392
393 pub fn node_count(&self) -> usize {
394 self.graph.node_count()
395 }
396
397 /// Estimate size of bytes including `Self`.
398 pub fn size(&self) -> usize {
399 let node_memory_usage = self.graph.node_count()
400 * (size_of::<ExprIntervalGraphNode>() + size_of::<NodeIndex>());
401 let edge_memory_usage =
402 self.graph.edge_count() * (size_of::<usize>() + size_of::<NodeIndex>() * 2);
403
404 size_of_val(self) + node_memory_usage + edge_memory_usage
405 }
406
407 // Sometimes, we do not want to calculate and/or propagate intervals all
408 // way down to leaf expressions. For example, assume that we have a
409 // `SymmetricHashJoin` which has a child with an output ordering like:
410 //
411 // ```text
412 // PhysicalSortExpr {
413 // expr: BinaryExpr('a', +, 'b'),
414 // sort_option: ..
415 // }
416 // ```
417 //
418 // i.e. its output order comes from a clause like `ORDER BY a + b`. In such
419 // a case, we must calculate the interval for the `BinaryExpr(a, +, b)`
420 // instead of the columns inside this `BinaryExpr`, because this interval
421 // decides whether we prune or not. Therefore, children `PhysicalExpr`s of
422 // this `BinaryExpr` may be pruned for performance. The figure below
423 // explains this example visually.
424 //
425 // Note that we just remove the nodes from the DAEG, do not make any change
426 // to the plan itself.
427 //
428 // ```text
429 //
430 // +-----+ +-----+
431 // | GT | | GT |
432 // +--------| |-------+ +--------| |-------+
433 // | +-----+ | | +-----+ |
434 // | | | |
435 // +-----+ | +-----+ |
436 // |Cast | | |Cast | |
437 // | | | --\ | | |
438 // +-----+ | ---------- +-----+ |
439 // | | --/ | |
440 // | | | |
441 // +-----+ +-----+ +-----+ +-----+
442 // +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+
443 // | | | | | | | | | | | | | |
444 // Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ |
445 // ------------------------------------ | | | |
446 // | | | | | |
447 // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+
448 // | a | | b | | c | | 2 | | c | | 2 |
449 // | | | | | | | | | | | |
450 // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+
451 //
452 // ```
453
454 /// This function associates stable node indices with [`PhysicalExpr`]s so
455 /// that we can match `Arc<dyn PhysicalExpr>` and NodeIndex objects during
456 /// membership tests.
457 pub fn gather_node_indices(
458 &mut self,
459 exprs: &[Arc<dyn PhysicalExpr>],
460 ) -> Vec<(Arc<dyn PhysicalExpr>, usize)> {
461 let graph = &self.graph;
462 let mut bfs = Bfs::new(graph, self.root);
463 // We collect the node indices (usize) of [PhysicalExpr]s in the order
464 // given by argument `exprs`. To preserve this order, we initialize each
465 // expression's node index with usize::MAX, and then find the corresponding
466 // node indices by traversing the graph.
467 let mut removals = vec![];
468 let mut expr_node_indices = exprs
469 .iter()
470 .map(|e| (Arc::clone(e), usize::MAX))
471 .collect::<Vec<_>>();
472 while let Some(node) = bfs.next(graph) {
473 // Get the plan corresponding to this node:
474 let expr = &graph[node].expr;
475 // If the current expression is among `exprs`, slate its children
476 // for removal:
477 if let Some(value) = exprs.iter().position(|e| expr.eq(e)) {
478 // Update the node index of the associated `PhysicalExpr`:
479 expr_node_indices[value].1 = node.index();
480 for edge in graph.edges_directed(node, Outgoing) {
481 // Slate the child for removal, do not remove immediately.
482 removals.push(edge.id());
483 }
484 }
485 }
486 for edge_idx in removals {
487 self.graph.remove_edge(edge_idx);
488 }
489 // Get the set of node indices reachable from the root node:
490 let connected_nodes = self.connected_nodes();
491 // Remove nodes not connected to the root node:
492 self.graph
493 .retain_nodes(|_, index| connected_nodes.contains(&index));
494 expr_node_indices
495 }
496
497 /// Returns the set of node indices reachable from the root node via a
498 /// simple depth-first search.
499 fn connected_nodes(&self) -> HashSet<NodeIndex> {
500 let mut nodes = HashSet::new();
501 let mut dfs = Dfs::new(&self.graph, self.root);
502 while let Some(node) = dfs.next(&self.graph) {
503 nodes.insert(node);
504 }
505 nodes
506 }
507
508 /// Updates intervals for all expressions in the DAEG by successive
509 /// bottom-up and top-down traversals.
510 pub fn update_ranges(
511 &mut self,
512 leaf_bounds: &mut [(usize, Interval)],
513 given_range: Interval,
514 ) -> Result<PropagationResult> {
515 self.assign_intervals(leaf_bounds);
516 let bounds = self.evaluate_bounds()?;
517 // There are three possible cases to consider:
518 // (1) given_range ⊇ bounds => Nothing to propagate
519 // (2) ∅ ⊂ (given_range ∩ bounds) ⊂ bounds => Can propagate
520 // (3) Disjoint sets => Infeasible
521 if given_range.contains(bounds)? == Interval::CERTAINLY_TRUE {
522 // First case:
523 Ok(PropagationResult::CannotPropagate)
524 } else if bounds.contains(&given_range)? != Interval::CERTAINLY_FALSE {
525 // Second case:
526 let result = self.propagate_constraints(given_range);
527 self.update_intervals(leaf_bounds);
528 result
529 } else {
530 // Third case:
531 Ok(PropagationResult::Infeasible)
532 }
533 }
534
535 /// This function assigns given ranges to expressions in the DAEG.
536 /// The argument `assignments` associates indices of sought expressions
537 /// with their corresponding new ranges.
538 pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) {
539 for (index, interval) in assignments {
540 let node_index = NodeIndex::from(*index as DefaultIx);
541 self.graph[node_index].interval = interval.clone();
542 }
543 }
544
545 /// This function fetches ranges of expressions from the DAEG. The argument
546 /// `assignments` associates indices of sought expressions with their ranges,
547 /// which this function modifies to reflect the intervals in the DAEG.
548 pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) {
549 for (index, interval) in assignments.iter_mut() {
550 let node_index = NodeIndex::from(*index as DefaultIx);
551 *interval = self.graph[node_index].interval.clone();
552 }
553 }
554
555 /// Computes bounds for an expression using interval arithmetic via a
556 /// bottom-up traversal.
557 ///
558 /// # Examples
559 ///
560 /// ```
561 /// use arrow::datatypes::DataType;
562 /// use arrow::datatypes::Field;
563 /// use arrow::datatypes::Schema;
564 /// use datafusion_common::ScalarValue;
565 /// use datafusion_expr::interval_arithmetic::Interval;
566 /// use datafusion_expr::Operator;
567 /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
568 /// use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
569 /// use datafusion_physical_expr::PhysicalExpr;
570 /// use std::sync::Arc;
571 ///
572 /// let expr = Arc::new(BinaryExpr::new(
573 /// Arc::new(Column::new("gnz", 0)),
574 /// Operator::Plus,
575 /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
576 /// ));
577 ///
578 /// let schema = Schema::new(vec![Field::new("gnz".to_string(), DataType::Int32, true)]);
579 ///
580 /// let mut graph = ExprIntervalGraph::try_new(expr, &schema).unwrap();
581 /// // Do it once, while constructing.
582 /// let node_indices = graph
583 /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]);
584 /// let left_index = node_indices.get(0).unwrap().1;
585 ///
586 /// // Provide intervals for leaf variables (here, there is only one).
587 /// let intervals = vec![(
588 /// left_index,
589 /// Interval::make(Some(10), Some(20)).unwrap(),
590 /// )];
591 ///
592 /// // Evaluate bounds for the composite expression:
593 /// graph.assign_intervals(&intervals);
594 /// assert_eq!(
595 /// graph.evaluate_bounds().unwrap(),
596 /// &Interval::make(Some(20), Some(30)).unwrap(),
597 /// )
598 /// ```
599 pub fn evaluate_bounds(&mut self) -> Result<&Interval> {
600 let mut dfs = DfsPostOrder::new(&self.graph, self.root);
601 while let Some(node) = dfs.next(&self.graph) {
602 let neighbors = self.graph.neighbors_directed(node, Outgoing);
603 let mut children_intervals = neighbors
604 .map(|child| self.graph[child].interval())
605 .collect::<Vec<_>>();
606 // If the current expression is a leaf, its interval should already
607 // be set externally, just continue with the evaluation procedure:
608 if !children_intervals.is_empty() {
609 // Reverse to align with `PhysicalExpr`'s children:
610 children_intervals.reverse();
611 self.graph[node].interval =
612 self.graph[node].expr.evaluate_bounds(&children_intervals)?;
613 }
614 }
615 Ok(self.graph[self.root].interval())
616 }
617
618 /// Updates/shrinks bounds for leaf expressions using interval arithmetic
619 /// via a top-down traversal.
620 fn propagate_constraints(
621 &mut self,
622 given_range: Interval,
623 ) -> Result<PropagationResult> {
624 // Adjust the root node with the given range:
625 if let Some(interval) = self.graph[self.root].interval.intersect(given_range)? {
626 self.graph[self.root].interval = interval;
627 } else {
628 return Ok(PropagationResult::Infeasible);
629 }
630
631 let mut bfs = Bfs::new(&self.graph, self.root);
632
633 while let Some(node) = bfs.next(&self.graph) {
634 let neighbors = self.graph.neighbors_directed(node, Outgoing);
635 let mut children = neighbors.collect::<Vec<_>>();
636 // If the current expression is a leaf, its range is now final.
637 // So, just continue with the propagation procedure:
638 if children.is_empty() {
639 continue;
640 }
641 // Reverse to align with `PhysicalExpr`'s children:
642 children.reverse();
643 let children_intervals = children
644 .iter()
645 .map(|child| self.graph[*child].interval())
646 .collect::<Vec<_>>();
647 let node_interval = self.graph[node].interval();
648 // Special case: true OR could in principle be propagated by 3 interval sets,
649 // (i.e. left true, or right true, or both true) however we do not support this yet.
650 if node_interval == &Interval::CERTAINLY_TRUE
651 && self.graph[node]
652 .expr
653 .as_any()
654 .downcast_ref::<BinaryExpr>()
655 .is_some_and(|expr| expr.op() == &Operator::Or)
656 {
657 return not_impl_err!("OR operator cannot yet propagate true intervals");
658 }
659 let propagated_intervals = self.graph[node]
660 .expr
661 .propagate_constraints(node_interval, &children_intervals)?;
662 if let Some(propagated_intervals) = propagated_intervals {
663 for (child, interval) in children.into_iter().zip(propagated_intervals) {
664 self.graph[child].interval = interval;
665 }
666 } else {
667 // The constraint is infeasible, report:
668 return Ok(PropagationResult::Infeasible);
669 }
670 }
671 Ok(PropagationResult::Success)
672 }
673
674 /// Returns the interval associated with the node at the given `index`.
675 pub fn get_interval(&self, index: usize) -> Interval {
676 self.graph[NodeIndex::new(index)].interval.clone()
677 }
678}
679
680/// This is a subfunction of the `propagate_arithmetic` function that propagates to the right child.
681fn propagate_right(
682 left: &Interval,
683 parent: &Interval,
684 right: &Interval,
685 op: &Operator,
686 inverse_op: &Operator,
687) -> Result<Option<Interval>> {
688 match op {
689 Operator::Minus => apply_operator(op, left, parent),
690 Operator::Plus => apply_operator(inverse_op, parent, left),
691 Operator::Divide => apply_operator(op, left, parent),
692 Operator::Multiply => apply_operator(inverse_op, parent, left),
693 _ => internal_err!("Interval arithmetic does not support the operator {}", op),
694 }?
695 .intersect(right)
696}
697
698/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
699/// if there exists a `timestamp - timestamp` operation, the result would be
700/// of type `Duration`. However, we may encounter a situation where a time interval
701/// is involved in an arithmetic operation with a `Duration` type. This function
702/// offers special handling for such cases, where the time interval resides on
703/// the left side of the operation.
704fn propagate_time_interval_at_left(
705 left_child: &Interval,
706 right_child: &Interval,
707 parent: &Interval,
708 op: &Operator,
709 inverse_op: &Operator,
710) -> Result<Option<(Interval, Interval)>> {
711 // We check if the child's time interval(s) has a non-zero month or day field(s).
712 // If so, we return it as is without propagating. Otherwise, we first convert
713 // the time intervals to the `Duration` type, then propagate, and then convert
714 // the bounds to time intervals again.
715 let result = if let Some(duration) = convert_interval_type_to_duration(left_child) {
716 match apply_operator(inverse_op, parent, right_child)?.intersect(duration)? {
717 Some(value) => {
718 let left = convert_duration_type_to_interval(&value);
719 let right = propagate_right(&value, parent, right_child, op, inverse_op)?;
720 match (left, right) {
721 (Some(left), Some(right)) => Some((left, right)),
722 _ => None,
723 }
724 }
725 None => None,
726 }
727 } else {
728 propagate_right(left_child, parent, right_child, op, inverse_op)?
729 .map(|right| (left_child.clone(), right))
730 };
731 Ok(result)
732}
733
734/// During the propagation of [`Interval`] values on an [`ExprIntervalGraph`],
735/// if there exists a `timestamp - timestamp` operation, the result would be
736/// of type `Duration`. However, we may encounter a situation where a time interval
737/// is involved in an arithmetic operation with a `Duration` type. This function
738/// offers special handling for such cases, where the time interval resides on
739/// the right side of the operation.
740fn propagate_time_interval_at_right(
741 left_child: &Interval,
742 right_child: &Interval,
743 parent: &Interval,
744 op: &Operator,
745 inverse_op: &Operator,
746) -> Result<Option<(Interval, Interval)>> {
747 // We check if the child's time interval(s) has a non-zero month or day field(s).
748 // If so, we return it as is without propagating. Otherwise, we first convert
749 // the time intervals to the `Duration` type, then propagate, and then convert
750 // the bounds to time intervals again.
751 let result = if let Some(duration) = convert_interval_type_to_duration(right_child) {
752 match apply_operator(inverse_op, parent, &duration)?.intersect(left_child)? {
753 Some(value) => {
754 propagate_right(left_child, parent, &duration, op, inverse_op)?
755 .and_then(|right| convert_duration_type_to_interval(&right))
756 .map(|right| (value, right))
757 }
758 None => None,
759 }
760 } else {
761 apply_operator(inverse_op, parent, right_child)?
762 .intersect(left_child)?
763 .map(|value| (value, right_child.clone()))
764 };
765 Ok(result)
766}
767
768fn reverse_tuple<T, U>((first, second): (T, U)) -> (U, T) {
769 (second, first)
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775 use crate::expressions::{BinaryExpr, Column};
776 use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
777
778 use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
779 use arrow::datatypes::{Field, TimeUnit};
780 use datafusion_common::ScalarValue;
781
782 use itertools::Itertools;
783 use rand::rngs::StdRng;
784 use rand::{Rng, SeedableRng};
785 use rstest::*;
786
787 #[allow(clippy::too_many_arguments)]
788 fn experiment(
789 expr: Arc<dyn PhysicalExpr>,
790 exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
791 left_interval: Interval,
792 right_interval: Interval,
793 left_expected: Interval,
794 right_expected: Interval,
795 result: PropagationResult,
796 schema: &Schema,
797 ) -> Result<()> {
798 let col_stats = vec![
799 (Arc::clone(&exprs_with_interval.0), left_interval),
800 (Arc::clone(&exprs_with_interval.1), right_interval),
801 ];
802 let expected = vec![
803 (Arc::clone(&exprs_with_interval.0), left_expected),
804 (Arc::clone(&exprs_with_interval.1), right_expected),
805 ];
806 let mut graph = ExprIntervalGraph::try_new(expr, schema)?;
807 let expr_indexes = graph.gather_node_indices(
808 &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(),
809 );
810
811 let mut col_stat_nodes = col_stats
812 .iter()
813 .zip(expr_indexes.iter())
814 .map(|((_, interval), (_, index))| (*index, interval.clone()))
815 .collect_vec();
816 let expected_nodes = expected
817 .iter()
818 .zip(expr_indexes.iter())
819 .map(|((_, interval), (_, index))| (*index, interval.clone()))
820 .collect_vec();
821
822 let exp_result =
823 graph.update_ranges(&mut col_stat_nodes[..], Interval::CERTAINLY_TRUE)?;
824 assert_eq!(exp_result, result);
825 col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
826 |((_, calculated_interval_node), (_, expected))| {
827 // NOTE: These randomized tests only check for conservative containment,
828 // not openness/closedness of endpoints.
829
830 // Calculated bounds are relaxed by 1 to cover all strict and
831 // and non-strict comparison cases since we have only closed bounds.
832 let one = ScalarValue::new_one(&expected.data_type()).unwrap();
833 assert!(
834 calculated_interval_node.lower()
835 <= &expected.lower().add(&one).unwrap(),
836 "{}",
837 format!(
838 "Calculated {} must be less than or equal {}",
839 calculated_interval_node.lower(),
840 expected.lower()
841 )
842 );
843 assert!(
844 calculated_interval_node.upper()
845 >= &expected.upper().sub(&one).unwrap(),
846 "{}",
847 format!(
848 "Calculated {} must be greater than or equal {}",
849 calculated_interval_node.upper(),
850 expected.upper()
851 )
852 );
853 },
854 );
855 Ok(())
856 }
857
858 macro_rules! generate_cases {
859 ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
860 fn $FUNC_NAME<const ASC: bool>(
861 expr: Arc<dyn PhysicalExpr>,
862 left_col: Arc<dyn PhysicalExpr>,
863 right_col: Arc<dyn PhysicalExpr>,
864 seed: u64,
865 expr_left: $TYPE,
866 expr_right: $TYPE,
867 ) -> Result<()> {
868 let mut r = StdRng::seed_from_u64(seed);
869
870 let (left_given, right_given, left_expected, right_expected) = if ASC {
871 let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
872 let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
873 (
874 (Some(left), None),
875 (Some(right), None),
876 (Some(<$TYPE>::max(left, right + expr_left)), None),
877 (Some(<$TYPE>::max(right, left + expr_right)), None),
878 )
879 } else {
880 let left = r.random_range((0 as $TYPE)..(1000 as $TYPE));
881 let right = r.random_range((0 as $TYPE)..(1000 as $TYPE));
882 (
883 (None, Some(left)),
884 (None, Some(right)),
885 (None, Some(<$TYPE>::min(left, right + expr_left))),
886 (None, Some(<$TYPE>::min(right, left + expr_right))),
887 )
888 };
889
890 experiment(
891 expr,
892 (left_col.clone(), right_col.clone()),
893 Interval::make(left_given.0, left_given.1).unwrap(),
894 Interval::make(right_given.0, right_given.1).unwrap(),
895 Interval::make(left_expected.0, left_expected.1).unwrap(),
896 Interval::make(right_expected.0, right_expected.1).unwrap(),
897 PropagationResult::Success,
898 &Schema::new(vec![
899 Field::new(
900 left_col.as_any().downcast_ref::<Column>().unwrap().name(),
901 DataType::$SCALAR,
902 true,
903 ),
904 Field::new(
905 right_col.as_any().downcast_ref::<Column>().unwrap().name(),
906 DataType::$SCALAR,
907 true,
908 ),
909 ]),
910 )
911 }
912 };
913 }
914 generate_cases!(generate_case_i32, i32, Int32);
915 generate_cases!(generate_case_i64, i64, Int64);
916 generate_cases!(generate_case_f32, f32, Float32);
917 generate_cases!(generate_case_f64, f64, Float64);
918
919 #[test]
920 fn testing_not_possible() -> Result<()> {
921 let left_col = Arc::new(Column::new("left_watermark", 0));
922 let right_col = Arc::new(Column::new("right_watermark", 0));
923
924 // left_watermark > right_watermark + 5
925 let left_and_1 = Arc::new(BinaryExpr::new(
926 Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
927 Operator::Plus,
928 Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
929 ));
930 let expr = Arc::new(BinaryExpr::new(
931 left_and_1,
932 Operator::Gt,
933 Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
934 ));
935 experiment(
936 expr,
937 (
938 Arc::clone(&left_col) as Arc<dyn PhysicalExpr>,
939 Arc::clone(&right_col) as Arc<dyn PhysicalExpr>,
940 ),
941 Interval::make(Some(10_i32), Some(20_i32))?,
942 Interval::make(Some(100), None)?,
943 Interval::make(Some(10), Some(20))?,
944 Interval::make(Some(100), None)?,
945 PropagationResult::Infeasible,
946 &Schema::new(vec![
947 Field::new(
948 left_col.as_any().downcast_ref::<Column>().unwrap().name(),
949 DataType::Int32,
950 true,
951 ),
952 Field::new(
953 right_col.as_any().downcast_ref::<Column>().unwrap().name(),
954 DataType::Int32,
955 true,
956 ),
957 ]),
958 )
959 }
960
961 macro_rules! integer_float_case_1 {
962 ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
963 #[rstest]
964 #[test]
965 fn $TEST_FUNC_NAME(
966 #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
967 seed: u64,
968 #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
969 #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
970 ) -> Result<()> {
971 let left_col = Arc::new(Column::new("left_watermark", 0));
972 let right_col = Arc::new(Column::new("right_watermark", 0));
973
974 // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33
975 let expr = gen_conjunctive_numerical_expr(
976 left_col.clone(),
977 right_col.clone(),
978 (
979 Operator::Plus,
980 Operator::Plus,
981 Operator::Plus,
982 Operator::Plus,
983 ),
984 ScalarValue::$SCALAR(Some(1 as $TYPE)),
985 ScalarValue::$SCALAR(Some(11 as $TYPE)),
986 ScalarValue::$SCALAR(Some(3 as $TYPE)),
987 ScalarValue::$SCALAR(Some(33 as $TYPE)),
988 (greater_op, less_op),
989 );
990 // l > r + 10 AND r > l - 30
991 let l_gt_r = 10 as $TYPE;
992 let r_gt_l = -30 as $TYPE;
993 $GENERATE_CASE_FUNC_NAME::<true>(
994 expr.clone(),
995 left_col.clone(),
996 right_col.clone(),
997 seed,
998 l_gt_r,
999 r_gt_l,
1000 )?;
1001 // Descending tests
1002 // r < l - 10 AND l < r + 30
1003 let r_lt_l = -l_gt_r;
1004 let l_lt_r = -r_gt_l;
1005 $GENERATE_CASE_FUNC_NAME::<false>(
1006 expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1007 )
1008 }
1009 };
1010 }
1011
1012 integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
1013 integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
1014 integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
1015 integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
1016
1017 macro_rules! integer_float_case_2 {
1018 ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1019 #[rstest]
1020 #[test]
1021 fn $TEST_FUNC_NAME(
1022 #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1023 seed: u64,
1024 #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1025 #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1026 ) -> Result<()> {
1027 let left_col = Arc::new(Column::new("left_watermark", 0));
1028 let right_col = Arc::new(Column::new("right_watermark", 0));
1029
1030 // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10
1031 let expr = gen_conjunctive_numerical_expr(
1032 left_col.clone(),
1033 right_col.clone(),
1034 (
1035 Operator::Minus,
1036 Operator::Plus,
1037 Operator::Plus,
1038 Operator::Plus,
1039 ),
1040 ScalarValue::$SCALAR(Some(1 as $TYPE)),
1041 ScalarValue::$SCALAR(Some(5 as $TYPE)),
1042 ScalarValue::$SCALAR(Some(3 as $TYPE)),
1043 ScalarValue::$SCALAR(Some(10 as $TYPE)),
1044 (greater_op, less_op),
1045 );
1046 // l > r + 6 AND r > l - 7
1047 let l_gt_r = 6 as $TYPE;
1048 let r_gt_l = -7 as $TYPE;
1049 $GENERATE_CASE_FUNC_NAME::<true>(
1050 expr.clone(),
1051 left_col.clone(),
1052 right_col.clone(),
1053 seed,
1054 l_gt_r,
1055 r_gt_l,
1056 )?;
1057 // Descending tests
1058 // r < l - 6 AND l < r + 7
1059 let r_lt_l = -l_gt_r;
1060 let l_lt_r = -r_gt_l;
1061 $GENERATE_CASE_FUNC_NAME::<false>(
1062 expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1063 )
1064 }
1065 };
1066 }
1067
1068 integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
1069 integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
1070 integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
1071 integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
1072
1073 macro_rules! integer_float_case_3 {
1074 ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1075 #[rstest]
1076 #[test]
1077 fn $TEST_FUNC_NAME(
1078 #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1079 seed: u64,
1080 #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1081 #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1082 ) -> Result<()> {
1083 let left_col = Arc::new(Column::new("left_watermark", 0));
1084 let right_col = Arc::new(Column::new("right_watermark", 0));
1085
1086 // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10
1087 let expr = gen_conjunctive_numerical_expr(
1088 left_col.clone(),
1089 right_col.clone(),
1090 (
1091 Operator::Minus,
1092 Operator::Plus,
1093 Operator::Minus,
1094 Operator::Plus,
1095 ),
1096 ScalarValue::$SCALAR(Some(1 as $TYPE)),
1097 ScalarValue::$SCALAR(Some(5 as $TYPE)),
1098 ScalarValue::$SCALAR(Some(3 as $TYPE)),
1099 ScalarValue::$SCALAR(Some(10 as $TYPE)),
1100 (greater_op, less_op),
1101 );
1102 // l > r + 6 AND r > l - 13
1103 let l_gt_r = 6 as $TYPE;
1104 let r_gt_l = -13 as $TYPE;
1105 $GENERATE_CASE_FUNC_NAME::<true>(
1106 expr.clone(),
1107 left_col.clone(),
1108 right_col.clone(),
1109 seed,
1110 l_gt_r,
1111 r_gt_l,
1112 )?;
1113 // Descending tests
1114 // r < l - 6 AND l < r + 13
1115 let r_lt_l = -l_gt_r;
1116 let l_lt_r = -r_gt_l;
1117 $GENERATE_CASE_FUNC_NAME::<false>(
1118 expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1119 )
1120 }
1121 };
1122 }
1123
1124 integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
1125 integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
1126 integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
1127 integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
1128
1129 macro_rules! integer_float_case_4 {
1130 ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1131 #[rstest]
1132 #[test]
1133 fn $TEST_FUNC_NAME(
1134 #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1135 seed: u64,
1136 #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1137 #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1138 ) -> Result<()> {
1139 let left_col = Arc::new(Column::new("left_watermark", 0));
1140 let right_col = Arc::new(Column::new("right_watermark", 0));
1141
1142 // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1143 let expr = gen_conjunctive_numerical_expr(
1144 left_col.clone(),
1145 right_col.clone(),
1146 (
1147 Operator::Minus,
1148 Operator::Minus,
1149 Operator::Minus,
1150 Operator::Plus,
1151 ),
1152 ScalarValue::$SCALAR(Some(10 as $TYPE)),
1153 ScalarValue::$SCALAR(Some(5 as $TYPE)),
1154 ScalarValue::$SCALAR(Some(3 as $TYPE)),
1155 ScalarValue::$SCALAR(Some(10 as $TYPE)),
1156 (greater_op, less_op),
1157 );
1158 // l > r + 5 AND r > l - 13
1159 let l_gt_r = 5 as $TYPE;
1160 let r_gt_l = -13 as $TYPE;
1161 $GENERATE_CASE_FUNC_NAME::<true>(
1162 expr.clone(),
1163 left_col.clone(),
1164 right_col.clone(),
1165 seed,
1166 l_gt_r,
1167 r_gt_l,
1168 )?;
1169 // Descending tests
1170 // r < l - 5 AND l < r + 13
1171 let r_lt_l = -l_gt_r;
1172 let l_lt_r = -r_gt_l;
1173 $GENERATE_CASE_FUNC_NAME::<false>(
1174 expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1175 )
1176 }
1177 };
1178 }
1179
1180 integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
1181 integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
1182 integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
1183 integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
1184
1185 macro_rules! integer_float_case_5 {
1186 ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
1187 #[rstest]
1188 #[test]
1189 fn $TEST_FUNC_NAME(
1190 #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215, 4123)]
1191 seed: u64,
1192 #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
1193 #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
1194 ) -> Result<()> {
1195 let left_col = Arc::new(Column::new("left_watermark", 0));
1196 let right_col = Arc::new(Column::new("right_watermark", 0));
1197
1198 // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3
1199 let expr = gen_conjunctive_numerical_expr(
1200 left_col.clone(),
1201 right_col.clone(),
1202 (
1203 Operator::Minus,
1204 Operator::Minus,
1205 Operator::Minus,
1206 Operator::Minus,
1207 ),
1208 ScalarValue::$SCALAR(Some(10 as $TYPE)),
1209 ScalarValue::$SCALAR(Some(5 as $TYPE)),
1210 ScalarValue::$SCALAR(Some(30 as $TYPE)),
1211 ScalarValue::$SCALAR(Some(3 as $TYPE)),
1212 (greater_op, less_op),
1213 );
1214 // l > r + 5 AND r > l - 27
1215 let l_gt_r = 5 as $TYPE;
1216 let r_gt_l = -27 as $TYPE;
1217 $GENERATE_CASE_FUNC_NAME::<true>(
1218 expr.clone(),
1219 left_col.clone(),
1220 right_col.clone(),
1221 seed,
1222 l_gt_r,
1223 r_gt_l,
1224 )?;
1225 // Descending tests
1226 // r < l - 5 AND l < r + 27
1227 let r_lt_l = -l_gt_r;
1228 let l_lt_r = -r_gt_l;
1229 $GENERATE_CASE_FUNC_NAME::<false>(
1230 expr, left_col, right_col, seed, l_lt_r, r_lt_l,
1231 )
1232 }
1233 };
1234 }
1235
1236 integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
1237 integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
1238 integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
1239 integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
1240
1241 #[test]
1242 fn test_gather_node_indices_dont_remove() -> Result<()> {
1243 // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1.
1244 // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also
1245 // depends on leaf nodes a@0 and b@1.
1246 let left_expr = Arc::new(BinaryExpr::new(
1247 Arc::new(BinaryExpr::new(
1248 Arc::new(Column::new("a", 0)),
1249 Operator::Plus,
1250 Arc::new(Column::new("b", 1)),
1251 )),
1252 Operator::Plus,
1253 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1254 ));
1255
1256 let right_expr = Arc::new(BinaryExpr::new(
1257 Arc::new(Column::new("a", 0)),
1258 Operator::Minus,
1259 Arc::new(Column::new("b", 1)),
1260 ));
1261 let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1262 let mut graph = ExprIntervalGraph::try_new(
1263 expr,
1264 &Schema::new(vec![
1265 Field::new("a", DataType::Int32, true),
1266 Field::new("b", DataType::Int32, true),
1267 ]),
1268 )
1269 .unwrap();
1270 // Define a test leaf node.
1271 let leaf_node = Arc::new(BinaryExpr::new(
1272 Arc::new(Column::new("a", 0)),
1273 Operator::Plus,
1274 Arc::new(Column::new("b", 1)),
1275 ));
1276 // Store the current node count.
1277 let prev_node_count = graph.node_count();
1278 // Gather the index of node in the expression graph that match the test leaf node.
1279 graph.gather_node_indices(&[leaf_node]);
1280 // Store the final node count.
1281 let final_node_count = graph.node_count();
1282 // Assert that the final node count is equal the previous node count.
1283 // This means we did not remove any node.
1284 assert_eq!(prev_node_count, final_node_count);
1285 Ok(())
1286 }
1287
1288 #[test]
1289 fn test_gather_node_indices_remove() -> Result<()> {
1290 // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1.
1291 // We expect to remove two nodes since we do not need a@ and b@.
1292 let left_expr = Arc::new(BinaryExpr::new(
1293 Arc::new(BinaryExpr::new(
1294 Arc::new(Column::new("a", 0)),
1295 Operator::Plus,
1296 Arc::new(Column::new("b", 1)),
1297 )),
1298 Operator::Plus,
1299 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1300 ));
1301
1302 let right_expr = Arc::new(BinaryExpr::new(
1303 Arc::new(Column::new("y", 0)),
1304 Operator::Minus,
1305 Arc::new(Column::new("z", 1)),
1306 ));
1307 let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1308 let mut graph = ExprIntervalGraph::try_new(
1309 expr,
1310 &Schema::new(vec![
1311 Field::new("a", DataType::Int32, true),
1312 Field::new("b", DataType::Int32, true),
1313 Field::new("y", DataType::Int32, true),
1314 Field::new("z", DataType::Int32, true),
1315 ]),
1316 )
1317 .unwrap();
1318 // Define a test leaf node.
1319 let leaf_node = Arc::new(BinaryExpr::new(
1320 Arc::new(Column::new("a", 0)),
1321 Operator::Plus,
1322 Arc::new(Column::new("b", 1)),
1323 ));
1324 // Store the current node count.
1325 let prev_node_count = graph.node_count();
1326 // Gather the index of node in the expression graph that match the test leaf node.
1327 graph.gather_node_indices(&[leaf_node]);
1328 // Store the final node count.
1329 let final_node_count = graph.node_count();
1330 // Assert that the final node count is two less than the previous node
1331 // count; i.e. that we did remove two nodes.
1332 assert_eq!(prev_node_count, final_node_count + 2);
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn test_gather_node_indices_remove_one() -> Result<()> {
1338 // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1.
1339 // We expect to remove one nodesince we still need a@ but not b@.
1340 let left_expr = Arc::new(BinaryExpr::new(
1341 Arc::new(BinaryExpr::new(
1342 Arc::new(Column::new("a", 0)),
1343 Operator::Plus,
1344 Arc::new(Column::new("b", 1)),
1345 )),
1346 Operator::Plus,
1347 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1348 ));
1349
1350 let right_expr = Arc::new(BinaryExpr::new(
1351 Arc::new(Column::new("a", 0)),
1352 Operator::Minus,
1353 Arc::new(Column::new("z", 1)),
1354 ));
1355 let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1356 let mut graph = ExprIntervalGraph::try_new(
1357 expr,
1358 &Schema::new(vec![
1359 Field::new("a", DataType::Int32, true),
1360 Field::new("b", DataType::Int32, true),
1361 Field::new("z", DataType::Int32, true),
1362 ]),
1363 )
1364 .unwrap();
1365 // Define a test leaf node.
1366 let leaf_node = Arc::new(BinaryExpr::new(
1367 Arc::new(Column::new("a", 0)),
1368 Operator::Plus,
1369 Arc::new(Column::new("b", 1)),
1370 ));
1371 // Store the current node count.
1372 let prev_node_count = graph.node_count();
1373 // Gather the index of node in the expression graph that match the test leaf node.
1374 graph.gather_node_indices(&[leaf_node]);
1375 // Store the final node count.
1376 let final_node_count = graph.node_count();
1377 // Assert that the final node count is one less than the previous node
1378 // count; i.e. that we did remove two nodes.
1379 assert_eq!(prev_node_count, final_node_count + 1);
1380 Ok(())
1381 }
1382
1383 #[test]
1384 fn test_gather_node_indices_cannot_provide() -> Result<()> {
1385 // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1
1386 // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node.
1387 // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions.
1388 // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches.
1389 // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future.
1390 let left_expr = Arc::new(BinaryExpr::new(
1391 Arc::new(BinaryExpr::new(
1392 Arc::new(Column::new("a", 0)),
1393 Operator::Plus,
1394 Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
1395 )),
1396 Operator::Plus,
1397 Arc::new(Column::new("b", 1)),
1398 ));
1399
1400 let right_expr = Arc::new(BinaryExpr::new(
1401 Arc::new(Column::new("y", 0)),
1402 Operator::Minus,
1403 Arc::new(Column::new("z", 1)),
1404 ));
1405 let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr));
1406 let mut graph = ExprIntervalGraph::try_new(
1407 expr,
1408 &Schema::new(vec![
1409 Field::new("a", DataType::Int32, true),
1410 Field::new("b", DataType::Int32, true),
1411 Field::new("y", DataType::Int32, true),
1412 Field::new("z", DataType::Int32, true),
1413 ]),
1414 )
1415 .unwrap();
1416 // Define a test leaf node.
1417 let leaf_node = Arc::new(BinaryExpr::new(
1418 Arc::new(Column::new("a", 0)),
1419 Operator::Plus,
1420 Arc::new(Column::new("b", 1)),
1421 ));
1422 // Store the current node count.
1423 let prev_node_count = graph.node_count();
1424 // Gather the index of node in the expression graph that match the test leaf node.
1425 graph.gather_node_indices(&[leaf_node]);
1426 // Store the final node count.
1427 let final_node_count = graph.node_count();
1428 // Assert that the final node count is equal the previous node count (i.e., no node was pruned).
1429 assert_eq!(prev_node_count, final_node_count);
1430 Ok(())
1431 }
1432
1433 #[test]
1434 fn test_propagate_constraints_singleton_interval_at_right() -> Result<()> {
1435 let expression = BinaryExpr::new(
1436 Arc::new(Column::new("ts_column", 0)),
1437 Operator::Plus,
1438 Arc::new(Literal::new(ScalarValue::new_interval_mdn(0, 1, 321))),
1439 );
1440 let parent = Interval::try_new(
1441 // 15.10.2020 - 10:11:12.000_000_321 AM
1442 ScalarValue::TimestampNanosecond(Some(1_602_756_672_000_000_321), None),
1443 // 16.10.2020 - 10:11:12.000_000_321 AM
1444 ScalarValue::TimestampNanosecond(Some(1_602_843_072_000_000_321), None),
1445 )?;
1446 let left_child = Interval::try_new(
1447 // 10.10.2020 - 10:11:12 AM
1448 ScalarValue::TimestampNanosecond(Some(1_602_324_672_000_000_000), None),
1449 // 20.10.2020 - 10:11:12 AM
1450 ScalarValue::TimestampNanosecond(Some(1_603_188_672_000_000_000), None),
1451 )?;
1452 let right_child = Interval::try_new(
1453 // 1 day 321 ns
1454 ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1455 months: 0,
1456 days: 1,
1457 nanoseconds: 321,
1458 })),
1459 // 1 day 321 ns
1460 ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1461 months: 0,
1462 days: 1,
1463 nanoseconds: 321,
1464 })),
1465 )?;
1466 let children = vec![&left_child, &right_child];
1467 let result = expression
1468 .propagate_constraints(&parent, &children)?
1469 .unwrap();
1470
1471 assert_eq!(
1472 vec![
1473 Interval::try_new(
1474 // 14.10.2020 - 10:11:12 AM
1475 ScalarValue::TimestampNanosecond(
1476 Some(1_602_670_272_000_000_000),
1477 None
1478 ),
1479 // 15.10.2020 - 10:11:12 AM
1480 ScalarValue::TimestampNanosecond(
1481 Some(1_602_756_672_000_000_000),
1482 None
1483 ),
1484 )?,
1485 Interval::try_new(
1486 // 1 day 321 ns in Duration type
1487 ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1488 months: 0,
1489 days: 1,
1490 nanoseconds: 321,
1491 })),
1492 // 1 day 321 ns in Duration type
1493 ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano {
1494 months: 0,
1495 days: 1,
1496 nanoseconds: 321,
1497 })),
1498 )?
1499 ],
1500 result
1501 );
1502
1503 Ok(())
1504 }
1505
1506 #[test]
1507 fn test_propagate_constraints_column_interval_at_left() -> Result<()> {
1508 let expression = BinaryExpr::new(
1509 Arc::new(Column::new("interval_column", 1)),
1510 Operator::Plus,
1511 Arc::new(Column::new("ts_column", 0)),
1512 );
1513 let parent = Interval::try_new(
1514 // 15.10.2020 - 10:11:12 AM
1515 ScalarValue::TimestampMillisecond(Some(1_602_756_672_000), None),
1516 // 16.10.2020 - 10:11:12 AM
1517 ScalarValue::TimestampMillisecond(Some(1_602_843_072_000), None),
1518 )?;
1519 let right_child = Interval::try_new(
1520 // 10.10.2020 - 10:11:12 AM
1521 ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1522 // 20.10.2020 - 10:11:12 AM
1523 ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None),
1524 )?;
1525 let left_child = Interval::try_new(
1526 // 2 days in millisecond
1527 ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1528 days: 0,
1529 milliseconds: 172_800_000,
1530 })),
1531 // 10 days in millisecond
1532 ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1533 days: 0,
1534 milliseconds: 864_000_000,
1535 })),
1536 )?;
1537 let children = vec![&left_child, &right_child];
1538 let result = expression
1539 .propagate_constraints(&parent, &children)?
1540 .unwrap();
1541
1542 assert_eq!(
1543 vec![
1544 Interval::try_new(
1545 // 2 days in millisecond
1546 ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1547 days: 0,
1548 milliseconds: 172_800_000,
1549 })),
1550 // 6 days
1551 ScalarValue::IntervalDayTime(Some(IntervalDayTime {
1552 days: 0,
1553 milliseconds: 518_400_000,
1554 })),
1555 )?,
1556 Interval::try_new(
1557 // 10.10.2020 - 10:11:12 AM
1558 ScalarValue::TimestampMillisecond(Some(1_602_324_672_000), None),
1559 // 14.10.2020 - 10:11:12 AM
1560 ScalarValue::TimestampMillisecond(Some(1_602_670_272_000), None),
1561 )?
1562 ],
1563 result
1564 );
1565
1566 Ok(())
1567 }
1568
1569 #[test]
1570 fn test_propagate_comparison() -> Result<()> {
1571 // In the examples below:
1572 // `left` is unbounded: [?, ?],
1573 // `right` is known to be [1000,1000]
1574 // so `left` < `right` results in no new knowledge of `right` but knowing that `left` is now < 1000:` [?, 999]
1575 let left = Interval::make_unbounded(&DataType::Int64)?;
1576 let right = Interval::make(Some(1000_i64), Some(1000_i64))?;
1577 assert_eq!(
1578 (Some((
1579 Interval::make(None, Some(999_i64))?,
1580 Interval::make(Some(1000_i64), Some(1000_i64))?,
1581 ))),
1582 propagate_comparison(
1583 &Operator::Lt,
1584 &Interval::CERTAINLY_TRUE,
1585 &left,
1586 &right
1587 )?
1588 );
1589
1590 let left =
1591 Interval::make_unbounded(&DataType::Timestamp(TimeUnit::Nanosecond, None))?;
1592 let right = Interval::try_new(
1593 ScalarValue::TimestampNanosecond(Some(1000), None),
1594 ScalarValue::TimestampNanosecond(Some(1000), None),
1595 )?;
1596 assert_eq!(
1597 (Some((
1598 Interval::try_new(
1599 ScalarValue::try_from(&DataType::Timestamp(
1600 TimeUnit::Nanosecond,
1601 None
1602 ))
1603 .unwrap(),
1604 ScalarValue::TimestampNanosecond(Some(999), None),
1605 )?,
1606 Interval::try_new(
1607 ScalarValue::TimestampNanosecond(Some(1000), None),
1608 ScalarValue::TimestampNanosecond(Some(1000), None),
1609 )?
1610 ))),
1611 propagate_comparison(
1612 &Operator::Lt,
1613 &Interval::CERTAINLY_TRUE,
1614 &left,
1615 &right
1616 )?
1617 );
1618
1619 let left = Interval::make_unbounded(&DataType::Timestamp(
1620 TimeUnit::Nanosecond,
1621 Some("+05:00".into()),
1622 ))?;
1623 let right = Interval::try_new(
1624 ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1625 ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1626 )?;
1627 assert_eq!(
1628 (Some((
1629 Interval::try_new(
1630 ScalarValue::try_from(&DataType::Timestamp(
1631 TimeUnit::Nanosecond,
1632 Some("+05:00".into()),
1633 ))
1634 .unwrap(),
1635 ScalarValue::TimestampNanosecond(Some(999), Some("+05:00".into())),
1636 )?,
1637 Interval::try_new(
1638 ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1639 ScalarValue::TimestampNanosecond(Some(1000), Some("+05:00".into())),
1640 )?
1641 ))),
1642 propagate_comparison(
1643 &Operator::Lt,
1644 &Interval::CERTAINLY_TRUE,
1645 &left,
1646 &right
1647 )?
1648 );
1649
1650 Ok(())
1651 }
1652
1653 #[test]
1654 fn test_propagate_or() -> Result<()> {
1655 let expr = Arc::new(BinaryExpr::new(
1656 Arc::new(Column::new("a", 0)),
1657 Operator::Or,
1658 Arc::new(Column::new("b", 1)),
1659 ));
1660 let parent = Interval::CERTAINLY_FALSE;
1661 let children_set = vec![
1662 vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1663 vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_FALSE],
1664 vec![&Interval::CERTAINLY_FALSE, &Interval::CERTAINLY_FALSE],
1665 vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1666 ];
1667 for children in children_set {
1668 assert_eq!(
1669 expr.propagate_constraints(&parent, &children)?.unwrap(),
1670 vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_FALSE],
1671 );
1672 }
1673
1674 let parent = Interval::CERTAINLY_FALSE;
1675 let children_set = vec![
1676 vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1677 vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1678 ];
1679 for children in children_set {
1680 assert_eq!(expr.propagate_constraints(&parent, &children)?, None,);
1681 }
1682
1683 let parent = Interval::CERTAINLY_TRUE;
1684 let children = vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN];
1685 assert_eq!(
1686 expr.propagate_constraints(&parent, &children)?.unwrap(),
1687 vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE]
1688 );
1689
1690 let parent = Interval::CERTAINLY_TRUE;
1691 let children = vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN];
1692 assert_eq!(
1693 expr.propagate_constraints(&parent, &children)?.unwrap(),
1694 // Empty means unchanged intervals.
1695 vec![]
1696 );
1697
1698 Ok(())
1699 }
1700
1701 #[test]
1702 fn test_propagate_certainly_false_and() -> Result<()> {
1703 let expr = Arc::new(BinaryExpr::new(
1704 Arc::new(Column::new("a", 0)),
1705 Operator::And,
1706 Arc::new(Column::new("b", 1)),
1707 ));
1708 let parent = Interval::CERTAINLY_FALSE;
1709 let children_and_results_set = vec![
1710 (
1711 vec![&Interval::CERTAINLY_TRUE, &Interval::UNCERTAIN],
1712 vec![Interval::CERTAINLY_TRUE, Interval::CERTAINLY_FALSE],
1713 ),
1714 (
1715 vec![&Interval::UNCERTAIN, &Interval::CERTAINLY_TRUE],
1716 vec![Interval::CERTAINLY_FALSE, Interval::CERTAINLY_TRUE],
1717 ),
1718 (
1719 vec![&Interval::UNCERTAIN, &Interval::UNCERTAIN],
1720 // Empty means unchanged intervals.
1721 vec![],
1722 ),
1723 (
1724 vec![&Interval::CERTAINLY_FALSE, &Interval::UNCERTAIN],
1725 vec![],
1726 ),
1727 ];
1728 for (children, result) in children_and_results_set {
1729 assert_eq!(
1730 expr.propagate_constraints(&parent, &children)?.unwrap(),
1731 result
1732 );
1733 }
1734
1735 Ok(())
1736 }
1737}