1use std::sync::Arc;
19
20use datafusion_common::{
21 DataFusionError, HashMap,
22 tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter},
23};
24use datafusion_physical_expr::{PhysicalExpr, expressions::Column};
25
26pub struct PhysicalColumnRewriter<'a> {
35 pub column_map: &'a HashMap<Column, Arc<dyn PhysicalExpr>>,
37}
38
39impl<'a> PhysicalColumnRewriter<'a> {
40 pub fn new(column_map: &'a HashMap<Column, Arc<dyn PhysicalExpr>>) -> Self {
42 Self { column_map }
43 }
44}
45
46impl<'a> TreeNodeRewriter for PhysicalColumnRewriter<'a> {
47 type Node = Arc<dyn PhysicalExpr>;
48
49 fn f_down(
50 &mut self,
51 node: Self::Node,
52 ) -> datafusion_common::Result<Transformed<Self::Node>> {
53 if let Some(column) = node.as_any().downcast_ref::<Column>() {
54 if let Some(new_column) = self.column_map.get(column) {
55 return Ok(Transformed::new(
57 Arc::clone(new_column),
58 true,
59 TreeNodeRecursion::Jump,
60 ));
61 } else {
62 return Err(DataFusionError::Internal(format!(
64 "Column {column:?} not found in column mapping {:?}",
65 self.column_map
66 )));
67 }
68 }
69 Ok(Transformed::no(node))
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use arrow::datatypes::{DataType, Field, Schema};
77 use datafusion_common::{DataFusionError, Result, tree_node::TreeNode};
78 use datafusion_physical_expr::{
79 PhysicalExpr,
80 expressions::{Column, binary, col, lit},
81 };
82 use std::sync::Arc;
83
84 fn create_test_schema() -> Arc<Schema> {
86 Arc::new(Schema::new(vec![
87 Field::new("a", DataType::Int32, true),
88 Field::new("b", DataType::Int32, true),
89 Field::new("c", DataType::Int32, true),
90 Field::new("d", DataType::Int32, true),
91 Field::new("e", DataType::Int32, true),
92 Field::new("new_col", DataType::Int32, true),
93 Field::new("inner_col", DataType::Int32, true),
94 Field::new("another_col", DataType::Int32, true),
95 ]))
96 }
97
98 fn create_complex_expression(schema: &Schema) -> Arc<dyn PhysicalExpr> {
101 let col_a = col("a", schema).unwrap();
102 let col_b = col("b", schema).unwrap();
103 let col_c = col("c", schema).unwrap();
104 let col_d = col("d", schema).unwrap();
105 let col_e = col("e", schema).unwrap();
106
107 let add_expr =
108 binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap();
109 let sub_expr =
110 binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap();
111 let mul_expr = binary(
112 add_expr,
113 datafusion_expr::Operator::Multiply,
114 sub_expr,
115 schema,
116 )
117 .unwrap();
118 binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap()
119 }
120
121 fn create_deeply_nested_expression(schema: &Schema) -> Arc<dyn PhysicalExpr> {
124 let col_a = col("a", schema).unwrap();
125 let col_b = col("b", schema).unwrap();
126 let col_c = col("c", schema).unwrap();
127 let col_d = col("d", schema).unwrap();
128 let col_e = col("e", schema).unwrap();
129
130 let inner1 =
131 binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap();
132 let inner2 =
133 binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap();
134 let inner3 =
135 binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap();
136 binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap()
137 }
138
139 #[test]
140 fn test_simple_column_replacement_with_jump() -> Result<()> {
141 let schema = create_test_schema();
142
143 let mut column_map = HashMap::new();
145 column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32));
146 column_map.insert(
147 Column::new_with_schema("b", &schema).unwrap(),
148 lit("replaced_b"),
149 );
150 column_map.insert(
151 Column::new_with_schema("c", &schema).unwrap(),
152 col("c", &schema).unwrap(),
153 );
154 column_map.insert(
155 Column::new_with_schema("d", &schema).unwrap(),
156 col("d", &schema).unwrap(),
157 );
158 column_map.insert(
159 Column::new_with_schema("e", &schema).unwrap(),
160 col("e", &schema).unwrap(),
161 );
162
163 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
164 let expr = create_complex_expression(&schema);
165
166 let result = expr.rewrite(&mut rewriter)?;
167
168 assert!(result.transformed);
170
171 assert_eq!(
172 format!("{}", result.data),
173 "(42 + replaced_b) * (c@2 - d@3) + e@4"
174 );
175
176 Ok(())
177 }
178
179 #[test]
180 fn test_nested_column_replacement_with_jump() -> Result<()> {
181 let schema = create_test_schema();
182 let mut column_map = HashMap::new();
184 let replacement_expr = binary(
186 lit(100i32),
187 datafusion_expr::Operator::Plus,
188 col("new_col", &schema).unwrap(),
189 &schema,
190 )
191 .unwrap();
192 column_map.insert(
193 Column::new_with_schema("c", &schema).unwrap(),
194 replacement_expr,
195 );
196 column_map.insert(
197 Column::new_with_schema("a", &schema).unwrap(),
198 col("a", &schema).unwrap(),
199 );
200 column_map.insert(
201 Column::new_with_schema("b", &schema).unwrap(),
202 col("b", &schema).unwrap(),
203 );
204 column_map.insert(
205 Column::new_with_schema("d", &schema).unwrap(),
206 col("d", &schema).unwrap(),
207 );
208 column_map.insert(
209 Column::new_with_schema("e", &schema).unwrap(),
210 col("e", &schema).unwrap(),
211 );
212
213 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
214 let expr = create_deeply_nested_expression(&schema);
215
216 let result = expr.rewrite(&mut rewriter)?;
217
218 assert!(result.transformed);
220
221 assert_eq!(
222 format!("{}", result.data),
223 "a@0 + b@1 + 100 + new_col@5 + d@3 + e@4"
224 );
225
226 Ok(())
227 }
228
229 #[test]
230 fn test_circular_reference_prevention() -> Result<()> {
231 let schema = create_test_schema();
232 let mut column_map = HashMap::new();
234
235 column_map.insert(
237 Column::new_with_schema("a", &schema).unwrap(),
238 col("b", &schema).unwrap(),
239 );
240 column_map.insert(
241 Column::new_with_schema("b", &schema).unwrap(),
242 col("a", &schema).unwrap(),
243 );
244
245 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
246
247 let expr = binary(
249 col("a", &schema).unwrap(),
250 datafusion_expr::Operator::Plus,
251 col("b", &schema).unwrap(),
252 &schema,
253 )
254 .unwrap();
255
256 let result = expr.rewrite(&mut rewriter)?;
257
258 assert!(result.transformed);
260
261 assert_eq!(format!("{}", result.data), "b@1 + a@0");
262
263 Ok(())
264 }
265
266 #[test]
267 fn test_multiple_replacements_in_same_expression() -> Result<()> {
268 let schema = create_test_schema();
269 let mut column_map = HashMap::new();
271
272 column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32));
274 column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32));
275 column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32));
276 column_map.insert(
277 Column::new_with_schema("b", &schema).unwrap(),
278 col("b", &schema).unwrap(),
279 );
280 column_map.insert(
281 Column::new_with_schema("d", &schema).unwrap(),
282 col("d", &schema).unwrap(),
283 );
284
285 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
286 let expr = create_complex_expression(&schema); let result = expr.rewrite(&mut rewriter)?;
289
290 assert!(result.transformed);
292 assert_eq!(format!("{}", result.data), "(10 + b@1) * (20 - d@3) + 30");
293
294 Ok(())
295 }
296
297 #[test]
298 fn test_jump_with_complex_replacement_expression() -> Result<()> {
299 let schema = create_test_schema();
300 let mut column_map = HashMap::new();
302
303 let inner_expr = binary(
305 lit(5i32),
306 datafusion_expr::Operator::Multiply,
307 col("a", &schema).unwrap(),
308 &schema,
309 )
310 .unwrap();
311 let middle_expr = binary(
312 inner_expr,
313 datafusion_expr::Operator::Plus,
314 lit(3i32),
315 &schema,
316 )
317 .unwrap();
318 let complex_replacement = binary(
319 middle_expr,
320 datafusion_expr::Operator::Minus,
321 col("another_col", &schema).unwrap(),
322 &schema,
323 )
324 .unwrap();
325
326 column_map.insert(
327 Column::new_with_schema("a", &schema).unwrap(),
328 complex_replacement,
329 );
330 column_map.insert(
331 Column::new_with_schema("b", &schema).unwrap(),
332 col("b", &schema).unwrap(),
333 );
334
335 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
336
337 let expr = binary(
339 col("a", &schema).unwrap(),
340 datafusion_expr::Operator::Plus,
341 col("b", &schema).unwrap(),
342 &schema,
343 )
344 .unwrap();
345
346 let result = expr.rewrite(&mut rewriter)?;
347
348 assert_eq!(
349 format!("{}", result.data),
350 "5 * a@0 + 3 - another_col@7 + b@1"
351 );
352
353 assert!(result.transformed);
355
356 Ok(())
357 }
358
359 #[test]
360 fn test_unmapped_columns_detection() -> Result<()> {
361 let schema = create_test_schema();
362 let mut column_map = HashMap::new();
363
364 column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32));
366
367 let mut rewriter = PhysicalColumnRewriter::new(&column_map);
368
369 let expr = binary(
371 col("a", &schema).unwrap(),
372 datafusion_expr::Operator::Plus,
373 col("b", &schema).unwrap(),
374 &schema,
375 )
376 .unwrap();
377
378 let err = expr.rewrite(&mut rewriter).unwrap_err();
379 assert!(matches!(err, DataFusionError::Internal(_)));
380
381 Ok(())
382 }
383}