datafusion_physical_expr/utils/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::SchemaRef;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use datafusion_physical_expr_common::sort_expr::LexOrdering;
37use itertools::Itertools;
38use petgraph::graph::NodeIndex;
39use petgraph::stable_graph::StableGraph;
40
41/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
42///
43/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
44pub fn split_conjunction(
45    predicate: &Arc<dyn PhysicalExpr>,
46) -> Vec<&Arc<dyn PhysicalExpr>> {
47    split_impl(Operator::And, predicate, vec![])
48}
49
50/// Create a conjunction of the given predicates.
51/// If the input is empty, return a literal true.
52/// If the input contains a single predicate, return the predicate.
53/// Otherwise, return a conjunction of the predicates (e.g. `a AND b AND c`).
54pub fn conjunction(
55    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
56) -> Arc<dyn PhysicalExpr> {
57    conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
58}
59
60/// Create a conjunction of the given predicates.
61/// If the input is empty or the return None.
62/// If the input contains a single predicate, return Some(predicate).
63/// Otherwise, return a Some(..) of a conjunction of the predicates (e.g. `Some(a AND b AND c)`).
64pub fn conjunction_opt(
65    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
66) -> Option<Arc<dyn PhysicalExpr>> {
67    predicates
68        .into_iter()
69        .fold(None, |acc, predicate| match acc {
70            None => Some(predicate),
71            Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
72        })
73}
74
75/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
76///
77/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
78pub fn split_disjunction(
79    predicate: &Arc<dyn PhysicalExpr>,
80) -> Vec<&Arc<dyn PhysicalExpr>> {
81    split_impl(Operator::Or, predicate, vec![])
82}
83
84fn split_impl<'a>(
85    operator: Operator,
86    predicate: &'a Arc<dyn PhysicalExpr>,
87    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
88) -> Vec<&'a Arc<dyn PhysicalExpr>> {
89    match predicate.as_any().downcast_ref::<BinaryExpr>() {
90        Some(binary) if binary.op() == &operator => {
91            let exprs = split_impl(operator, binary.left(), exprs);
92            split_impl(operator, binary.right(), exprs)
93        }
94        Some(_) | None => {
95            exprs.push(predicate);
96            exprs
97        }
98    }
99}
100
101/// This function maps back requirement after ProjectionExec
102/// to the Executor for its input.
103// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
104// This function changes requirement given according to ProjectionExec schema to the requirement
105// according to schema of input executor to the ProjectionExec.
106// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
107// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
108// This function will produce incorrect result (It will only emit single Column as a result).
109pub fn map_columns_before_projection(
110    parent_required: &[Arc<dyn PhysicalExpr>],
111    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
112) -> Vec<Arc<dyn PhysicalExpr>> {
113    if parent_required.is_empty() {
114        // No need to build mapping.
115        return vec![];
116    }
117    let column_mapping = proj_exprs
118        .iter()
119        .filter_map(|(expr, name)| {
120            expr.as_any()
121                .downcast_ref::<Column>()
122                .map(|column| (name.clone(), column.clone()))
123        })
124        .collect::<HashMap<_, _>>();
125    parent_required
126        .iter()
127        .filter_map(|r| {
128            r.as_any()
129                .downcast_ref::<Column>()
130                .and_then(|c| column_mapping.get(c.name()))
131        })
132        .map(|e| Arc::new(e.clone()) as _)
133        .collect()
134}
135
136/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
137/// `PhysicalSortExpr` sequence.
138pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
139    sequence: impl IntoIterator<Item = T>,
140) -> Vec<Arc<dyn PhysicalExpr>> {
141    sequence
142        .into_iter()
143        .map(|elem| Arc::clone(&elem.borrow().expr))
144        .collect()
145}
146
147/// This function finds the indices of `targets` within `items` using strict
148/// equality.
149pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
150    targets: impl IntoIterator<Item = T>,
151    items: &[Arc<dyn PhysicalExpr>],
152) -> Vec<usize> {
153    targets
154        .into_iter()
155        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
156        .collect()
157}
158
159pub type ExprTreeNode<T> = ExprContext<Option<T>>;
160
161/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
162/// DAG) by collecting identical expressions in one node. Caller specifies the node type
163/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
164/// the [`ExprTreeNode`] ancillary object.
165struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
166    // The resulting DAEG (expression DAG).
167    graph: StableGraph<T, usize>,
168    // A vector of visited expression nodes and their corresponding node indices.
169    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
170    // A function to convert an input expression node to T.
171    constructor: &'a F,
172}
173
174impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
175    // This method mutates an expression node by transforming it to a physical expression
176    // and adding it to the graph. The method returns the mutated expression node.
177    fn mutate(
178        &mut self,
179        mut node: ExprTreeNode<NodeIndex>,
180    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
181        // Get the expression associated with the input expression node.
182        let expr = &node.expr;
183
184        // Check if the expression has already been visited.
185        let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
186            // If the expression has been visited, return the corresponding node index.
187            Some((_, idx)) => *idx,
188            // If the expression has not been visited, add a new node to the graph and
189            // add edges to its child nodes. Add the visited expression to the vector
190            // of visited expressions and return the newly created node index.
191            None => {
192                let node_idx = self.graph.add_node((self.constructor)(&node)?);
193                for expr_node in node.children.iter() {
194                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
195                }
196                self.visited_plans.push((Arc::clone(expr), node_idx));
197                node_idx
198            }
199        };
200        // Set the data field of the input expression node to the corresponding node index.
201        node.data = Some(node_idx);
202        // Return the mutated expression node.
203        Ok(Transformed::yes(node))
204    }
205}
206
207// A function that builds a directed acyclic graph of physical expression trees.
208pub fn build_dag<T, F>(
209    expr: Arc<dyn PhysicalExpr>,
210    constructor: &F,
211) -> Result<(NodeIndex, StableGraph<T, usize>)>
212where
213    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
214{
215    // Create a new expression tree node from the input expression.
216    let init = ExprTreeNode::new_default(expr);
217    // Create a new `PhysicalExprDAEGBuilder` instance.
218    let mut builder = PhysicalExprDAEGBuilder {
219        graph: StableGraph::<T, usize>::new(),
220        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
221        constructor,
222    };
223    // Use the builder to transform the expression tree node into a DAG.
224    let root = init.transform_up(|node| builder.mutate(node)).data()?;
225    // Return a tuple containing the root node index and the DAG.
226    Ok((root.data.unwrap(), builder.graph))
227}
228
229/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
230pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
231    let mut columns = HashSet::<Column>::new();
232    expr.apply(|expr| {
233        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
234            columns.get_or_insert_owned(column);
235        }
236        Ok(TreeNodeRecursion::Continue)
237    })
238    // pre_visit always returns OK, so this will always too
239    .expect("no way to return error during recursion");
240    columns
241}
242
243/// Re-assign column indices referenced in predicate according to given schema.
244/// This may be helpful when dealing with projections.
245pub fn reassign_predicate_columns(
246    pred: Arc<dyn PhysicalExpr>,
247    schema: &SchemaRef,
248    ignore_not_found: bool,
249) -> Result<Arc<dyn PhysicalExpr>> {
250    pred.transform_down(|expr| {
251        let expr_any = expr.as_any();
252
253        if let Some(column) = expr_any.downcast_ref::<Column>() {
254            let index = match schema.index_of(column.name()) {
255                Ok(idx) => idx,
256                Err(_) if ignore_not_found => usize::MAX,
257                Err(e) => return Err(e.into()),
258            };
259            return Ok(Transformed::yes(Arc::new(Column::new(
260                column.name(),
261                index,
262            ))));
263        }
264        Ok(Transformed::no(expr))
265    })
266    .data()
267}
268
269/// Merge left and right sort expressions, checking for duplicates.
270pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering {
271    left.iter()
272        .cloned()
273        .chain(right.iter().cloned())
274        .unique()
275        .collect()
276}
277
278#[cfg(test)]
279pub(crate) mod tests {
280    use std::any::Any;
281    use std::fmt::{Display, Formatter};
282
283    use super::*;
284    use crate::expressions::{binary, cast, col, in_list, lit, Literal};
285
286    use arrow::array::{ArrayRef, Float32Array, Float64Array};
287    use arrow::datatypes::{DataType, Field, Schema};
288    use datafusion_common::{exec_err, DataFusionError, ScalarValue};
289    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
290    use datafusion_expr::{
291        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
292    };
293
294    use petgraph::visit::Bfs;
295
296    #[derive(Debug, Clone)]
297    pub struct TestScalarUDF {
298        pub(crate) signature: Signature,
299    }
300
301    impl TestScalarUDF {
302        pub fn new() -> Self {
303            use DataType::*;
304            Self {
305                signature: Signature::uniform(
306                    1,
307                    vec![Float64, Float32],
308                    Volatility::Immutable,
309                ),
310            }
311        }
312    }
313
314    impl ScalarUDFImpl for TestScalarUDF {
315        fn as_any(&self) -> &dyn Any {
316            self
317        }
318        fn name(&self) -> &str {
319            "test-scalar-udf"
320        }
321
322        fn signature(&self) -> &Signature {
323            &self.signature
324        }
325
326        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
327            let arg_type = &arg_types[0];
328
329            match arg_type {
330                DataType::Float32 => Ok(DataType::Float32),
331                _ => Ok(DataType::Float64),
332            }
333        }
334
335        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
336            Ok(input[0].sort_properties)
337        }
338
339        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
340            let args = ColumnarValue::values_to_arrays(&args.args)?;
341
342            let arr: ArrayRef = match args[0].data_type() {
343                DataType::Float64 => Arc::new({
344                    let arg = &args[0]
345                        .as_any()
346                        .downcast_ref::<Float64Array>()
347                        .ok_or_else(|| {
348                            DataFusionError::Internal(format!(
349                                "could not cast {} to {}",
350                                self.name(),
351                                std::any::type_name::<Float64Array>()
352                            ))
353                        })?;
354
355                    arg.iter()
356                        .map(|a| a.map(f64::floor))
357                        .collect::<Float64Array>()
358                }),
359                DataType::Float32 => Arc::new({
360                    let arg = &args[0]
361                        .as_any()
362                        .downcast_ref::<Float32Array>()
363                        .ok_or_else(|| {
364                            DataFusionError::Internal(format!(
365                                "could not cast {} to {}",
366                                self.name(),
367                                std::any::type_name::<Float32Array>()
368                            ))
369                        })?;
370
371                    arg.iter()
372                        .map(|a| a.map(f32::floor))
373                        .collect::<Float32Array>()
374                }),
375                other => {
376                    return exec_err!(
377                        "Unsupported data type {other:?} for function {}",
378                        self.name()
379                    );
380                }
381            };
382            Ok(ColumnarValue::Array(arr))
383        }
384    }
385
386    #[derive(Clone)]
387    struct DummyProperty {
388        expr_type: String,
389    }
390
391    /// This is a dummy node in the DAEG; it stores a reference to the actual
392    /// [PhysicalExpr] as well as a dummy property.
393    #[derive(Clone)]
394    struct PhysicalExprDummyNode {
395        pub expr: Arc<dyn PhysicalExpr>,
396        pub property: DummyProperty,
397    }
398
399    impl Display for PhysicalExprDummyNode {
400        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401            write!(f, "{}", self.expr)
402        }
403    }
404
405    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
406        let expr = Arc::clone(&node.expr);
407        let dummy_property = if expr.as_any().is::<BinaryExpr>() {
408            "Binary"
409        } else if expr.as_any().is::<Column>() {
410            "Column"
411        } else if expr.as_any().is::<Literal>() {
412            "Literal"
413        } else {
414            "Other"
415        }
416        .to_owned();
417        Ok(PhysicalExprDummyNode {
418            expr,
419            property: DummyProperty {
420                expr_type: dummy_property,
421            },
422        })
423    }
424
425    #[test]
426    fn test_build_dag() -> Result<()> {
427        let schema = Schema::new(vec![
428            Field::new("0", DataType::Int32, true),
429            Field::new("1", DataType::Int32, true),
430            Field::new("2", DataType::Int32, true),
431        ]);
432        let expr = binary(
433            cast(
434                binary(
435                    col("0", &schema)?,
436                    Operator::Plus,
437                    col("1", &schema)?,
438                    &schema,
439                )?,
440                &schema,
441                DataType::Int64,
442            )?,
443            Operator::Gt,
444            binary(
445                cast(col("2", &schema)?, &schema, DataType::Int64)?,
446                Operator::Plus,
447                lit(ScalarValue::Int64(Some(10))),
448                &schema,
449            )?,
450            &schema,
451        )?;
452        let mut vector_dummy_props = vec![];
453        let (root, graph) = build_dag(expr, &make_dummy_node)?;
454        let mut bfs = Bfs::new(&graph, root);
455        while let Some(node_index) = bfs.next(&graph) {
456            let node = &graph[node_index];
457            vector_dummy_props.push(node.property.clone());
458        }
459
460        assert_eq!(
461            vector_dummy_props
462                .iter()
463                .filter(|property| property.expr_type == "Binary")
464                .count(),
465            3
466        );
467        assert_eq!(
468            vector_dummy_props
469                .iter()
470                .filter(|property| property.expr_type == "Column")
471                .count(),
472            3
473        );
474        assert_eq!(
475            vector_dummy_props
476                .iter()
477                .filter(|property| property.expr_type == "Literal")
478                .count(),
479            1
480        );
481        assert_eq!(
482            vector_dummy_props
483                .iter()
484                .filter(|property| property.expr_type == "Other")
485                .count(),
486            2
487        );
488        Ok(())
489    }
490
491    #[test]
492    fn test_convert_to_expr() -> Result<()> {
493        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
494        let sort_expr = vec![PhysicalSortExpr {
495            expr: col("a", &schema)?,
496            options: Default::default(),
497        }];
498        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
499        Ok(())
500    }
501
502    #[test]
503    fn test_get_indices_of_exprs_strict() {
504        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
505            Arc::new(Column::new("a", 0)),
506            Arc::new(Column::new("b", 1)),
507            Arc::new(Column::new("c", 2)),
508            Arc::new(Column::new("d", 3)),
509        ];
510        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
511            Arc::new(Column::new("b", 1)),
512            Arc::new(Column::new("c", 2)),
513            Arc::new(Column::new("a", 0)),
514        ];
515        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
516        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
517    }
518
519    #[test]
520    fn test_reassign_predicate_columns_in_list() {
521        let int_field = Field::new("should_not_matter", DataType::Int64, true);
522        let dict_field = Field::new(
523            "id",
524            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
525            true,
526        );
527        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
528        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
529        let pred = in_list(
530            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
531            vec![lit(ScalarValue::Dictionary(
532                Box::new(DataType::Int32),
533                Box::new(ScalarValue::from("2")),
534            ))],
535            &false,
536            &schema_big,
537        )
538        .unwrap();
539
540        let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
541
542        let expected = in_list(
543            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
544            vec![lit(ScalarValue::Dictionary(
545                Box::new(DataType::Int32),
546                Box::new(ScalarValue::from("2")),
547            ))],
548            &false,
549            &schema_small,
550        )
551        .unwrap();
552
553        assert_eq!(actual.as_ref(), expected.as_ref());
554    }
555
556    #[test]
557    fn test_collect_columns() -> Result<()> {
558        let expr1 = Arc::new(Column::new("col1", 2)) as _;
559        let mut expected = HashSet::new();
560        expected.insert(Column::new("col1", 2));
561        assert_eq!(collect_columns(&expr1), expected);
562
563        let expr2 = Arc::new(Column::new("col2", 5)) as _;
564        let mut expected = HashSet::new();
565        expected.insert(Column::new("col2", 5));
566        assert_eq!(collect_columns(&expr2), expected);
567
568        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
569        let mut expected = HashSet::new();
570        expected.insert(Column::new("col1", 2));
571        expected.insert(Column::new("col2", 5));
572        assert_eq!(collect_columns(&expr3), expected);
573        Ok(())
574    }
575}