Skip to main content

datafusion_physical_plan/
column_rewriter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use datafusion_common::{
21    DataFusionError, HashMap,
22    tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter},
23};
24use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
25
26/// Rewrite column references in a physical expr according to a mapping.
27///
28/// This rewriter traverses the expression tree and replaces [`Column`] nodes
29/// with the corresponding expression found in the `column_map`.
30///
31/// If a column is found in the map, it is replaced by the mapped expression.
32/// If a column is NOT found in the map, a `DataFusionError::Internal` is
33/// returned.
34pub struct PhysicalColumnRewriter<'a> {
35    /// Mapping from original column to new column.
36    pub column_map: &'a HashMap<Column, Arc<dyn PhysicalExpr>>,
37}
38
39impl<'a> PhysicalColumnRewriter<'a> {
40    /// Create a new PhysicalColumnRewriter with the given column mapping.
41    pub fn new(column_map: &'a HashMap<Column, Arc<dyn PhysicalExpr>>) -> Self {
42        Self { column_map }
43    }
44}
45
46impl<'a> TreeNodeRewriter for PhysicalColumnRewriter<'a> {
47    type Node = Arc<dyn PhysicalExpr>;
48
49    fn f_down(
50        &mut self,
51        node: Self::Node,
52    ) -> datafusion_common::Result<Transformed<Self::Node>> {
53        if let Some(column) = node.as_any().downcast_ref::<Column>() {
54            if let Some(new_column) = self.column_map.get(column) {
55                // jump to prevent rewriting the new sub-expression again
56                return Ok(Transformed::new(
57                    Arc::clone(new_column),
58                    true,
59                    TreeNodeRecursion::Jump,
60                ));
61            } else {
62                // Column not found in mapping
63                return Err(DataFusionError::Internal(format!(
64                    "Column {column:?} not found in column mapping {:?}",
65                    self.column_map
66                )));
67            }
68        }
69        Ok(Transformed::no(node))
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use arrow::datatypes::{DataType, Field, Schema};
77    use datafusion_common::{DataFusionError, Result, tree_node::TreeNode};
78    use datafusion_physical_expr::{
79        PhysicalExpr,
80        expressions::{Column, binary, col, lit},
81    };
82    use std::sync::Arc;
83
84    /// Helper function to create a test schema
85    fn create_test_schema() -> Arc<Schema> {
86        Arc::new(Schema::new(vec![
87            Field::new("a", DataType::Int32, true),
88            Field::new("b", DataType::Int32, true),
89            Field::new("c", DataType::Int32, true),
90            Field::new("d", DataType::Int32, true),
91            Field::new("e", DataType::Int32, true),
92            Field::new("new_col", DataType::Int32, true),
93            Field::new("inner_col", DataType::Int32, true),
94            Field::new("another_col", DataType::Int32, true),
95        ]))
96    }
97
98    /// Helper function to create a complex nested expression with multiple columns
99    /// Create: (col_a + col_b) * (col_c - col_d) + col_e
100    fn create_complex_expression(schema: &Schema) -> Arc<dyn PhysicalExpr> {
101        let col_a = col("a", schema).unwrap();
102        let col_b = col("b", schema).unwrap();
103        let col_c = col("c", schema).unwrap();
104        let col_d = col("d", schema).unwrap();
105        let col_e = col("e", schema).unwrap();
106
107        let add_expr =
108            binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap();
109        let sub_expr =
110            binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap();
111        let mul_expr = binary(
112            add_expr,
113            datafusion_expr::Operator::Multiply,
114            sub_expr,
115            schema,
116        )
117        .unwrap();
118        binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap()
119    }
120
121    /// Helper function to create a deeply nested expression
122    /// Create: col_a + (col_b + (col_c + (col_d + col_e)))
123    fn create_deeply_nested_expression(schema: &Schema) -> Arc<dyn PhysicalExpr> {
124        let col_a = col("a", schema).unwrap();
125        let col_b = col("b", schema).unwrap();
126        let col_c = col("c", schema).unwrap();
127        let col_d = col("d", schema).unwrap();
128        let col_e = col("e", schema).unwrap();
129
130        let inner1 =
131            binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap();
132        let inner2 =
133            binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap();
134        let inner3 =
135            binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap();
136        binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap()
137    }
138
139    #[test]
140    fn test_simple_column_replacement_with_jump() -> Result<()> {
141        let schema = create_test_schema();
142
143        // Test that Jump prevents re-processing of replaced columns
144        let mut column_map = HashMap::new();
145        column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32));
146        column_map.insert(
147            Column::new_with_schema("b", &schema).unwrap(),
148            lit("replaced_b"),
149        );
150        column_map.insert(
151            Column::new_with_schema("c", &schema).unwrap(),
152            col("c", &schema).unwrap(),
153        );
154        column_map.insert(
155            Column::new_with_schema("d", &schema).unwrap(),
156            col("d", &schema).unwrap(),
157        );
158        column_map.insert(
159            Column::new_with_schema("e", &schema).unwrap(),
160            col("e", &schema).unwrap(),
161        );
162
163        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
164        let expr = create_complex_expression(&schema);
165
166        let result = expr.rewrite(&mut rewriter)?;
167
168        // Verify the transformation occurred
169        assert!(result.transformed);
170
171        assert_eq!(
172            format!("{}", result.data),
173            "(42 + replaced_b) * (c@2 - d@3) + e@4"
174        );
175
176        Ok(())
177    }
178
179    #[test]
180    fn test_nested_column_replacement_with_jump() -> Result<()> {
181        let schema = create_test_schema();
182        // Test Jump behavior with deeply nested expressions
183        let mut column_map = HashMap::new();
184        // Replace col_c with a complex expression containing new columns
185        let replacement_expr = binary(
186            lit(100i32),
187            datafusion_expr::Operator::Plus,
188            col("new_col", &schema).unwrap(),
189            &schema,
190        )
191        .unwrap();
192        column_map.insert(
193            Column::new_with_schema("c", &schema).unwrap(),
194            replacement_expr,
195        );
196        column_map.insert(
197            Column::new_with_schema("a", &schema).unwrap(),
198            col("a", &schema).unwrap(),
199        );
200        column_map.insert(
201            Column::new_with_schema("b", &schema).unwrap(),
202            col("b", &schema).unwrap(),
203        );
204        column_map.insert(
205            Column::new_with_schema("d", &schema).unwrap(),
206            col("d", &schema).unwrap(),
207        );
208        column_map.insert(
209            Column::new_with_schema("e", &schema).unwrap(),
210            col("e", &schema).unwrap(),
211        );
212
213        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
214        let expr = create_deeply_nested_expression(&schema);
215
216        let result = expr.rewrite(&mut rewriter)?;
217
218        // Verify transformation occurred
219        assert!(result.transformed);
220
221        assert_eq!(
222            format!("{}", result.data),
223            "a@0 + b@1 + 100 + new_col@5 + d@3 + e@4"
224        );
225
226        Ok(())
227    }
228
229    #[test]
230    fn test_circular_reference_prevention() -> Result<()> {
231        let schema = create_test_schema();
232        // Test that Jump prevents infinite recursion with circular references
233        let mut column_map = HashMap::new();
234
235        // Create a circular reference: col_a -> col_b -> col_a (but Jump should prevent the second visit)
236        column_map.insert(
237            Column::new_with_schema("a", &schema).unwrap(),
238            col("b", &schema).unwrap(),
239        );
240        column_map.insert(
241            Column::new_with_schema("b", &schema).unwrap(),
242            col("a", &schema).unwrap(),
243        );
244
245        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
246
247        // Start with an expression containing col_a
248        let expr = binary(
249            col("a", &schema).unwrap(),
250            datafusion_expr::Operator::Plus,
251            col("b", &schema).unwrap(),
252            &schema,
253        )
254        .unwrap();
255
256        let result = expr.rewrite(&mut rewriter)?;
257
258        // Verify transformation occurred
259        assert!(result.transformed);
260
261        assert_eq!(format!("{}", result.data), "b@1 + a@0");
262
263        Ok(())
264    }
265
266    #[test]
267    fn test_multiple_replacements_in_same_expression() -> Result<()> {
268        let schema = create_test_schema();
269        // Test multiple column replacements in the same complex expression
270        let mut column_map = HashMap::new();
271
272        // Replace multiple columns with literals
273        column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32));
274        column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32));
275        column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32));
276        column_map.insert(
277            Column::new_with_schema("b", &schema).unwrap(),
278            col("b", &schema).unwrap(),
279        );
280        column_map.insert(
281            Column::new_with_schema("d", &schema).unwrap(),
282            col("d", &schema).unwrap(),
283        );
284
285        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
286        let expr = create_complex_expression(&schema); // (col_a + col_b) * (col_c - col_d) + col_e
287
288        let result = expr.rewrite(&mut rewriter)?;
289
290        // Verify transformation occurred
291        assert!(result.transformed);
292        assert_eq!(format!("{}", result.data), "(10 + b@1) * (20 - d@3) + 30");
293
294        Ok(())
295    }
296
297    #[test]
298    fn test_jump_with_complex_replacement_expression() -> Result<()> {
299        let schema = create_test_schema();
300        // Test Jump behavior when replacing with very complex expressions
301        let mut column_map = HashMap::new();
302
303        // Replace col_a with a complex nested expression
304        let inner_expr = binary(
305            lit(5i32),
306            datafusion_expr::Operator::Multiply,
307            col("a", &schema).unwrap(),
308            &schema,
309        )
310        .unwrap();
311        let middle_expr = binary(
312            inner_expr,
313            datafusion_expr::Operator::Plus,
314            lit(3i32),
315            &schema,
316        )
317        .unwrap();
318        let complex_replacement = binary(
319            middle_expr,
320            datafusion_expr::Operator::Minus,
321            col("another_col", &schema).unwrap(),
322            &schema,
323        )
324        .unwrap();
325
326        column_map.insert(
327            Column::new_with_schema("a", &schema).unwrap(),
328            complex_replacement,
329        );
330        column_map.insert(
331            Column::new_with_schema("b", &schema).unwrap(),
332            col("b", &schema).unwrap(),
333        );
334
335        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
336
337        // Create expression: col_a + col_b
338        let expr = binary(
339            col("a", &schema).unwrap(),
340            datafusion_expr::Operator::Plus,
341            col("b", &schema).unwrap(),
342            &schema,
343        )
344        .unwrap();
345
346        let result = expr.rewrite(&mut rewriter)?;
347
348        assert_eq!(
349            format!("{}", result.data),
350            "5 * a@0 + 3 - another_col@7 + b@1"
351        );
352
353        // Verify transformation occurred
354        assert!(result.transformed);
355
356        Ok(())
357    }
358
359    #[test]
360    fn test_unmapped_columns_detection() -> Result<()> {
361        let schema = create_test_schema();
362        let mut column_map = HashMap::new();
363
364        // Only map col_a, leave col_b unmapped
365        column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32));
366
367        let mut rewriter = PhysicalColumnRewriter::new(&column_map);
368
369        // Create expression: col_a + col_b
370        let expr = binary(
371            col("a", &schema).unwrap(),
372            datafusion_expr::Operator::Plus,
373            col("b", &schema).unwrap(),
374            &schema,
375        )
376        .unwrap();
377
378        let err = expr.rewrite(&mut rewriter).unwrap_err();
379        assert!(matches!(err, DataFusionError::Internal(_)));
380
381        Ok(())
382    }
383}