datafusion_expr/expr_rewriter/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression rewriter
19
20use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::TableReference;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_common::{Column, DFSchema, Result};
33
34mod guarantees;
35pub use guarantees::GuaranteeRewriter;
36pub use guarantees::rewrite_with_guarantees;
37pub use guarantees::rewrite_with_guarantees_map;
38mod order_by;
39
40pub use order_by::rewrite_sort_cols_by_aggs;
41
42/// Trait for rewriting [`Expr`]s into function calls.
43///
44/// This trait is used with `FunctionRegistry::register_function_rewrite` to
45/// to evaluating `Expr`s using functions that may not be built in to DataFusion
46///
47/// For example, concatenating arrays `a || b` is represented as
48/// `Operator::ArrowAt`, but can be implemented by calling a function
49/// `array_concat` from the `functions-nested` crate.
50// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
51pub trait FunctionRewrite: Debug {
52    /// Return a human readable name for this rewrite
53    fn name(&self) -> &str;
54
55    /// Potentially rewrite `expr` to some other expression
56    ///
57    /// Note that recursion is handled by the caller -- this method should only
58    /// handle `expr`, not recurse to its children.
59    fn rewrite(
60        &self,
61        expr: Expr,
62        schema: &DFSchema,
63        config: &ConfigOptions,
64    ) -> Result<Transformed<Expr>>;
65}
66
67/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
68/// in the `expr` expression tree.
69pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
70    expr.transform(|expr| {
71        Ok({
72            if let Expr::Column(c) = expr {
73                let col = LogicalPlanBuilder::normalize(plan, c)?;
74                Transformed::yes(Expr::Column(col))
75            } else {
76                Transformed::no(expr)
77            }
78        })
79    })
80    .data()
81}
82
83/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
84pub fn normalize_col_with_schemas_and_ambiguity_check(
85    expr: Expr,
86    schemas: &[&[&DFSchema]],
87    using_columns: &[HashSet<Column>],
88) -> Result<Expr> {
89    // Normalize column inside Unnest
90    if let Expr::Unnest(Unnest { expr }) = expr {
91        let e = normalize_col_with_schemas_and_ambiguity_check(
92            expr.as_ref().clone(),
93            schemas,
94            using_columns,
95        )?;
96        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
97    }
98
99    expr.transform(|expr| {
100        Ok({
101            if let Expr::Column(c) = expr {
102                let col =
103                    c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
104                Transformed::yes(Expr::Column(col))
105            } else {
106                Transformed::no(expr)
107            }
108        })
109    })
110    .data()
111}
112
113/// Recursively normalize all [`Column`] expressions in a list of expression trees
114pub fn normalize_cols(
115    exprs: impl IntoIterator<Item = impl Into<Expr>>,
116    plan: &LogicalPlan,
117) -> Result<Vec<Expr>> {
118    exprs
119        .into_iter()
120        .map(|e| normalize_col(e.into(), plan))
121        .collect()
122}
123
124pub fn normalize_sorts(
125    sorts: impl IntoIterator<Item = impl Into<Sort>>,
126    plan: &LogicalPlan,
127) -> Result<Vec<Sort>> {
128    sorts
129        .into_iter()
130        .map(|e| {
131            let sort = e.into();
132            normalize_col(sort.expr, plan)
133                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
134        })
135        .collect()
136}
137
138/// Recursively replace all [`Column`] expressions in a given expression tree with
139/// `Column` expressions provided by the hash map argument.
140pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
141    expr.transform(|expr| {
142        Ok({
143            if let Expr::Column(c) = &expr {
144                match replace_map.get(c) {
145                    Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
146                    None => Transformed::no(expr),
147                }
148            } else {
149                Transformed::no(expr)
150            }
151        })
152    })
153    .data()
154}
155
156/// Recursively 'unnormalize' (remove all qualifiers) from an
157/// expression tree.
158///
159/// For example, if there were expressions like `foo.bar` this would
160/// rewrite it to just `bar`.
161pub fn unnormalize_col(expr: Expr) -> Expr {
162    expr.transform(|expr| {
163        Ok({
164            if let Expr::Column(c) = expr {
165                let col = Column::new_unqualified(c.name);
166                Transformed::yes(Expr::Column(col))
167            } else {
168                Transformed::no(expr)
169            }
170        })
171    })
172    .data()
173    .expect("Unnormalize is infallible")
174}
175
176/// Create a Column from the Scalar Expr
177pub fn create_col_from_scalar_expr(
178    scalar_expr: &Expr,
179    subqry_alias: String,
180) -> Result<Column> {
181    match scalar_expr {
182        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
183            Some::<TableReference>(subqry_alias.into()),
184            name,
185        )),
186        Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
187        _ => {
188            let scalar_column = scalar_expr.schema_name().to_string();
189            Ok(Column::new(
190                Some::<TableReference>(subqry_alias.into()),
191                scalar_column,
192            ))
193        }
194    }
195}
196
197/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
198#[inline]
199pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
200    exprs.into_iter().map(unnormalize_col).collect()
201}
202
203/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
204/// in the expression tree.
205pub fn strip_outer_reference(expr: Expr) -> Expr {
206    expr.transform(|expr| {
207        Ok({
208            if let Expr::OuterReferenceColumn(_, col) = expr {
209                Transformed::yes(Expr::Column(col))
210            } else {
211                Transformed::no(expr)
212            }
213        })
214    })
215    .data()
216    .expect("strip_outer_reference is infallible")
217}
218
219/// Returns plan with expressions coerced to types compatible with
220/// schema types
221pub fn coerce_plan_expr_for_schema(
222    plan: LogicalPlan,
223    schema: &DFSchema,
224) -> Result<LogicalPlan> {
225    match plan {
226        // special case Projection to avoid adding multiple projections
227        LogicalPlan::Projection(Projection { expr, input, .. }) => {
228            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
229            let projection = Projection::try_new(new_exprs, input)?;
230            Ok(LogicalPlan::Projection(projection))
231        }
232        _ => {
233            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
234            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
235            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
236            if add_project {
237                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
238                Ok(LogicalPlan::Projection(projection))
239            } else {
240                Ok(plan)
241            }
242        }
243    }
244}
245
246fn coerce_exprs_for_schema(
247    exprs: Vec<Expr>,
248    src_schema: &DFSchema,
249    dst_schema: &DFSchema,
250) -> Result<Vec<Expr>> {
251    exprs
252        .into_iter()
253        .enumerate()
254        .map(|(idx, expr)| {
255            let new_type = dst_schema.field(idx).data_type();
256            if new_type != &expr.get_type(src_schema)? {
257                match expr {
258                    Expr::Alias(Alias { expr, name, .. }) => {
259                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
260                    }
261                    #[expect(deprecated)]
262                    Expr::Wildcard { .. } => Ok(expr),
263                    _ => {
264                        // maintain the original name when casting
265                        let name = dst_schema.field(idx).name();
266                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
267                    }
268                }
269            } else {
270                Ok(expr)
271            }
272        })
273        .collect::<Result<_>>()
274}
275
276/// Recursively un-alias an expressions
277#[inline]
278pub fn unalias(expr: Expr) -> Expr {
279    match expr {
280        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
281        _ => expr,
282    }
283}
284
285/// Handles ensuring the name of rewritten expressions is not changed.
286///
287/// This is important when optimizing plans to ensure the output
288/// schema of plan nodes don't change after optimization.
289/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
290/// expression should be preserved: `3 as "1 + 2"`
291///
292/// See <https://github.com/apache/datafusion/issues/3555> for details
293pub struct NamePreserver {
294    use_alias: bool,
295}
296
297/// If the qualified name of an expression is remembered, it will be preserved
298/// when rewriting the expression
299#[derive(Debug)]
300pub enum SavedName {
301    /// Saved qualified name to be preserved
302    Saved {
303        relation: Option<TableReference>,
304        name: String,
305    },
306    /// Name is not preserved
307    None,
308}
309
310impl NamePreserver {
311    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
312    pub fn new(plan: &LogicalPlan) -> Self {
313        Self {
314            // The expressions of these plans do not contribute to their output schema,
315            // so there is no need to preserve expression names to prevent a schema change.
316            use_alias: !matches!(
317                plan,
318                LogicalPlan::Filter(_)
319                    | LogicalPlan::Join(_)
320                    | LogicalPlan::TableScan(_)
321                    | LogicalPlan::Limit(_)
322                    | LogicalPlan::Statement(_)
323            ),
324        }
325    }
326
327    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
328    ///
329    /// This will use aliases
330    pub fn new_for_projection() -> Self {
331        Self { use_alias: true }
332    }
333
334    pub fn save(&self, expr: &Expr) -> SavedName {
335        if self.use_alias {
336            let (relation, name) = expr.qualified_name();
337            SavedName::Saved { relation, name }
338        } else {
339            SavedName::None
340        }
341    }
342}
343
344impl SavedName {
345    /// Ensures the qualified name of the rewritten expression is preserved
346    pub fn restore(self, expr: Expr) -> Expr {
347        match self {
348            SavedName::Saved { relation, name } => {
349                let (new_relation, new_name) = expr.qualified_name();
350                if new_relation != relation || new_name != name {
351                    expr.alias_qualified(relation, name)
352                } else {
353                    expr
354                }
355            }
356            SavedName::None => expr,
357        }
358    }
359}
360
361#[cfg(test)]
362mod test {
363    use std::ops::Add;
364
365    use super::*;
366    use crate::literal::lit_with_metadata;
367    use crate::{Cast, col, lit};
368    use arrow::datatypes::{DataType, Field, Schema};
369    use datafusion_common::ScalarValue;
370    use datafusion_common::tree_node::TreeNodeRewriter;
371
372    #[derive(Default)]
373    struct RecordingRewriter {
374        v: Vec<String>,
375    }
376
377    impl TreeNodeRewriter for RecordingRewriter {
378        type Node = Expr;
379
380        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
381            self.v.push(format!("Previsited {expr}"));
382            Ok(Transformed::no(expr))
383        }
384
385        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
386            self.v.push(format!("Mutated {expr}"));
387            Ok(Transformed::no(expr))
388        }
389    }
390
391    #[test]
392    fn rewriter_rewrite() {
393        // rewrites all "foo" string literals to "bar"
394        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
395            match expr {
396                Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
397                    let utf8_val = if utf8_val == "foo" {
398                        "bar".to_string()
399                    } else {
400                        utf8_val
401                    };
402                    Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata)))
403                }
404                // otherwise, return None
405                _ => Ok(Transformed::no(expr)),
406            }
407        };
408
409        // rewrites "foo" --> "bar"
410        let rewritten = col("state")
411            .eq(lit("foo"))
412            .transform(transformer)
413            .data()
414            .unwrap();
415        assert_eq!(rewritten, col("state").eq(lit("bar")));
416
417        // doesn't rewrite
418        let rewritten = col("state")
419            .eq(lit("baz"))
420            .transform(transformer)
421            .data()
422            .unwrap();
423        assert_eq!(rewritten, col("state").eq(lit("baz")));
424    }
425
426    #[test]
427    fn normalize_cols() {
428        let expr = col("a") + col("b") + col("c");
429
430        // Schemas with some matching and some non matching cols
431        let schema_a = make_schema_with_empty_metadata(
432            vec![Some("tableA".into()), Some("tableA".into())],
433            vec!["a", "aa"],
434        );
435        let schema_c = make_schema_with_empty_metadata(
436            vec![Some("tableC".into()), Some("tableC".into())],
437            vec!["cc", "c"],
438        );
439        let schema_b =
440            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
441        // non matching
442        let schema_f = make_schema_with_empty_metadata(
443            vec![Some("tableC".into()), Some("tableC".into())],
444            vec!["f", "ff"],
445        );
446        let schemas = [schema_c, schema_f, schema_b, schema_a];
447        let schemas = schemas.iter().collect::<Vec<_>>();
448
449        let normalized_expr =
450            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
451                .unwrap();
452        assert_eq!(
453            normalized_expr,
454            col("tableA.a") + col("tableB.b") + col("tableC.c")
455        );
456    }
457
458    #[test]
459    fn normalize_cols_non_exist() {
460        // test normalizing columns when the name doesn't exist
461        let expr = col("a") + col("b");
462        let schema_a =
463            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
464        let schemas = [schema_a];
465        let schemas = schemas.iter().collect::<Vec<_>>();
466
467        let error =
468            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
469                .unwrap_err()
470                .strip_backtrace();
471        let expected = "Schema error: No field named b. \
472            Valid fields are \"tableA\".a.";
473        assert_eq!(error, expected);
474    }
475
476    #[test]
477    fn unnormalize_cols() {
478        let expr = col("tableA.a") + col("tableB.b");
479        let unnormalized_expr = unnormalize_col(expr);
480        assert_eq!(unnormalized_expr, col("a") + col("b"));
481    }
482
483    fn make_schema_with_empty_metadata(
484        qualifiers: Vec<Option<TableReference>>,
485        fields: Vec<&str>,
486    ) -> DFSchema {
487        let fields = fields
488            .iter()
489            .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false)))
490            .collect::<Vec<_>>();
491        let schema = Arc::new(Schema::new(fields));
492        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
493    }
494
495    #[test]
496    fn rewriter_visit() {
497        let mut rewriter = RecordingRewriter::default();
498        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
499
500        assert_eq!(
501            rewriter.v,
502            vec![
503                "Previsited state = Utf8(\"CO\")",
504                "Previsited state",
505                "Mutated state",
506                "Previsited Utf8(\"CO\")",
507                "Mutated Utf8(\"CO\")",
508                "Mutated state = Utf8(\"CO\")"
509            ]
510        )
511    }
512
513    #[test]
514    fn test_rewrite_preserving_name() {
515        test_rewrite(col("a"), col("a"));
516
517        test_rewrite(col("a"), col("b"));
518
519        // cast data types
520        test_rewrite(
521            col("a"),
522            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
523        );
524
525        // change literal type from i32 to i64
526        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
527
528        // test preserve qualifier
529        test_rewrite(
530            Expr::Column(Column::new(Some("test"), "a")),
531            Expr::Column(Column::new_unqualified("test.a")),
532        );
533        test_rewrite(
534            Expr::Column(Column::new_unqualified("test.a")),
535            Expr::Column(Column::new(Some("test"), "a")),
536        );
537    }
538
539    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
540    /// by using the `NamePreserver`
541    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
542        struct TestRewriter {
543            rewrite_to: Expr,
544        }
545
546        impl TreeNodeRewriter for TestRewriter {
547            type Node = Expr;
548
549            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
550                Ok(Transformed::yes(self.rewrite_to.clone()))
551            }
552        }
553
554        let mut rewriter = TestRewriter {
555            rewrite_to: rewrite_to.clone(),
556        };
557        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
558        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
559        let new_expr = saved_name.restore(new_expr);
560
561        let original_name = expr_from.qualified_name();
562        let new_name = new_expr.qualified_name();
563        assert_eq!(
564            original_name, new_name,
565            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
566        )
567    }
568}