Skip to main content

datafusion_physical_expr/utils/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column, Literal};
25use crate::tree_node::ExprContext;
26use crate::{
27    AcrossPartitions, ConstExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr,
28};
29
30use arrow::datatypes::Schema;
31use datafusion_common::tree_node::{
32    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
33};
34use datafusion_common::{HashMap, HashSet, Result};
35use datafusion_expr::Operator;
36
37use petgraph::graph::NodeIndex;
38use petgraph::stable_graph::StableGraph;
39
40/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
41///
42/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
43pub fn split_conjunction(
44    predicate: &Arc<dyn PhysicalExpr>,
45) -> Vec<&Arc<dyn PhysicalExpr>> {
46    split_impl(Operator::And, predicate, vec![])
47}
48
49impl ConstExpr {
50    /// Collects predicate-derived constants from equality conjunctions.
51    ///
52    /// For each equality predicate of the form `lhs = rhs`, if either side is
53    /// already known constant according to `input_eqs`, or is a literal, then
54    /// the other side is also constant and will be returned as a [`ConstExpr`].
55    ///
56    /// Literals are treated as uniform constants across partitions, so
57    /// `col = literal` produces a constant for `col` with the literal value.
58    ///
59    /// For example, given predicate `a = 5 AND b = c` where `c` is already
60    /// known constant, this returns constants for both `a` (Uniform with value
61    /// 5) and `b` (propagating `c`'s across-partitions value).
62    pub fn collect_predicate_constants(
63        input_eqs: &EquivalenceProperties,
64        predicate: &Arc<dyn PhysicalExpr>,
65    ) -> Vec<ConstExpr> {
66        /// Returns the `AcrossPartitions` value for `expr` if it is constant:
67        /// either already known constant in `input_eqs`, or a `Literal`
68        /// (which is inherently constant across all partitions).
69        fn expr_constant_or_literal(
70            expr: &Arc<dyn PhysicalExpr>,
71            input_eqs: &EquivalenceProperties,
72        ) -> Option<AcrossPartitions> {
73            input_eqs.is_expr_constant(expr).or_else(|| {
74                expr.downcast_ref::<Literal>()
75                    .map(|l| AcrossPartitions::Uniform(Some(l.value().clone())))
76            })
77        }
78
79        let mut constants = Vec::new();
80        for conjunction in split_conjunction(predicate) {
81            if let Some(binary) = conjunction.downcast_ref::<BinaryExpr>()
82                && binary.op() == &Operator::Eq
83            {
84                // Check if either side is constant — either already known
85                // constant from the input equivalence properties, or a literal
86                // value (which is inherently constant across all partitions).
87                let left_const = expr_constant_or_literal(binary.left(), input_eqs);
88                let right_const = expr_constant_or_literal(binary.right(), input_eqs);
89
90                if let Some(left_across) = left_const {
91                    // LEFT is constant, so RIGHT must also be constant.
92                    // Use RIGHT's known across value if available, otherwise
93                    // propagate LEFT's (e.g. Uniform from a literal).
94                    let across = right_const.unwrap_or(left_across);
95                    constants.push(ConstExpr::new(Arc::clone(binary.right()), across));
96                } else if let Some(right_across) = right_const {
97                    // RIGHT is constant, so LEFT must also be constant.
98                    constants
99                        .push(ConstExpr::new(Arc::clone(binary.left()), right_across));
100                }
101            }
102        }
103
104        constants
105    }
106}
107
108/// Create a conjunction of the given predicates.
109/// If the input is empty, return a literal true.
110/// If the input contains a single predicate, return the predicate.
111/// Otherwise, return a conjunction of the predicates (e.g. `a AND b AND c`).
112pub fn conjunction(
113    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
114) -> Arc<dyn PhysicalExpr> {
115    conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
116}
117
118/// Create a conjunction of the given predicates.
119/// If the input is empty or the return None.
120/// If the input contains a single predicate, return Some(predicate).
121/// Otherwise, return a Some(..) of a conjunction of the predicates (e.g. `Some(a AND b AND c)`).
122pub fn conjunction_opt(
123    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
124) -> Option<Arc<dyn PhysicalExpr>> {
125    predicates
126        .into_iter()
127        .fold(None, |acc, predicate| match acc {
128            None => Some(predicate),
129            Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
130        })
131}
132
133/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
134///
135/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
136pub fn split_disjunction(
137    predicate: &Arc<dyn PhysicalExpr>,
138) -> Vec<&Arc<dyn PhysicalExpr>> {
139    split_impl(Operator::Or, predicate, vec![])
140}
141
142fn split_impl<'a>(
143    operator: Operator,
144    predicate: &'a Arc<dyn PhysicalExpr>,
145    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
146) -> Vec<&'a Arc<dyn PhysicalExpr>> {
147    match predicate.downcast_ref::<BinaryExpr>() {
148        Some(binary) if binary.op() == &operator => {
149            let exprs = split_impl(operator, binary.left(), exprs);
150            split_impl(operator, binary.right(), exprs)
151        }
152        Some(_) | None => {
153            exprs.push(predicate);
154            exprs
155        }
156    }
157}
158
159/// This function maps back requirement after ProjectionExec
160/// to the Executor for its input.
161// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
162// This function changes requirement given according to ProjectionExec schema to the requirement
163// according to schema of input executor to the ProjectionExec.
164// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
165// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
166// This function will produce incorrect result (It will only emit single Column as a result).
167pub fn map_columns_before_projection(
168    parent_required: &[Arc<dyn PhysicalExpr>],
169    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
170) -> Vec<Arc<dyn PhysicalExpr>> {
171    if parent_required.is_empty() {
172        // No need to build mapping.
173        return vec![];
174    }
175    let column_mapping = proj_exprs
176        .iter()
177        .filter_map(|(expr, name)| {
178            expr.downcast_ref::<Column>()
179                .map(|column| (name.clone(), column.clone()))
180        })
181        .collect::<HashMap<_, _>>();
182    parent_required
183        .iter()
184        .filter_map(|r| {
185            r.downcast_ref::<Column>()
186                .and_then(|c| column_mapping.get(c.name()))
187        })
188        .map(|e| Arc::new(e.clone()) as _)
189        .collect()
190}
191
192/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
193/// `PhysicalSortExpr` sequence.
194pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
195    sequence: impl IntoIterator<Item = T>,
196) -> Vec<Arc<dyn PhysicalExpr>> {
197    sequence
198        .into_iter()
199        .map(|elem| Arc::clone(&elem.borrow().expr))
200        .collect()
201}
202
203/// This function finds the indices of `targets` within `items` using strict
204/// equality.
205pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
206    targets: impl IntoIterator<Item = T>,
207    items: &[Arc<dyn PhysicalExpr>],
208) -> Vec<usize> {
209    targets
210        .into_iter()
211        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
212        .collect()
213}
214
215pub type ExprTreeNode<T> = ExprContext<Option<T>>;
216
217/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
218/// DAG) by collecting identical expressions in one node. Caller specifies the node type
219/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
220/// the [`ExprTreeNode`] ancillary object.
221struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
222    // The resulting DAEG (expression DAG).
223    graph: StableGraph<T, usize>,
224    // A vector of visited expression nodes and their corresponding node indices.
225    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
226    // A function to convert an input expression node to T.
227    constructor: &'a F,
228}
229
230impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
231    // This method mutates an expression node by transforming it to a physical expression
232    // and adding it to the graph. The method returns the mutated expression node.
233    fn mutate(
234        &mut self,
235        mut node: ExprTreeNode<NodeIndex>,
236    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
237        // Get the expression associated with the input expression node.
238        let expr = &node.expr;
239
240        // Check if the expression has already been visited.
241        let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
242            // If the expression has been visited, return the corresponding node index.
243            Some((_, idx)) => *idx,
244            // If the expression has not been visited, add a new node to the graph and
245            // add edges to its child nodes. Add the visited expression to the vector
246            // of visited expressions and return the newly created node index.
247            None => {
248                let node_idx = self.graph.add_node((self.constructor)(&node)?);
249                for expr_node in node.children.iter() {
250                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
251                }
252                self.visited_plans.push((Arc::clone(expr), node_idx));
253                node_idx
254            }
255        };
256        // Set the data field of the input expression node to the corresponding node index.
257        node.data = Some(node_idx);
258        // Return the mutated expression node.
259        Ok(Transformed::yes(node))
260    }
261}
262
263// A function that builds a directed acyclic graph of physical expression trees.
264pub fn build_dag<T, F>(
265    expr: Arc<dyn PhysicalExpr>,
266    constructor: &F,
267) -> Result<(NodeIndex, StableGraph<T, usize>)>
268where
269    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
270{
271    // Create a new expression tree node from the input expression.
272    let init = ExprTreeNode::new_default(expr);
273    // Create a new `PhysicalExprDAEGBuilder` instance.
274    let mut builder = PhysicalExprDAEGBuilder {
275        graph: StableGraph::<T, usize>::new(),
276        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
277        constructor,
278    };
279    // Use the builder to transform the expression tree node into a DAG.
280    let root = init.transform_up(|node| builder.mutate(node)).data()?;
281    // Return a tuple containing the root node index and the DAG.
282    Ok((root.data.unwrap(), builder.graph))
283}
284
285/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
286pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
287    let mut columns = HashSet::<Column>::new();
288    expr.apply(|expr| {
289        if let Some(column) = expr.downcast_ref::<Column>() {
290            columns.get_or_insert_with(column, |c| c.clone());
291        }
292        Ok(TreeNodeRecursion::Continue)
293    })
294    // pre_visit always returns OK, so this will always too
295    .expect("no way to return error during recursion");
296    columns
297}
298
299/// Re-assign indices of [`Column`]s within the given [`PhysicalExpr`] according to
300/// the provided [`Schema`].
301///
302/// This can be useful when attempting to map an expression onto a different schema.
303///
304/// # Errors
305///
306/// This function will return an error if any column in the expression cannot be found
307/// in the provided schema.
308pub fn reassign_expr_columns(
309    expr: Arc<dyn PhysicalExpr>,
310    schema: &Schema,
311) -> Result<Arc<dyn PhysicalExpr>> {
312    expr.transform_down(|expr| {
313        if let Some(column) = expr.downcast_ref::<Column>() {
314            let index = schema.index_of(column.name())?;
315
316            return Ok(Transformed::yes(Arc::new(Column::new(
317                column.name(),
318                index,
319            ))));
320        }
321        Ok(Transformed::no(expr))
322    })
323    .data()
324}
325
326#[cfg(test)]
327pub(crate) mod tests {
328
329    use std::fmt::{Display, Formatter};
330
331    use super::*;
332    use crate::expressions::{Literal, binary, cast, col, in_list, lit};
333
334    use arrow::array::{ArrayRef, Float32Array, Float64Array};
335    use arrow::datatypes::{DataType, Field};
336    use datafusion_common::{ScalarValue, exec_err, internal_datafusion_err};
337    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
338    use datafusion_expr::{
339        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
340    };
341
342    use petgraph::visit::Bfs;
343
344    #[derive(Debug, PartialEq, Eq, Hash)]
345    pub struct TestScalarUDF {
346        pub(crate) signature: Signature,
347    }
348
349    impl TestScalarUDF {
350        pub fn new() -> Self {
351            use DataType::*;
352            Self {
353                signature: Signature::uniform(
354                    1,
355                    vec![Float64, Float32],
356                    Volatility::Immutable,
357                ),
358            }
359        }
360    }
361
362    impl ScalarUDFImpl for TestScalarUDF {
363        fn name(&self) -> &str {
364            "test-scalar-udf"
365        }
366
367        fn signature(&self) -> &Signature {
368            &self.signature
369        }
370
371        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
372            let arg_type = &arg_types[0];
373
374            match arg_type {
375                DataType::Float32 => Ok(DataType::Float32),
376                _ => Ok(DataType::Float64),
377            }
378        }
379
380        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
381            Ok(input[0].sort_properties)
382        }
383
384        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
385            let args = ColumnarValue::values_to_arrays(&args.args)?;
386
387            let arr: ArrayRef = match args[0].data_type() {
388                DataType::Float64 => Arc::new({
389                    let arg = &args[0]
390                        .as_any()
391                        .downcast_ref::<Float64Array>()
392                        .ok_or_else(|| {
393                            internal_datafusion_err!(
394                                "could not cast {} to {}",
395                                self.name(),
396                                std::any::type_name::<Float64Array>()
397                            )
398                        })?;
399
400                    arg.iter()
401                        .map(|a| a.map(f64::floor))
402                        .collect::<Float64Array>()
403                }),
404                DataType::Float32 => Arc::new({
405                    let arg = &args[0]
406                        .as_any()
407                        .downcast_ref::<Float32Array>()
408                        .ok_or_else(|| {
409                            internal_datafusion_err!(
410                                "could not cast {} to {}",
411                                self.name(),
412                                std::any::type_name::<Float32Array>()
413                            )
414                        })?;
415
416                    arg.iter()
417                        .map(|a| a.map(f32::floor))
418                        .collect::<Float32Array>()
419                }),
420                other => {
421                    return exec_err!(
422                        "Unsupported data type {other:?} for function {}",
423                        self.name()
424                    );
425                }
426            };
427            Ok(ColumnarValue::Array(arr))
428        }
429    }
430
431    #[derive(Clone)]
432    struct DummyProperty {
433        expr_type: String,
434    }
435
436    /// This is a dummy node in the DAEG; it stores a reference to the actual
437    /// [PhysicalExpr] as well as a dummy property.
438    #[derive(Clone)]
439    struct PhysicalExprDummyNode {
440        pub expr: Arc<dyn PhysicalExpr>,
441        pub property: DummyProperty,
442    }
443
444    impl Display for PhysicalExprDummyNode {
445        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
446            write!(f, "{}", self.expr)
447        }
448    }
449
450    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
451        let expr = Arc::clone(&node.expr);
452        let dummy_property = if expr.is::<BinaryExpr>() {
453            "Binary"
454        } else if expr.is::<Column>() {
455            "Column"
456        } else if expr.is::<Literal>() {
457            "Literal"
458        } else {
459            "Other"
460        }
461        .to_owned();
462        Ok(PhysicalExprDummyNode {
463            expr,
464            property: DummyProperty {
465                expr_type: dummy_property,
466            },
467        })
468    }
469
470    #[test]
471    fn test_build_dag() -> Result<()> {
472        let schema = Schema::new(vec![
473            Field::new("0", DataType::Int32, true),
474            Field::new("1", DataType::Int32, true),
475            Field::new("2", DataType::Int32, true),
476        ]);
477        let expr = binary(
478            cast(
479                binary(
480                    col("0", &schema)?,
481                    Operator::Plus,
482                    col("1", &schema)?,
483                    &schema,
484                )?,
485                &schema,
486                DataType::Int64,
487            )?,
488            Operator::Gt,
489            binary(
490                cast(col("2", &schema)?, &schema, DataType::Int64)?,
491                Operator::Plus,
492                lit(ScalarValue::Int64(Some(10))),
493                &schema,
494            )?,
495            &schema,
496        )?;
497        let mut vector_dummy_props = vec![];
498        let (root, graph) = build_dag(expr, &make_dummy_node)?;
499        let mut bfs = Bfs::new(&graph, root);
500        while let Some(node_index) = bfs.next(&graph) {
501            let node = &graph[node_index];
502            vector_dummy_props.push(node.property.clone());
503        }
504
505        assert_eq!(
506            vector_dummy_props
507                .iter()
508                .filter(|property| property.expr_type == "Binary")
509                .count(),
510            3
511        );
512        assert_eq!(
513            vector_dummy_props
514                .iter()
515                .filter(|property| property.expr_type == "Column")
516                .count(),
517            3
518        );
519        assert_eq!(
520            vector_dummy_props
521                .iter()
522                .filter(|property| property.expr_type == "Literal")
523                .count(),
524            1
525        );
526        assert_eq!(
527            vector_dummy_props
528                .iter()
529                .filter(|property| property.expr_type == "Other")
530                .count(),
531            2
532        );
533        Ok(())
534    }
535
536    #[test]
537    fn test_convert_to_expr() -> Result<()> {
538        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
539        let sort_expr = vec![PhysicalSortExpr {
540            expr: col("a", &schema)?,
541            options: Default::default(),
542        }];
543        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
544        Ok(())
545    }
546
547    #[test]
548    fn test_get_indices_of_exprs_strict() {
549        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
550            Arc::new(Column::new("a", 0)),
551            Arc::new(Column::new("b", 1)),
552            Arc::new(Column::new("c", 2)),
553            Arc::new(Column::new("d", 3)),
554        ];
555        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
556            Arc::new(Column::new("b", 1)),
557            Arc::new(Column::new("c", 2)),
558            Arc::new(Column::new("a", 0)),
559        ];
560        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
561        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
562    }
563
564    #[test]
565    fn test_reassign_expr_columns_in_list() {
566        let int_field = Field::new("should_not_matter", DataType::Int64, true);
567        let dict_field = Field::new(
568            "id",
569            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
570            true,
571        );
572        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
573        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
574        let pred = in_list(
575            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
576            vec![lit(ScalarValue::Dictionary(
577                Box::new(DataType::Int32),
578                Box::new(ScalarValue::from("2")),
579            ))],
580            &false,
581            &schema_big,
582        )
583        .unwrap();
584
585        let actual = reassign_expr_columns(pred, &schema_small).unwrap();
586
587        let expected = in_list(
588            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
589            vec![lit(ScalarValue::Dictionary(
590                Box::new(DataType::Int32),
591                Box::new(ScalarValue::from("2")),
592            ))],
593            &false,
594            &schema_small,
595        )
596        .unwrap();
597
598        assert_eq!(actual.as_ref(), expected.as_ref());
599    }
600
601    #[test]
602    fn test_collect_columns() -> Result<()> {
603        let expr1 = Arc::new(Column::new("col1", 2)) as _;
604        let mut expected = HashSet::new();
605        expected.insert(Column::new("col1", 2));
606        assert_eq!(collect_columns(&expr1), expected);
607
608        let expr2 = Arc::new(Column::new("col2", 5)) as _;
609        let mut expected = HashSet::new();
610        expected.insert(Column::new("col2", 5));
611        assert_eq!(collect_columns(&expr2), expected);
612
613        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
614        let mut expected = HashSet::new();
615        expected.insert(Column::new("col1", 2));
616        expected.insert(Column::new("col2", 5));
617        assert_eq!(collect_columns(&expr3), expected);
618        Ok(())
619    }
620
621    #[test]
622    fn test_collect_predicate_constants_propagates_uniform_literal_value() -> Result<()> {
623        let schema = Arc::new(Schema::new(vec![Field::new(
624            "ticker",
625            DataType::Utf8,
626            false,
627        )]));
628        let predicate = binary(
629            col("ticker", schema.as_ref())?,
630            Operator::Eq,
631            lit(ScalarValue::Utf8(Some("NGJ26".to_string()))),
632            schema.as_ref(),
633        )?;
634        let eq_properties = EquivalenceProperties::new(schema);
635
636        let constants =
637            ConstExpr::collect_predicate_constants(&eq_properties, &predicate);
638
639        assert_eq!(constants.len(), 1);
640        assert_eq!(
641            constants[0].across_partitions,
642            AcrossPartitions::Uniform(Some(ScalarValue::Utf8(Some("NGJ26".to_string()))))
643        );
644
645        Ok(())
646    }
647}