Skip to main content

datafusion_expr/expr_rewriter/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression rewriter
19
20use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::TableReference;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_common::{Column, DFSchema, Result};
33
34mod guarantees;
35pub use guarantees::GuaranteeRewriter;
36pub use guarantees::rewrite_with_guarantees;
37pub use guarantees::rewrite_with_guarantees_map;
38mod order_by;
39
40pub use order_by::rewrite_sort_cols_by_aggs;
41
42/// Trait for rewriting [`Expr`]s into function calls.
43///
44/// This trait is used with `FunctionRegistry::register_function_rewrite` to
45/// to evaluating `Expr`s using functions that may not be built in to DataFusion
46///
47/// For example, concatenating arrays `a || b` is represented as
48/// `Operator::ArrowAt`, but can be implemented by calling a function
49/// `array_concat` from the `functions-nested` crate.
50// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
51pub trait FunctionRewrite: Debug {
52    /// Return a human readable name for this rewrite
53    fn name(&self) -> &str;
54
55    /// Potentially rewrite `expr` to some other expression
56    ///
57    /// Note that recursion is handled by the caller -- this method should only
58    /// handle `expr`, not recurse to its children.
59    fn rewrite(
60        &self,
61        expr: Expr,
62        schema: &DFSchema,
63        config: &ConfigOptions,
64    ) -> Result<Transformed<Expr>>;
65}
66
67/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
68/// in the `expr` expression tree.
69pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
70    expr.transform(|expr| {
71        Ok({
72            if let Expr::Column(c) = expr {
73                let col = LogicalPlanBuilder::normalize(plan, c)?;
74                Transformed::yes(Expr::Column(col))
75            } else {
76                Transformed::no(expr)
77            }
78        })
79    })
80    .data()
81}
82
83/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
84pub fn normalize_col_with_schemas_and_ambiguity_check(
85    expr: Expr,
86    schemas: &[&[&DFSchema]],
87    using_columns: &[HashSet<Column>],
88) -> Result<Expr> {
89    // Normalize column inside Unnest
90    if let Expr::Unnest(Unnest { expr }) = expr {
91        let e = normalize_col_with_schemas_and_ambiguity_check(
92            expr.as_ref().clone(),
93            schemas,
94            using_columns,
95        )?;
96        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
97    }
98
99    expr.transform(|expr| {
100        Ok({
101            if let Expr::Column(c) = expr {
102                let col =
103                    c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
104                Transformed::yes(Expr::Column(col))
105            } else {
106                Transformed::no(expr)
107            }
108        })
109    })
110    .data()
111}
112
113/// Recursively normalize all [`Column`] expressions in a list of expression trees
114pub fn normalize_cols(
115    exprs: impl IntoIterator<Item = impl Into<Expr>>,
116    plan: &LogicalPlan,
117) -> Result<Vec<Expr>> {
118    exprs
119        .into_iter()
120        .map(|e| normalize_col(e.into(), plan))
121        .collect()
122}
123
124pub fn normalize_sorts(
125    sorts: impl IntoIterator<Item = impl Into<Sort>>,
126    plan: &LogicalPlan,
127) -> Result<Vec<Sort>> {
128    sorts
129        .into_iter()
130        .map(|e| {
131            let sort = e.into();
132            normalize_col(sort.expr, plan)
133                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
134        })
135        .collect()
136}
137
138/// Recursively replace all [`Column`] expressions in a given expression tree with
139/// `Column` expressions provided by the hash map argument.
140pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
141    expr.transform(|expr| {
142        Ok({
143            if let Expr::Column(c) = &expr {
144                match replace_map.get(c) {
145                    Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
146                    None => Transformed::no(expr),
147                }
148            } else {
149                Transformed::no(expr)
150            }
151        })
152    })
153    .data()
154}
155
156/// Recursively 'unnormalize' (remove all qualifiers) from an
157/// expression tree.
158///
159/// For example, if there were expressions like `foo.bar` this would
160/// rewrite it to just `bar`.
161pub fn unnormalize_col(expr: Expr) -> Expr {
162    expr.transform(|expr| {
163        Ok({
164            if let Expr::Column(c) = expr {
165                let col = Column::new_unqualified(c.name);
166                Transformed::yes(Expr::Column(col))
167            } else {
168                Transformed::no(expr)
169            }
170        })
171    })
172    .data()
173    .expect("Unnormalize is infallible")
174}
175
176/// Create a Column from the Scalar Expr
177pub fn create_col_from_scalar_expr(
178    scalar_expr: &Expr,
179    subqry_alias: String,
180) -> Result<Column> {
181    match scalar_expr {
182        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
183            Some::<TableReference>(subqry_alias.into()),
184            name,
185        )),
186        Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
187        _ => {
188            let scalar_column = scalar_expr.schema_name().to_string();
189            Ok(Column::new(
190                Some::<TableReference>(subqry_alias.into()),
191                scalar_column,
192            ))
193        }
194    }
195}
196
197/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
198#[inline]
199pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
200    exprs.into_iter().map(unnormalize_col).collect()
201}
202
203/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
204/// in the expression tree.
205pub fn strip_outer_reference(expr: Expr) -> Expr {
206    expr.transform(|expr| {
207        Ok({
208            if let Expr::OuterReferenceColumn(_, col) = expr {
209                Transformed::yes(Expr::Column(col))
210            } else {
211                Transformed::no(expr)
212            }
213        })
214    })
215    .data()
216    .expect("strip_outer_reference is infallible")
217}
218
219/// Returns plan with expressions coerced to types compatible with
220/// schema types
221pub fn coerce_plan_expr_for_schema(
222    plan: LogicalPlan,
223    schema: &DFSchema,
224) -> Result<LogicalPlan> {
225    match plan {
226        // special case Projection to avoid adding multiple projections
227        LogicalPlan::Projection(Projection { expr, input, .. }) => {
228            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
229            let projection = Projection::try_new(new_exprs, input)?;
230            Ok(LogicalPlan::Projection(projection))
231        }
232        _ => {
233            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
234            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
235            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
236            if add_project {
237                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
238                Ok(LogicalPlan::Projection(projection))
239            } else {
240                Ok(plan)
241            }
242        }
243    }
244}
245
246fn coerce_exprs_for_schema(
247    exprs: Vec<Expr>,
248    src_schema: &DFSchema,
249    dst_schema: &DFSchema,
250) -> Result<Vec<Expr>> {
251    exprs
252        .into_iter()
253        .enumerate()
254        .map(|(idx, expr)| {
255            let new_type = dst_schema.field(idx).data_type();
256            if new_type != &expr.get_type(src_schema)? {
257                match expr {
258                    Expr::Alias(Alias { expr, name, .. }) => {
259                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
260                    }
261                    #[expect(deprecated)]
262                    Expr::Wildcard { .. } => Ok(expr),
263                    _ => {
264                        match expr {
265                            // maintain the original name when casting a column, to avoid the
266                            // tablename being added to it when not explicitly set by the query
267                            // (see: https://github.com/apache/datafusion/issues/18818)
268                            Expr::Column(ref column) => {
269                                let name = column.name().to_owned();
270                                Ok(expr.cast_to(new_type, src_schema)?.alias(name))
271                            }
272                            _ => Ok(expr.cast_to(new_type, src_schema)?),
273                        }
274                    }
275                }
276            } else {
277                Ok(expr)
278            }
279        })
280        .collect::<Result<_>>()
281}
282
283/// Recursively un-alias an expressions
284#[inline]
285pub fn unalias(expr: Expr) -> Expr {
286    match expr {
287        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
288        _ => expr,
289    }
290}
291
292/// Handles ensuring the name of rewritten expressions is not changed.
293///
294/// This is important when optimizing plans to ensure the output
295/// schema of plan nodes don't change after optimization.
296/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
297/// expression should be preserved: `3 as "1 + 2"`
298///
299/// See <https://github.com/apache/datafusion/issues/3555> for details
300pub struct NamePreserver {
301    use_alias: bool,
302}
303
304/// If the qualified name of an expression is remembered, it will be preserved
305/// when rewriting the expression
306#[derive(Debug)]
307pub enum SavedName {
308    /// Saved qualified name to be preserved
309    Saved {
310        relation: Option<TableReference>,
311        name: String,
312    },
313    /// Name is not preserved
314    None,
315}
316
317impl NamePreserver {
318    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
319    pub fn new(plan: &LogicalPlan) -> Self {
320        Self {
321            // The expressions of these plans do not contribute to their output schema,
322            // so there is no need to preserve expression names to prevent a schema change.
323            use_alias: !matches!(
324                plan,
325                LogicalPlan::Filter(_)
326                    | LogicalPlan::Join(_)
327                    | LogicalPlan::TableScan(_)
328                    | LogicalPlan::Limit(_)
329                    | LogicalPlan::Statement(_)
330            ),
331        }
332    }
333
334    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
335    ///
336    /// This will use aliases
337    pub fn new_for_projection() -> Self {
338        Self { use_alias: true }
339    }
340
341    pub fn save(&self, expr: &Expr) -> SavedName {
342        if self.use_alias {
343            let (relation, name) = expr.qualified_name();
344            SavedName::Saved { relation, name }
345        } else {
346            SavedName::None
347        }
348    }
349}
350
351impl SavedName {
352    /// Ensures the qualified name of the rewritten expression is preserved
353    pub fn restore(self, expr: Expr) -> Expr {
354        match self {
355            SavedName::Saved { relation, name } => {
356                let (new_relation, new_name) = expr.qualified_name();
357                if new_relation != relation || new_name != name {
358                    expr.alias_qualified(relation, name)
359                } else {
360                    expr
361                }
362            }
363            SavedName::None => expr,
364        }
365    }
366}
367
368#[cfg(test)]
369mod test {
370    use std::ops::Add;
371
372    use super::*;
373    use crate::literal::lit_with_metadata;
374    use crate::{Cast, col, lit};
375    use arrow::datatypes::{DataType, Field, Schema};
376    use datafusion_common::ScalarValue;
377    use datafusion_common::tree_node::TreeNodeRewriter;
378
379    #[derive(Default)]
380    struct RecordingRewriter {
381        v: Vec<String>,
382    }
383
384    impl TreeNodeRewriter for RecordingRewriter {
385        type Node = Expr;
386
387        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
388            self.v.push(format!("Previsited {expr}"));
389            Ok(Transformed::no(expr))
390        }
391
392        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
393            self.v.push(format!("Mutated {expr}"));
394            Ok(Transformed::no(expr))
395        }
396    }
397
398    #[test]
399    fn rewriter_rewrite() {
400        // rewrites all "foo" string literals to "bar"
401        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
402            match expr {
403                Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
404                    let utf8_val = if utf8_val == "foo" {
405                        "bar".to_string()
406                    } else {
407                        utf8_val
408                    };
409                    Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata)))
410                }
411                // otherwise, return None
412                _ => Ok(Transformed::no(expr)),
413            }
414        };
415
416        // rewrites "foo" --> "bar"
417        let rewritten = col("state")
418            .eq(lit("foo"))
419            .transform(transformer)
420            .data()
421            .unwrap();
422        assert_eq!(rewritten, col("state").eq(lit("bar")));
423
424        // doesn't rewrite
425        let rewritten = col("state")
426            .eq(lit("baz"))
427            .transform(transformer)
428            .data()
429            .unwrap();
430        assert_eq!(rewritten, col("state").eq(lit("baz")));
431    }
432
433    #[test]
434    fn normalize_cols() {
435        let expr = col("a") + col("b") + col("c");
436
437        // Schemas with some matching and some non matching cols
438        let schema_a = make_schema_with_empty_metadata(
439            vec![Some("tableA".into()), Some("tableA".into())],
440            vec!["a", "aa"],
441        );
442        let schema_c = make_schema_with_empty_metadata(
443            vec![Some("tableC".into()), Some("tableC".into())],
444            vec!["cc", "c"],
445        );
446        let schema_b =
447            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
448        // non matching
449        let schema_f = make_schema_with_empty_metadata(
450            vec![Some("tableC".into()), Some("tableC".into())],
451            vec!["f", "ff"],
452        );
453        let schemas = [schema_c, schema_f, schema_b, schema_a];
454        let schemas = schemas.iter().collect::<Vec<_>>();
455
456        let normalized_expr =
457            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
458                .unwrap();
459        assert_eq!(
460            normalized_expr,
461            col("tableA.a") + col("tableB.b") + col("tableC.c")
462        );
463    }
464
465    #[test]
466    fn normalize_cols_non_exist() {
467        // test normalizing columns when the name doesn't exist
468        let expr = col("a") + col("b");
469        let schema_a =
470            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
471        let schemas = [schema_a];
472        let schemas = schemas.iter().collect::<Vec<_>>();
473
474        let error =
475            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
476                .unwrap_err()
477                .strip_backtrace();
478        let expected = "Schema error: No field named b. \
479            Valid fields are \"tableA\".a.";
480        assert_eq!(error, expected);
481    }
482
483    #[test]
484    fn unnormalize_cols() {
485        let expr = col("tableA.a") + col("tableB.b");
486        let unnormalized_expr = unnormalize_col(expr);
487        assert_eq!(unnormalized_expr, col("a") + col("b"));
488    }
489
490    fn make_schema_with_empty_metadata(
491        qualifiers: Vec<Option<TableReference>>,
492        fields: Vec<&str>,
493    ) -> DFSchema {
494        let fields = fields
495            .iter()
496            .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false)))
497            .collect::<Vec<_>>();
498        let schema = Arc::new(Schema::new(fields));
499        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
500    }
501
502    #[test]
503    fn rewriter_visit() {
504        let mut rewriter = RecordingRewriter::default();
505        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
506
507        assert_eq!(
508            rewriter.v,
509            vec![
510                "Previsited state = Utf8(\"CO\")",
511                "Previsited state",
512                "Mutated state",
513                "Previsited Utf8(\"CO\")",
514                "Mutated Utf8(\"CO\")",
515                "Mutated state = Utf8(\"CO\")"
516            ]
517        )
518    }
519
520    #[test]
521    fn test_rewrite_preserving_name() {
522        test_rewrite(col("a"), col("a"));
523
524        test_rewrite(col("a"), col("b"));
525
526        // cast data types
527        test_rewrite(
528            col("a"),
529            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
530        );
531
532        // change literal type from i32 to i64
533        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
534
535        // test preserve qualifier
536        test_rewrite(
537            Expr::Column(Column::new(Some("test"), "a")),
538            Expr::Column(Column::new_unqualified("test.a")),
539        );
540        test_rewrite(
541            Expr::Column(Column::new_unqualified("test.a")),
542            Expr::Column(Column::new(Some("test"), "a")),
543        );
544    }
545
546    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
547    /// by using the `NamePreserver`
548    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
549        struct TestRewriter {
550            rewrite_to: Expr,
551        }
552
553        impl TreeNodeRewriter for TestRewriter {
554            type Node = Expr;
555
556            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
557                Ok(Transformed::yes(self.rewrite_to.clone()))
558            }
559        }
560
561        let mut rewriter = TestRewriter {
562            rewrite_to: rewrite_to.clone(),
563        };
564        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
565        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
566        let new_expr = saved_name.restore(new_expr);
567
568        let original_name = expr_from.qualified_name();
569        let new_name = new_expr.qualified_name();
570        assert_eq!(
571            original_name, new_name,
572            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
573        )
574    }
575}