datafusion_expr/expr_rewriter/
mod.rs1use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::TableReference;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_common::{Column, DFSchema, Result};
33
34mod guarantees;
35pub use guarantees::GuaranteeRewriter;
36pub use guarantees::rewrite_with_guarantees;
37pub use guarantees::rewrite_with_guarantees_map;
38mod order_by;
39
40pub use order_by::rewrite_sort_cols_by_aggs;
41
42pub trait FunctionRewrite: Debug {
52 fn name(&self) -> &str;
54
55 fn rewrite(
60 &self,
61 expr: Expr,
62 schema: &DFSchema,
63 config: &ConfigOptions,
64 ) -> Result<Transformed<Expr>>;
65}
66
67pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
70 expr.transform(|expr| {
71 Ok({
72 if let Expr::Column(c) = expr {
73 let col = LogicalPlanBuilder::normalize(plan, c)?;
74 Transformed::yes(Expr::Column(col))
75 } else {
76 Transformed::no(expr)
77 }
78 })
79 })
80 .data()
81}
82
83pub fn normalize_col_with_schemas_and_ambiguity_check(
85 expr: Expr,
86 schemas: &[&[&DFSchema]],
87 using_columns: &[HashSet<Column>],
88) -> Result<Expr> {
89 if let Expr::Unnest(Unnest { expr }) = expr {
91 let e = normalize_col_with_schemas_and_ambiguity_check(
92 expr.as_ref().clone(),
93 schemas,
94 using_columns,
95 )?;
96 return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
97 }
98
99 expr.transform(|expr| {
100 Ok({
101 if let Expr::Column(c) = expr {
102 let col =
103 c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?;
104 Transformed::yes(Expr::Column(col))
105 } else {
106 Transformed::no(expr)
107 }
108 })
109 })
110 .data()
111}
112
113pub fn normalize_cols(
115 exprs: impl IntoIterator<Item = impl Into<Expr>>,
116 plan: &LogicalPlan,
117) -> Result<Vec<Expr>> {
118 exprs
119 .into_iter()
120 .map(|e| normalize_col(e.into(), plan))
121 .collect()
122}
123
124pub fn normalize_sorts(
125 sorts: impl IntoIterator<Item = impl Into<Sort>>,
126 plan: &LogicalPlan,
127) -> Result<Vec<Sort>> {
128 sorts
129 .into_iter()
130 .map(|e| {
131 let sort = e.into();
132 normalize_col(sort.expr, plan)
133 .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
134 })
135 .collect()
136}
137
138pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
141 expr.transform(|expr| {
142 Ok({
143 if let Expr::Column(c) = &expr {
144 match replace_map.get(c) {
145 Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())),
146 None => Transformed::no(expr),
147 }
148 } else {
149 Transformed::no(expr)
150 }
151 })
152 })
153 .data()
154}
155
156pub fn unnormalize_col(expr: Expr) -> Expr {
162 expr.transform(|expr| {
163 Ok({
164 if let Expr::Column(c) = expr {
165 let col = Column::new_unqualified(c.name);
166 Transformed::yes(Expr::Column(col))
167 } else {
168 Transformed::no(expr)
169 }
170 })
171 })
172 .data()
173 .expect("Unnormalize is infallible")
174}
175
176pub fn create_col_from_scalar_expr(
178 scalar_expr: &Expr,
179 subqry_alias: String,
180) -> Result<Column> {
181 match scalar_expr {
182 Expr::Alias(Alias { name, .. }) => Ok(Column::new(
183 Some::<TableReference>(subqry_alias.into()),
184 name,
185 )),
186 Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
187 _ => {
188 let scalar_column = scalar_expr.schema_name().to_string();
189 Ok(Column::new(
190 Some::<TableReference>(subqry_alias.into()),
191 scalar_column,
192 ))
193 }
194 }
195}
196
197#[inline]
199pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
200 exprs.into_iter().map(unnormalize_col).collect()
201}
202
203pub fn strip_outer_reference(expr: Expr) -> Expr {
206 expr.transform(|expr| {
207 Ok({
208 if let Expr::OuterReferenceColumn(_, col) = expr {
209 Transformed::yes(Expr::Column(col))
210 } else {
211 Transformed::no(expr)
212 }
213 })
214 })
215 .data()
216 .expect("strip_outer_reference is infallible")
217}
218
219pub fn coerce_plan_expr_for_schema(
222 plan: LogicalPlan,
223 schema: &DFSchema,
224) -> Result<LogicalPlan> {
225 match plan {
226 LogicalPlan::Projection(Projection { expr, input, .. }) => {
228 let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
229 let projection = Projection::try_new(new_exprs, input)?;
230 Ok(LogicalPlan::Projection(projection))
231 }
232 _ => {
233 let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
234 let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
235 let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
236 if add_project {
237 let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
238 Ok(LogicalPlan::Projection(projection))
239 } else {
240 Ok(plan)
241 }
242 }
243 }
244}
245
246fn coerce_exprs_for_schema(
247 exprs: Vec<Expr>,
248 src_schema: &DFSchema,
249 dst_schema: &DFSchema,
250) -> Result<Vec<Expr>> {
251 exprs
252 .into_iter()
253 .enumerate()
254 .map(|(idx, expr)| {
255 let new_type = dst_schema.field(idx).data_type();
256 if new_type != &expr.get_type(src_schema)? {
257 match expr {
258 Expr::Alias(Alias { expr, name, .. }) => {
259 Ok(expr.cast_to(new_type, src_schema)?.alias(name))
260 }
261 #[expect(deprecated)]
262 Expr::Wildcard { .. } => Ok(expr),
263 _ => {
264 match expr {
265 Expr::Column(ref column) => {
269 let name = column.name().to_owned();
270 Ok(expr.cast_to(new_type, src_schema)?.alias(name))
271 }
272 _ => Ok(expr.cast_to(new_type, src_schema)?),
273 }
274 }
275 }
276 } else {
277 Ok(expr)
278 }
279 })
280 .collect::<Result<_>>()
281}
282
283#[inline]
285pub fn unalias(expr: Expr) -> Expr {
286 match expr {
287 Expr::Alias(Alias { expr, .. }) => unalias(*expr),
288 _ => expr,
289 }
290}
291
292pub struct NamePreserver {
301 use_alias: bool,
302}
303
304#[derive(Debug)]
307pub enum SavedName {
308 Saved {
310 relation: Option<TableReference>,
311 name: String,
312 },
313 None,
315}
316
317impl NamePreserver {
318 pub fn new(plan: &LogicalPlan) -> Self {
320 Self {
321 use_alias: !matches!(
324 plan,
325 LogicalPlan::Filter(_)
326 | LogicalPlan::Join(_)
327 | LogicalPlan::TableScan(_)
328 | LogicalPlan::Limit(_)
329 | LogicalPlan::Statement(_)
330 ),
331 }
332 }
333
334 pub fn new_for_projection() -> Self {
338 Self { use_alias: true }
339 }
340
341 pub fn save(&self, expr: &Expr) -> SavedName {
342 if self.use_alias {
343 match expr {
344 Expr::Alias(alias) => SavedName::Saved {
345 relation: alias.relation.clone(),
346 name: alias.name.clone(),
347 },
348 _ => {
349 let (relation, name) = expr.qualified_name();
350 SavedName::Saved { relation, name }
351 }
352 }
353 } else {
354 SavedName::None
355 }
356 }
357}
358
359impl SavedName {
360 pub fn restore(self, expr: Expr) -> Expr {
362 match self {
363 SavedName::Saved { relation, name } => {
364 let (new_relation, new_name) = expr.qualified_name();
365 if new_relation != relation || new_name != name {
366 expr.alias_qualified(relation, name)
367 } else {
368 expr
369 }
370 }
371 SavedName::None => expr,
372 }
373 }
374}
375
376#[cfg(test)]
377mod test {
378 use std::ops::Add;
379
380 use super::*;
381 use crate::literal::lit_with_metadata;
382 use crate::{Cast, col, lit};
383 use arrow::datatypes::{DataType, Field, Schema};
384 use datafusion_common::ScalarValue;
385 use datafusion_common::tree_node::TreeNodeRewriter;
386
387 #[derive(Default)]
388 struct RecordingRewriter {
389 v: Vec<String>,
390 }
391
392 impl TreeNodeRewriter for RecordingRewriter {
393 type Node = Expr;
394
395 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
396 self.v.push(format!("Previsited {expr}"));
397 Ok(Transformed::no(expr))
398 }
399
400 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
401 self.v.push(format!("Mutated {expr}"));
402 Ok(Transformed::no(expr))
403 }
404 }
405
406 #[test]
407 fn rewriter_rewrite() {
408 let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
410 match expr {
411 Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
412 let utf8_val = if utf8_val == "foo" {
413 "bar".to_string()
414 } else {
415 utf8_val
416 };
417 Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata)))
418 }
419 _ => Ok(Transformed::no(expr)),
421 }
422 };
423
424 let rewritten = col("state")
426 .eq(lit("foo"))
427 .transform(transformer)
428 .data()
429 .unwrap();
430 assert_eq!(rewritten, col("state").eq(lit("bar")));
431
432 let rewritten = col("state")
434 .eq(lit("baz"))
435 .transform(transformer)
436 .data()
437 .unwrap();
438 assert_eq!(rewritten, col("state").eq(lit("baz")));
439 }
440
441 #[test]
442 fn normalize_cols() {
443 let expr = col("a") + col("b") + col("c");
444
445 let schema_a = make_schema_with_empty_metadata(
447 vec![Some("tableA".into()), Some("tableA".into())],
448 vec!["a", "aa"],
449 );
450 let schema_c = make_schema_with_empty_metadata(
451 vec![Some("tableC".into()), Some("tableC".into())],
452 vec!["cc", "c"],
453 );
454 let schema_b =
455 make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
456 let schema_f = make_schema_with_empty_metadata(
458 vec![Some("tableC".into()), Some("tableC".into())],
459 vec!["f", "ff"],
460 );
461 let schemas = [schema_c, schema_f, schema_b, schema_a];
462 let schemas = schemas.iter().collect::<Vec<_>>();
463
464 let normalized_expr =
465 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
466 .unwrap();
467 assert_eq!(
468 normalized_expr,
469 col("tableA.a") + col("tableB.b") + col("tableC.c")
470 );
471 }
472
473 #[test]
474 fn normalize_cols_non_exist() {
475 let expr = col("a") + col("b");
477 let schema_a =
478 make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
479 let schemas = [schema_a];
480 let schemas = schemas.iter().collect::<Vec<_>>();
481
482 let error =
483 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
484 .unwrap_err()
485 .strip_backtrace();
486 let expected = "Schema error: No field named b. \
487 Valid fields are \"tableA\".a.";
488 assert_eq!(error, expected);
489 }
490
491 #[test]
492 fn unnormalize_cols() {
493 let expr = col("tableA.a") + col("tableB.b");
494 let unnormalized_expr = unnormalize_col(expr);
495 assert_eq!(unnormalized_expr, col("a") + col("b"));
496 }
497
498 fn make_schema_with_empty_metadata(
499 qualifiers: Vec<Option<TableReference>>,
500 fields: Vec<&str>,
501 ) -> DFSchema {
502 let fields = fields
503 .iter()
504 .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false)))
505 .collect::<Vec<_>>();
506 let schema = Arc::new(Schema::new(fields));
507 DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
508 }
509
510 #[test]
511 fn rewriter_visit() {
512 let mut rewriter = RecordingRewriter::default();
513 col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
514
515 assert_eq!(
516 rewriter.v,
517 vec![
518 "Previsited state = Utf8(\"CO\")",
519 "Previsited state",
520 "Mutated state",
521 "Previsited Utf8(\"CO\")",
522 "Mutated Utf8(\"CO\")",
523 "Mutated state = Utf8(\"CO\")"
524 ]
525 )
526 }
527
528 #[test]
529 fn test_rewrite_preserving_name() {
530 test_rewrite(col("a"), col("a"));
531
532 test_rewrite(col("a"), col("b"));
533
534 test_rewrite(
536 col("a"),
537 Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
538 );
539
540 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
542
543 test_rewrite(
545 Expr::Column(Column::new(Some("test"), "a")),
546 Expr::Column(Column::new_unqualified("test.a")),
547 );
548 test_rewrite(
549 Expr::Column(Column::new_unqualified("test.a")),
550 Expr::Column(Column::new(Some("test"), "a")),
551 );
552 }
553
554 fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
557 struct TestRewriter {
558 rewrite_to: Expr,
559 }
560
561 impl TreeNodeRewriter for TestRewriter {
562 type Node = Expr;
563
564 fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
565 Ok(Transformed::yes(self.rewrite_to.clone()))
566 }
567 }
568
569 let mut rewriter = TestRewriter {
570 rewrite_to: rewrite_to.clone(),
571 };
572 let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
573 let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
574 let new_expr = saved_name.restore(new_expr);
575
576 let original_name = expr_from.qualified_name();
577 let new_name = new_expr.qualified_name();
578 assert_eq!(
579 original_name, new_name,
580 "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
581 )
582 }
583}