1use std::sync::Arc;
35
36use arrow::datatypes::{DataType, Schema};
37use datafusion_common::{
38 Result, ScalarValue,
39 tree_node::{Transformed, TreeNode},
40};
41use datafusion_expr::Operator;
42use datafusion_expr_common::casts::try_cast_literal_to_type;
43
44use crate::PhysicalExpr;
45use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
46
47pub(crate) fn unwrap_cast_in_comparison(
49 expr: Arc<dyn PhysicalExpr>,
50 schema: &Schema,
51) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
52 expr.transform_down(|e| {
53 if let Some(binary) = e.as_any().downcast_ref::<BinaryExpr>()
54 && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)?
55 {
56 return Ok(Transformed::yes(unwrapped));
57 }
58 Ok(Transformed::no(e))
59 })
60}
61
62fn try_unwrap_cast_binary(
64 binary: &BinaryExpr,
65 schema: &Schema,
66) -> Result<Option<Arc<dyn PhysicalExpr>>> {
67 if let (Some((inner_expr, _cast_type)), Some(literal)) = (
69 extract_cast_info(binary.left()),
70 binary.right().as_any().downcast_ref::<Literal>(),
71 ) && binary.op().supports_propagation()
72 && let Some(unwrapped) = try_unwrap_cast_comparison(
73 Arc::clone(inner_expr),
74 literal.value(),
75 *binary.op(),
76 schema,
77 )?
78 {
79 return Ok(Some(unwrapped));
80 }
81
82 if let (Some(literal), Some((inner_expr, _cast_type))) = (
84 binary.left().as_any().downcast_ref::<Literal>(),
85 extract_cast_info(binary.right()),
86 ) {
87 if let Some(swapped_op) = binary.op().swap()
89 && binary.op().supports_propagation()
90 && let Some(unwrapped) = try_unwrap_cast_comparison(
91 Arc::clone(inner_expr),
92 literal.value(),
93 swapped_op,
94 schema,
95 )?
96 {
97 return Ok(Some(unwrapped));
98 }
99 }
102
103 Ok(None)
104}
105
106fn extract_cast_info(
111 expr: &Arc<dyn PhysicalExpr>,
112) -> Option<(&Arc<dyn PhysicalExpr>, &DataType)> {
113 if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
114 Some((cast.expr(), cast.cast_type()))
115 } else if let Some(try_cast) = expr.as_any().downcast_ref::<TryCastExpr>() {
116 Some((try_cast.expr(), try_cast.cast_type()))
117 } else {
118 None
119 }
120}
121
122fn try_unwrap_cast_comparison(
124 inner_expr: Arc<dyn PhysicalExpr>,
125 literal_value: &ScalarValue,
126 op: Operator,
127 schema: &Schema,
128) -> Result<Option<Arc<dyn PhysicalExpr>>> {
129 let inner_type = inner_expr.data_type(schema)?;
131
132 if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
134 let literal_expr = lit(casted_literal);
135 let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr);
136 return Ok(Some(Arc::new(binary_expr)));
137 }
138
139 Ok(None)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::expressions::{col, lit};
146 use arrow::datatypes::{DataType, Field, Schema};
147 use datafusion_common::ScalarValue;
148 use datafusion_expr::Operator;
149
150 fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
152 expr.as_any().downcast_ref::<CastExpr>().is_some()
153 || expr.as_any().downcast_ref::<TryCastExpr>().is_some()
154 }
155
156 fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
158 let left_cast_right_literal = is_cast_expr(binary.left())
160 && binary.right().as_any().downcast_ref::<Literal>().is_some();
161
162 let left_literal_right_cast =
164 binary.left().as_any().downcast_ref::<Literal>().is_some()
165 && is_cast_expr(binary.right());
166
167 left_cast_right_literal || left_literal_right_cast
168 }
169
170 fn test_schema() -> Schema {
171 Schema::new(vec![
172 Field::new("c1", DataType::Int32, false),
173 Field::new("c2", DataType::Int64, false),
174 Field::new("c3", DataType::Utf8, false),
175 ])
176 }
177
178 #[test]
179 fn test_unwrap_cast_in_binary_comparison() {
180 let schema = test_schema();
181
182 let column_expr = col("c1", &schema).unwrap();
184 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
185 let literal_expr = lit(10i64);
186 let binary_expr =
187 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
188
189 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
191
192 assert!(result.transformed);
194
195 let optimized = result.data;
197 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
198
199 assert!(!is_cast_expr(optimized_binary.left()));
201
202 let right_literal = optimized_binary
204 .right()
205 .as_any()
206 .downcast_ref::<Literal>()
207 .unwrap();
208 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10)));
209 }
210
211 #[test]
212 fn test_unwrap_cast_with_literal_on_left() {
213 let schema = test_schema();
214
215 let column_expr = col("c1", &schema).unwrap();
217 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
218 let literal_expr = lit(10i64);
219 let binary_expr =
220 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
221
222 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
224
225 assert!(result.transformed);
227
228 let optimized = result.data;
230 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
231
232 assert_eq!(*optimized_binary.op(), Operator::Gt);
234 }
235
236 #[test]
237 fn test_no_unwrap_when_types_unsupported() {
238 let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]);
239
240 let column_expr = col("f1", &schema).unwrap();
242 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None));
243 let literal_expr = lit(10.5f64);
244 let binary_expr =
245 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
246
247 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
249
250 assert!(!result.transformed);
252 }
253
254 #[test]
255 fn test_is_binary_expr_with_cast_and_literal() {
256 let schema = test_schema();
257
258 let column_expr = col("c1", &schema).unwrap();
259 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
260 let literal_expr = lit(10i64);
261 let binary_expr =
262 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
263 let binary_ref = binary_expr.as_any().downcast_ref::<BinaryExpr>().unwrap();
264
265 assert!(is_binary_expr_with_cast_and_literal(binary_ref));
266 }
267
268 #[test]
269 fn test_unwrap_cast_literal_on_left_side() {
270 let schema = Schema::new(vec![Field::new(
273 "decimal_col",
274 DataType::Decimal128(9, 2),
275 true,
276 )]);
277
278 let column_expr = col("decimal_col", &schema).unwrap();
280 let cast_expr = Arc::new(CastExpr::new(
281 column_expr,
282 DataType::Decimal128(22, 2),
283 None,
284 ));
285 let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2));
286 let binary_expr =
287 Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr));
288
289 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
291
292 assert!(result.transformed);
294
295 let optimized = result.data;
297 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
298
299 assert_eq!(*optimized_binary.op(), Operator::GtEq);
301
302 assert!(!is_cast_expr(optimized_binary.left()));
304
305 let right_literal = optimized_binary
307 .right()
308 .as_any()
309 .downcast_ref::<Literal>()
310 .unwrap();
311 assert_eq!(
312 right_literal.value().data_type(),
313 DataType::Decimal128(9, 2)
314 );
315 }
316
317 #[test]
318 fn test_unwrap_cast_with_different_comparison_operators() {
319 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
320
321 let operators = vec![
323 (Operator::Lt, Operator::Gt),
324 (Operator::LtEq, Operator::GtEq),
325 (Operator::Gt, Operator::Lt),
326 (Operator::GtEq, Operator::LtEq),
327 (Operator::Eq, Operator::Eq),
328 (Operator::NotEq, Operator::NotEq),
329 ];
330
331 for (original_op, expected_op) in operators {
332 let column_expr = col("int_col", &schema).unwrap();
334 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
335 let literal_expr = lit(100i64);
336 let binary_expr =
337 Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr));
338
339 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
341
342 assert!(result.transformed);
344
345 let optimized = result.data;
346 let optimized_binary =
347 optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
348
349 assert_eq!(
351 *optimized_binary.op(),
352 expected_op,
353 "Failed for operator {original_op:?} -> {expected_op:?}"
354 );
355
356 assert!(!is_cast_expr(optimized_binary.left()));
358
359 let right_literal = optimized_binary
361 .right()
362 .as_any()
363 .downcast_ref::<Literal>()
364 .unwrap();
365 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
366 }
367 }
368
369 #[test]
370 fn test_unwrap_cast_with_decimal_types() {
371 let test_cases = vec![
373 (9, 2, 22, 2, 400),
375 (10, 3, 20, 3, 1000),
376 (5, 1, 10, 1, 99),
377 ];
378
379 for (col_p, col_s, cast_p, cast_s, value) in test_cases {
380 let schema = Schema::new(vec![Field::new(
381 "decimal_col",
382 DataType::Decimal128(col_p, col_s),
383 true,
384 )]);
385
386 let column_expr = col("decimal_col", &schema).unwrap();
390 let cast_expr = Arc::new(CastExpr::new(
391 Arc::clone(&column_expr),
392 DataType::Decimal128(cast_p, cast_s),
393 None,
394 ));
395 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
396 let binary_expr =
397 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
398
399 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
400 assert!(result.transformed);
401
402 let cast_expr = Arc::new(CastExpr::new(
404 column_expr,
405 DataType::Decimal128(cast_p, cast_s),
406 None,
407 ));
408 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
409 let binary_expr =
410 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
411
412 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
413 assert!(result.transformed);
414 }
415 }
416
417 #[test]
418 fn test_unwrap_cast_with_null_literals() {
419 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]);
421
422 let column_expr = col("int_col", &schema).unwrap();
424 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
425 let null_literal = lit(ScalarValue::Int64(None));
426 let binary_expr =
427 Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal));
428
429 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
431
432 assert!(result.transformed);
434
435 let optimized = result.data;
437 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
438 let right_literal = optimized_binary
439 .right()
440 .as_any()
441 .downcast_ref::<Literal>()
442 .unwrap();
443 assert_eq!(right_literal.value(), &ScalarValue::Int32(None));
444 }
445
446 #[test]
447 fn test_unwrap_cast_with_try_cast() {
448 let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]);
450
451 let column_expr = col("str_col", &schema).unwrap();
453 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
454 let literal_expr = lit(100i64);
455 let binary_expr =
456 Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr));
457
458 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
460
461 assert!(!result.transformed);
463 }
464
465 #[test]
466 fn test_unwrap_cast_preserves_non_comparison_operators() {
467 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
469
470 let column_expr = col("int_col", &schema).unwrap();
472
473 let cast1 = Arc::new(CastExpr::new(
474 Arc::clone(&column_expr),
475 DataType::Int64,
476 None,
477 ));
478 let lit1 = lit(10i64);
479 let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1));
480
481 let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
482 let lit2 = lit(20i64);
483 let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2));
484
485 let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2));
486
487 let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap();
489
490 assert!(result.transformed);
492
493 let optimized = result.data;
495 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
496 assert_eq!(*and_binary.op(), Operator::And);
497
498 let left_binary = and_binary
500 .left()
501 .as_any()
502 .downcast_ref::<BinaryExpr>()
503 .unwrap();
504 let right_binary = and_binary
505 .right()
506 .as_any()
507 .downcast_ref::<BinaryExpr>()
508 .unwrap();
509
510 assert!(!is_cast_expr(left_binary.left()));
511 assert!(!is_cast_expr(right_binary.left()));
512 }
513
514 #[test]
515 fn test_try_cast_unwrapping() {
516 let schema = test_schema();
517
518 let column_expr = col("c1", &schema).unwrap();
520 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
521 let literal_expr = lit(100i64);
522 let binary_expr =
523 Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr));
524
525 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
527
528 assert!(result.transformed);
530
531 let optimized = result.data;
532 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
533
534 assert!(!is_cast_expr(optimized_binary.left()));
536
537 let right_literal = optimized_binary
539 .right()
540 .as_any()
541 .downcast_ref::<Literal>()
542 .unwrap();
543 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
544 }
545
546 #[test]
547 fn test_non_swappable_operator() {
548 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
550
551 let column_expr = col("int_col", &schema).unwrap();
554 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
555 let literal_expr = lit(10i64);
556 let binary_expr =
557 Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr));
558
559 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
561
562 assert!(!result.transformed);
564 }
565
566 #[test]
567 fn test_cast_that_cannot_be_unwrapped_overflow() {
568 let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]);
570
571 let column_expr = col("small_int", &schema).unwrap();
574 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
575 let literal_expr = lit(1000i64); let binary_expr =
577 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
578
579 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
581
582 assert!(!result.transformed);
584 }
585
586 #[test]
587 fn test_complex_nested_expression() {
588 let schema = test_schema();
589
590 let c1_expr = col("c1", &schema).unwrap();
593 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
594 let c1_literal = lit(10i64);
595 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
596
597 let c2_expr = col("c2", &schema).unwrap();
598 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
599 let c2_literal = lit(20i32);
600 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal));
601
602 let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary));
604
605 let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap();
607
608 assert!(result.transformed);
610
611 let optimized = result.data;
613 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
614
615 let left_binary = and_binary
617 .left()
618 .as_any()
619 .downcast_ref::<BinaryExpr>()
620 .unwrap();
621 assert!(!is_cast_expr(left_binary.left()));
622 let left_literal = left_binary
623 .right()
624 .as_any()
625 .downcast_ref::<Literal>()
626 .unwrap();
627 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10)));
628
629 let right_binary = and_binary
631 .right()
632 .as_any()
633 .downcast_ref::<BinaryExpr>()
634 .unwrap();
635 assert!(!is_cast_expr(right_binary.left()));
636 let right_literal = right_binary
637 .right()
638 .as_any()
639 .downcast_ref::<Literal>()
640 .unwrap();
641 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20)));
642 }
643}