1use std::sync::Arc;
21
22use arrow::compute::can_cast_types;
23use arrow::datatypes::{FieldRef, Schema, SchemaRef};
24use datafusion_common::{
25 exec_err,
26 tree_node::{Transformed, TransformedResult, TreeNode},
27 Result, ScalarValue,
28};
29use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
30
31use crate::expressions::{self, CastExpr, Column};
32
33pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
102 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
116
117 fn with_partition_values(
118 &self,
119 partition_values: Vec<(FieldRef, ScalarValue)>,
120 ) -> Arc<dyn PhysicalExprAdapter>;
121}
122
123pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
124 fn create(
126 &self,
127 logical_file_schema: SchemaRef,
128 physical_file_schema: SchemaRef,
129 ) -> Arc<dyn PhysicalExprAdapter>;
130}
131
132#[derive(Debug, Clone)]
133pub struct DefaultPhysicalExprAdapterFactory;
134
135impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
136 fn create(
137 &self,
138 logical_file_schema: SchemaRef,
139 physical_file_schema: SchemaRef,
140 ) -> Arc<dyn PhysicalExprAdapter> {
141 Arc::new(DefaultPhysicalExprAdapter {
142 logical_file_schema,
143 physical_file_schema,
144 partition_values: Vec::new(),
145 })
146 }
147}
148
149#[derive(Debug, Clone)]
170pub struct DefaultPhysicalExprAdapter {
171 logical_file_schema: SchemaRef,
172 physical_file_schema: SchemaRef,
173 partition_values: Vec<(FieldRef, ScalarValue)>,
174}
175
176impl DefaultPhysicalExprAdapter {
177 pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
182 Self {
183 logical_file_schema,
184 physical_file_schema,
185 partition_values: Vec::new(),
186 }
187 }
188}
189
190impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
191 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
192 let rewriter = DefaultPhysicalExprAdapterRewriter {
193 logical_file_schema: &self.logical_file_schema,
194 physical_file_schema: &self.physical_file_schema,
195 partition_fields: &self.partition_values,
196 };
197 expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr)))
198 .data()
199 }
200
201 fn with_partition_values(
202 &self,
203 partition_values: Vec<(FieldRef, ScalarValue)>,
204 ) -> Arc<dyn PhysicalExprAdapter> {
205 Arc::new(DefaultPhysicalExprAdapter {
206 partition_values,
207 ..self.clone()
208 })
209 }
210}
211
212struct DefaultPhysicalExprAdapterRewriter<'a> {
213 logical_file_schema: &'a Schema,
214 physical_file_schema: &'a Schema,
215 partition_fields: &'a [(FieldRef, ScalarValue)],
216}
217
218impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
219 fn rewrite_expr(
220 &self,
221 expr: Arc<dyn PhysicalExpr>,
222 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
223 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
224 return self.rewrite_column(Arc::clone(&expr), column);
225 }
226
227 Ok(Transformed::no(expr))
228 }
229
230 fn rewrite_column(
231 &self,
232 expr: Arc<dyn PhysicalExpr>,
233 column: &Column,
234 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
235 let logical_field = match self.logical_file_schema.field_with_name(column.name())
237 {
238 Ok(field) => field,
239 Err(e) => {
240 if let Some(partition_value) = self.get_partition_value(column.name()) {
242 return Ok(Transformed::yes(expressions::lit(partition_value)));
243 }
244 if let Ok(physical_field) =
248 self.physical_file_schema.field_with_name(column.name())
249 {
250 physical_field
254 } else {
255 return Err(e.into());
259 }
260 }
261 };
262
263 let physical_column_index =
265 match self.physical_file_schema.index_of(column.name()) {
266 Ok(index) => index,
267 Err(_) => {
268 if !logical_field.is_nullable() {
269 return exec_err!(
270 "Non-nullable column '{}' is missing from the physical schema",
271 column.name()
272 );
273 }
274 let null_value =
279 ScalarValue::Null.cast_to(logical_field.data_type())?;
280 return Ok(Transformed::yes(expressions::lit(null_value)));
281 }
282 };
283 let physical_field = self.physical_file_schema.field(physical_column_index);
284
285 let column = match (
286 column.index() == physical_column_index,
287 logical_field.data_type() == physical_field.data_type(),
288 ) {
289 (true, true) => return Ok(Transformed::no(expr)),
291 (true, _) => column.clone(),
293 (false, _) => {
294 Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
295 }
296 };
297
298 if logical_field.data_type() == physical_field.data_type() {
299 return Ok(Transformed::yes(Arc::new(column)));
301 }
302
303 if !can_cast_types(physical_field.data_type(), logical_field.data_type()) {
308 return exec_err!(
309 "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
310 column.name(),
311 physical_field.data_type(),
312 logical_field.data_type()
313 );
314 }
315
316 let cast_expr = Arc::new(CastExpr::new(
317 Arc::new(column),
318 logical_field.data_type().clone(),
319 None,
320 ));
321
322 Ok(Transformed::yes(cast_expr))
323 }
324
325 fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
326 self.partition_fields
327 .iter()
328 .find(|(field, _)| field.name() == column_name)
329 .map(|(_, value)| value.clone())
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use crate::expressions::{col, lit};
336
337 use super::*;
338 use arrow::{
339 array::{RecordBatch, RecordBatchOptions},
340 datatypes::{DataType, Field, Schema, SchemaRef},
341 };
342 use datafusion_common::{record_batch, ScalarValue};
343 use datafusion_expr::Operator;
344 use itertools::Itertools;
345 use std::sync::Arc;
346
347 fn create_test_schema() -> (Schema, Schema) {
348 let physical_schema = Schema::new(vec![
349 Field::new("a", DataType::Int32, false),
350 Field::new("b", DataType::Utf8, true),
351 ]);
352
353 let logical_schema = Schema::new(vec![
354 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
356 Field::new("c", DataType::Float64, true), ]);
358
359 (physical_schema, logical_schema)
360 }
361
362 #[test]
363 fn test_rewrite_column_with_type_cast() {
364 let (physical_schema, logical_schema) = create_test_schema();
365
366 let factory = DefaultPhysicalExprAdapterFactory;
367 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
368 let column_expr = Arc::new(Column::new("a", 0));
369
370 let result = adapter.rewrite(column_expr).unwrap();
371
372 assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
374 }
375
376 #[test]
377 fn test_rewrite_mulit_column_expr_with_type_cast() {
378 let (physical_schema, logical_schema) = create_test_schema();
379 let factory = DefaultPhysicalExprAdapterFactory;
380 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
381
382 let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
384 let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
385 let expr = expressions::BinaryExpr::new(
386 Arc::clone(&column_a),
387 Operator::Plus,
388 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
389 );
390 let expr = expressions::BinaryExpr::new(
391 Arc::new(expr),
392 Operator::Or,
393 Arc::new(expressions::BinaryExpr::new(
394 Arc::clone(&column_c),
395 Operator::Gt,
396 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
397 )),
398 );
399
400 let result = adapter.rewrite(Arc::new(expr)).unwrap();
401 println!("Rewritten expression: {result}");
402
403 let expected = expressions::BinaryExpr::new(
404 Arc::new(CastExpr::new(
405 Arc::new(Column::new("a", 0)),
406 DataType::Int64,
407 None,
408 )),
409 Operator::Plus,
410 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
411 );
412 let expected = Arc::new(expressions::BinaryExpr::new(
413 Arc::new(expected),
414 Operator::Or,
415 Arc::new(expressions::BinaryExpr::new(
416 lit(ScalarValue::Null),
417 Operator::Gt,
418 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
419 )),
420 )) as Arc<dyn PhysicalExpr>;
421
422 assert_eq!(
423 result.to_string(),
424 expected.to_string(),
425 "The rewritten expression did not match the expected output"
426 );
427 }
428
429 #[test]
430 fn test_rewrite_missing_column() -> Result<()> {
431 let (physical_schema, logical_schema) = create_test_schema();
432
433 let factory = DefaultPhysicalExprAdapterFactory;
434 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
435 let column_expr = Arc::new(Column::new("c", 2));
436
437 let result = adapter.rewrite(column_expr)?;
438
439 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
441 assert_eq!(*literal.value(), ScalarValue::Float64(None));
442 } else {
443 panic!("Expected literal expression");
444 }
445
446 Ok(())
447 }
448
449 #[test]
450 fn test_rewrite_partition_column() -> Result<()> {
451 let (physical_schema, logical_schema) = create_test_schema();
452
453 let partition_field =
454 Arc::new(Field::new("partition_col", DataType::Utf8, false));
455 let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
456 let partition_values = vec![(partition_field, partition_value)];
457
458 let factory = DefaultPhysicalExprAdapterFactory;
459 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
460 let adapter = adapter.with_partition_values(partition_values);
461
462 let column_expr = Arc::new(Column::new("partition_col", 0));
463 let result = adapter.rewrite(column_expr)?;
464
465 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
467 assert_eq!(
468 *literal.value(),
469 ScalarValue::Utf8(Some("test_value".to_string()))
470 );
471 } else {
472 panic!("Expected literal expression");
473 }
474
475 Ok(())
476 }
477
478 #[test]
479 fn test_rewrite_no_change_needed() -> Result<()> {
480 let (physical_schema, logical_schema) = create_test_schema();
481
482 let factory = DefaultPhysicalExprAdapterFactory;
483 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
484 let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
485
486 let result = adapter.rewrite(Arc::clone(&column_expr))?;
487
488 assert!(std::ptr::eq(
491 column_expr.as_ref() as *const dyn PhysicalExpr,
492 result.as_ref() as *const dyn PhysicalExpr
493 ));
494
495 Ok(())
496 }
497
498 #[test]
499 fn test_non_nullable_missing_column_error() {
500 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
501 let logical_schema = Schema::new(vec![
502 Field::new("a", DataType::Int32, false),
503 Field::new("b", DataType::Utf8, false), ]);
505
506 let factory = DefaultPhysicalExprAdapterFactory;
507 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
508 let column_expr = Arc::new(Column::new("b", 1));
509
510 let result = adapter.rewrite(column_expr);
511 assert!(result.is_err());
512 assert!(result
513 .unwrap_err()
514 .to_string()
515 .contains("Non-nullable column 'b' is missing"));
516 }
517
518 fn batch_project(
520 expr: Vec<Arc<dyn PhysicalExpr>>,
521 batch: &RecordBatch,
522 schema: SchemaRef,
523 ) -> Result<RecordBatch> {
524 let arrays = expr
525 .iter()
526 .map(|expr| {
527 expr.evaluate(batch)
528 .and_then(|v| v.into_array(batch.num_rows()))
529 })
530 .collect::<Result<Vec<_>>>()?;
531
532 if arrays.is_empty() {
533 let options =
534 RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
535 RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
536 .map_err(Into::into)
537 } else {
538 RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
539 }
540 }
541
542 #[test]
545 fn test_adapt_batches() {
546 let physical_batch = record_batch!(
547 ("a", Int32, vec![Some(1), None, Some(3)]),
548 ("extra", Utf8, vec![Some("x"), Some("y"), None])
549 )
550 .unwrap();
551
552 let physical_schema = physical_batch.schema();
553
554 let logical_schema = Arc::new(Schema::new(vec![
555 Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
558
559 let projection = vec![
560 col("b", &logical_schema).unwrap(),
561 col("a", &logical_schema).unwrap(),
562 ];
563
564 let factory = DefaultPhysicalExprAdapterFactory;
565 let adapter =
566 factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
567
568 let adapted_projection = projection
569 .into_iter()
570 .map(|expr| adapter.rewrite(expr).unwrap())
571 .collect_vec();
572
573 let adapted_schema = Arc::new(Schema::new(
574 adapted_projection
575 .iter()
576 .map(|expr| expr.return_field(&physical_schema).unwrap())
577 .collect_vec(),
578 ));
579
580 let res = batch_project(
581 adapted_projection,
582 &physical_batch,
583 Arc::clone(&adapted_schema),
584 )
585 .unwrap();
586
587 assert_eq!(res.num_columns(), 2);
588 assert_eq!(res.column(0).data_type(), &DataType::Utf8);
589 assert_eq!(res.column(1).data_type(), &DataType::Int64);
590 assert_eq!(
591 res.column(0)
592 .as_any()
593 .downcast_ref::<arrow::array::StringArray>()
594 .unwrap()
595 .iter()
596 .collect_vec(),
597 vec![None, None, None]
598 );
599 assert_eq!(
600 res.column(1)
601 .as_any()
602 .downcast_ref::<arrow::array::Int64Array>()
603 .unwrap()
604 .iter()
605 .collect_vec(),
606 vec![Some(1), None, Some(3)]
607 );
608 }
609}