datafusion_physical_expr/utils/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::tree_node::ExprContext;
26use crate::PhysicalExpr;
27use crate::PhysicalSortExpr;
28
29use arrow::datatypes::Schema;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{HashMap, HashSet, Result};
34use datafusion_expr::Operator;
35
36use petgraph::graph::NodeIndex;
37use petgraph::stable_graph::StableGraph;
38
39/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
40///
41/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
42pub fn split_conjunction(
43    predicate: &Arc<dyn PhysicalExpr>,
44) -> Vec<&Arc<dyn PhysicalExpr>> {
45    split_impl(Operator::And, predicate, vec![])
46}
47
48/// Create a conjunction of the given predicates.
49/// If the input is empty, return a literal true.
50/// If the input contains a single predicate, return the predicate.
51/// Otherwise, return a conjunction of the predicates (e.g. `a AND b AND c`).
52pub fn conjunction(
53    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
54) -> Arc<dyn PhysicalExpr> {
55    conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
56}
57
58/// Create a conjunction of the given predicates.
59/// If the input is empty or the return None.
60/// If the input contains a single predicate, return Some(predicate).
61/// Otherwise, return a Some(..) of a conjunction of the predicates (e.g. `Some(a AND b AND c)`).
62pub fn conjunction_opt(
63    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
64) -> Option<Arc<dyn PhysicalExpr>> {
65    predicates
66        .into_iter()
67        .fold(None, |acc, predicate| match acc {
68            None => Some(predicate),
69            Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
70        })
71}
72
73/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
74///
75/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
76pub fn split_disjunction(
77    predicate: &Arc<dyn PhysicalExpr>,
78) -> Vec<&Arc<dyn PhysicalExpr>> {
79    split_impl(Operator::Or, predicate, vec![])
80}
81
82fn split_impl<'a>(
83    operator: Operator,
84    predicate: &'a Arc<dyn PhysicalExpr>,
85    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
86) -> Vec<&'a Arc<dyn PhysicalExpr>> {
87    match predicate.as_any().downcast_ref::<BinaryExpr>() {
88        Some(binary) if binary.op() == &operator => {
89            let exprs = split_impl(operator, binary.left(), exprs);
90            split_impl(operator, binary.right(), exprs)
91        }
92        Some(_) | None => {
93            exprs.push(predicate);
94            exprs
95        }
96    }
97}
98
99/// This function maps back requirement after ProjectionExec
100/// to the Executor for its input.
101// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
102// This function changes requirement given according to ProjectionExec schema to the requirement
103// according to schema of input executor to the ProjectionExec.
104// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
105// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
106// This function will produce incorrect result (It will only emit single Column as a result).
107pub fn map_columns_before_projection(
108    parent_required: &[Arc<dyn PhysicalExpr>],
109    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
110) -> Vec<Arc<dyn PhysicalExpr>> {
111    if parent_required.is_empty() {
112        // No need to build mapping.
113        return vec![];
114    }
115    let column_mapping = proj_exprs
116        .iter()
117        .filter_map(|(expr, name)| {
118            expr.as_any()
119                .downcast_ref::<Column>()
120                .map(|column| (name.clone(), column.clone()))
121        })
122        .collect::<HashMap<_, _>>();
123    parent_required
124        .iter()
125        .filter_map(|r| {
126            r.as_any()
127                .downcast_ref::<Column>()
128                .and_then(|c| column_mapping.get(c.name()))
129        })
130        .map(|e| Arc::new(e.clone()) as _)
131        .collect()
132}
133
134/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
135/// `PhysicalSortExpr` sequence.
136pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
137    sequence: impl IntoIterator<Item = T>,
138) -> Vec<Arc<dyn PhysicalExpr>> {
139    sequence
140        .into_iter()
141        .map(|elem| Arc::clone(&elem.borrow().expr))
142        .collect()
143}
144
145/// This function finds the indices of `targets` within `items` using strict
146/// equality.
147pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
148    targets: impl IntoIterator<Item = T>,
149    items: &[Arc<dyn PhysicalExpr>],
150) -> Vec<usize> {
151    targets
152        .into_iter()
153        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
154        .collect()
155}
156
157pub type ExprTreeNode<T> = ExprContext<Option<T>>;
158
159/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
160/// DAG) by collecting identical expressions in one node. Caller specifies the node type
161/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
162/// the [`ExprTreeNode`] ancillary object.
163struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
164    // The resulting DAEG (expression DAG).
165    graph: StableGraph<T, usize>,
166    // A vector of visited expression nodes and their corresponding node indices.
167    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
168    // A function to convert an input expression node to T.
169    constructor: &'a F,
170}
171
172impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
173    // This method mutates an expression node by transforming it to a physical expression
174    // and adding it to the graph. The method returns the mutated expression node.
175    fn mutate(
176        &mut self,
177        mut node: ExprTreeNode<NodeIndex>,
178    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
179        // Get the expression associated with the input expression node.
180        let expr = &node.expr;
181
182        // Check if the expression has already been visited.
183        let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
184            // If the expression has been visited, return the corresponding node index.
185            Some((_, idx)) => *idx,
186            // If the expression has not been visited, add a new node to the graph and
187            // add edges to its child nodes. Add the visited expression to the vector
188            // of visited expressions and return the newly created node index.
189            None => {
190                let node_idx = self.graph.add_node((self.constructor)(&node)?);
191                for expr_node in node.children.iter() {
192                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
193                }
194                self.visited_plans.push((Arc::clone(expr), node_idx));
195                node_idx
196            }
197        };
198        // Set the data field of the input expression node to the corresponding node index.
199        node.data = Some(node_idx);
200        // Return the mutated expression node.
201        Ok(Transformed::yes(node))
202    }
203}
204
205// A function that builds a directed acyclic graph of physical expression trees.
206pub fn build_dag<T, F>(
207    expr: Arc<dyn PhysicalExpr>,
208    constructor: &F,
209) -> Result<(NodeIndex, StableGraph<T, usize>)>
210where
211    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
212{
213    // Create a new expression tree node from the input expression.
214    let init = ExprTreeNode::new_default(expr);
215    // Create a new `PhysicalExprDAEGBuilder` instance.
216    let mut builder = PhysicalExprDAEGBuilder {
217        graph: StableGraph::<T, usize>::new(),
218        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
219        constructor,
220    };
221    // Use the builder to transform the expression tree node into a DAG.
222    let root = init.transform_up(|node| builder.mutate(node)).data()?;
223    // Return a tuple containing the root node index and the DAG.
224    Ok((root.data.unwrap(), builder.graph))
225}
226
227/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
228pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
229    let mut columns = HashSet::<Column>::new();
230    expr.apply(|expr| {
231        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
232            columns.get_or_insert_owned(column);
233        }
234        Ok(TreeNodeRecursion::Continue)
235    })
236    // pre_visit always returns OK, so this will always too
237    .expect("no way to return error during recursion");
238    columns
239}
240
241/// Re-assign column indices referenced in predicate according to given schema.
242/// This may be helpful when dealing with projections.
243pub fn reassign_predicate_columns(
244    pred: Arc<dyn PhysicalExpr>,
245    schema: &Schema,
246    ignore_not_found: bool,
247) -> Result<Arc<dyn PhysicalExpr>> {
248    pred.transform_down(|expr| {
249        let expr_any = expr.as_any();
250
251        if let Some(column) = expr_any.downcast_ref::<Column>() {
252            let index = match schema.index_of(column.name()) {
253                Ok(idx) => idx,
254                Err(_) if ignore_not_found => usize::MAX,
255                Err(e) => return Err(e.into()),
256            };
257            return Ok(Transformed::yes(Arc::new(Column::new(
258                column.name(),
259                index,
260            ))));
261        }
262        Ok(Transformed::no(expr))
263    })
264    .data()
265}
266
267#[cfg(test)]
268pub(crate) mod tests {
269    use std::any::Any;
270    use std::fmt::{Display, Formatter};
271
272    use super::*;
273    use crate::expressions::{binary, cast, col, in_list, lit, Literal};
274
275    use arrow::array::{ArrayRef, Float32Array, Float64Array};
276    use arrow::datatypes::{DataType, Field, Schema};
277    use datafusion_common::{exec_err, DataFusionError, ScalarValue};
278    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
279    use datafusion_expr::{
280        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
281    };
282
283    use petgraph::visit::Bfs;
284
285    #[derive(Debug, Clone)]
286    pub struct TestScalarUDF {
287        pub(crate) signature: Signature,
288    }
289
290    impl TestScalarUDF {
291        pub fn new() -> Self {
292            use DataType::*;
293            Self {
294                signature: Signature::uniform(
295                    1,
296                    vec![Float64, Float32],
297                    Volatility::Immutable,
298                ),
299            }
300        }
301    }
302
303    impl ScalarUDFImpl for TestScalarUDF {
304        fn as_any(&self) -> &dyn Any {
305            self
306        }
307        fn name(&self) -> &str {
308            "test-scalar-udf"
309        }
310
311        fn signature(&self) -> &Signature {
312            &self.signature
313        }
314
315        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
316            let arg_type = &arg_types[0];
317
318            match arg_type {
319                DataType::Float32 => Ok(DataType::Float32),
320                _ => Ok(DataType::Float64),
321            }
322        }
323
324        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
325            Ok(input[0].sort_properties)
326        }
327
328        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
329            let args = ColumnarValue::values_to_arrays(&args.args)?;
330
331            let arr: ArrayRef = match args[0].data_type() {
332                DataType::Float64 => Arc::new({
333                    let arg = &args[0]
334                        .as_any()
335                        .downcast_ref::<Float64Array>()
336                        .ok_or_else(|| {
337                            DataFusionError::Internal(format!(
338                                "could not cast {} to {}",
339                                self.name(),
340                                std::any::type_name::<Float64Array>()
341                            ))
342                        })?;
343
344                    arg.iter()
345                        .map(|a| a.map(f64::floor))
346                        .collect::<Float64Array>()
347                }),
348                DataType::Float32 => Arc::new({
349                    let arg = &args[0]
350                        .as_any()
351                        .downcast_ref::<Float32Array>()
352                        .ok_or_else(|| {
353                            DataFusionError::Internal(format!(
354                                "could not cast {} to {}",
355                                self.name(),
356                                std::any::type_name::<Float32Array>()
357                            ))
358                        })?;
359
360                    arg.iter()
361                        .map(|a| a.map(f32::floor))
362                        .collect::<Float32Array>()
363                }),
364                other => {
365                    return exec_err!(
366                        "Unsupported data type {other:?} for function {}",
367                        self.name()
368                    );
369                }
370            };
371            Ok(ColumnarValue::Array(arr))
372        }
373    }
374
375    #[derive(Clone)]
376    struct DummyProperty {
377        expr_type: String,
378    }
379
380    /// This is a dummy node in the DAEG; it stores a reference to the actual
381    /// [PhysicalExpr] as well as a dummy property.
382    #[derive(Clone)]
383    struct PhysicalExprDummyNode {
384        pub expr: Arc<dyn PhysicalExpr>,
385        pub property: DummyProperty,
386    }
387
388    impl Display for PhysicalExprDummyNode {
389        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
390            write!(f, "{}", self.expr)
391        }
392    }
393
394    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
395        let expr = Arc::clone(&node.expr);
396        let dummy_property = if expr.as_any().is::<BinaryExpr>() {
397            "Binary"
398        } else if expr.as_any().is::<Column>() {
399            "Column"
400        } else if expr.as_any().is::<Literal>() {
401            "Literal"
402        } else {
403            "Other"
404        }
405        .to_owned();
406        Ok(PhysicalExprDummyNode {
407            expr,
408            property: DummyProperty {
409                expr_type: dummy_property,
410            },
411        })
412    }
413
414    #[test]
415    fn test_build_dag() -> Result<()> {
416        let schema = Schema::new(vec![
417            Field::new("0", DataType::Int32, true),
418            Field::new("1", DataType::Int32, true),
419            Field::new("2", DataType::Int32, true),
420        ]);
421        let expr = binary(
422            cast(
423                binary(
424                    col("0", &schema)?,
425                    Operator::Plus,
426                    col("1", &schema)?,
427                    &schema,
428                )?,
429                &schema,
430                DataType::Int64,
431            )?,
432            Operator::Gt,
433            binary(
434                cast(col("2", &schema)?, &schema, DataType::Int64)?,
435                Operator::Plus,
436                lit(ScalarValue::Int64(Some(10))),
437                &schema,
438            )?,
439            &schema,
440        )?;
441        let mut vector_dummy_props = vec![];
442        let (root, graph) = build_dag(expr, &make_dummy_node)?;
443        let mut bfs = Bfs::new(&graph, root);
444        while let Some(node_index) = bfs.next(&graph) {
445            let node = &graph[node_index];
446            vector_dummy_props.push(node.property.clone());
447        }
448
449        assert_eq!(
450            vector_dummy_props
451                .iter()
452                .filter(|property| property.expr_type == "Binary")
453                .count(),
454            3
455        );
456        assert_eq!(
457            vector_dummy_props
458                .iter()
459                .filter(|property| property.expr_type == "Column")
460                .count(),
461            3
462        );
463        assert_eq!(
464            vector_dummy_props
465                .iter()
466                .filter(|property| property.expr_type == "Literal")
467                .count(),
468            1
469        );
470        assert_eq!(
471            vector_dummy_props
472                .iter()
473                .filter(|property| property.expr_type == "Other")
474                .count(),
475            2
476        );
477        Ok(())
478    }
479
480    #[test]
481    fn test_convert_to_expr() -> Result<()> {
482        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
483        let sort_expr = vec![PhysicalSortExpr {
484            expr: col("a", &schema)?,
485            options: Default::default(),
486        }];
487        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
488        Ok(())
489    }
490
491    #[test]
492    fn test_get_indices_of_exprs_strict() {
493        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
494            Arc::new(Column::new("a", 0)),
495            Arc::new(Column::new("b", 1)),
496            Arc::new(Column::new("c", 2)),
497            Arc::new(Column::new("d", 3)),
498        ];
499        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
500            Arc::new(Column::new("b", 1)),
501            Arc::new(Column::new("c", 2)),
502            Arc::new(Column::new("a", 0)),
503        ];
504        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
505        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
506    }
507
508    #[test]
509    fn test_reassign_predicate_columns_in_list() {
510        let int_field = Field::new("should_not_matter", DataType::Int64, true);
511        let dict_field = Field::new(
512            "id",
513            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
514            true,
515        );
516        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
517        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
518        let pred = in_list(
519            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
520            vec![lit(ScalarValue::Dictionary(
521                Box::new(DataType::Int32),
522                Box::new(ScalarValue::from("2")),
523            ))],
524            &false,
525            &schema_big,
526        )
527        .unwrap();
528
529        let actual = reassign_predicate_columns(pred, &schema_small, false).unwrap();
530
531        let expected = in_list(
532            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
533            vec![lit(ScalarValue::Dictionary(
534                Box::new(DataType::Int32),
535                Box::new(ScalarValue::from("2")),
536            ))],
537            &false,
538            &schema_small,
539        )
540        .unwrap();
541
542        assert_eq!(actual.as_ref(), expected.as_ref());
543    }
544
545    #[test]
546    fn test_collect_columns() -> Result<()> {
547        let expr1 = Arc::new(Column::new("col1", 2)) as _;
548        let mut expected = HashSet::new();
549        expected.insert(Column::new("col1", 2));
550        assert_eq!(collect_columns(&expr1), expected);
551
552        let expr2 = Arc::new(Column::new("col2", 5)) as _;
553        let mut expected = HashSet::new();
554        expected.insert(Column::new("col2", 5));
555        assert_eq!(collect_columns(&expr2), expected);
556
557        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
558        let mut expected = HashSet::new();
559        expected.insert(Column::new("col1", 2));
560        expected.insert(Column::new("col2", 5));
561        assert_eq!(collect_columns(&expr3), expected);
562        Ok(())
563    }
564}