1use std::sync::Arc;
35
36use arrow::datatypes::{DataType, Schema};
37use datafusion_common::{Result, ScalarValue, tree_node::Transformed};
38use datafusion_expr::Operator;
39use datafusion_expr_common::casts::try_cast_literal_to_type;
40
41use crate::PhysicalExpr;
42use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
43
44pub(crate) fn unwrap_cast_in_comparison(
46 expr: Arc<dyn PhysicalExpr>,
47 schema: &Schema,
48) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
49 if let Some(binary) = expr.downcast_ref::<BinaryExpr>()
50 && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)?
51 {
52 return Ok(Transformed::yes(unwrapped));
53 }
54 Ok(Transformed::no(expr))
55}
56
57fn try_unwrap_cast_binary(
59 binary: &BinaryExpr,
60 schema: &Schema,
61) -> Result<Option<Arc<dyn PhysicalExpr>>> {
62 if let (Some((inner_expr, _cast_type)), Some(literal)) = (
64 extract_cast_info(binary.left()),
65 binary.right().downcast_ref::<Literal>(),
66 ) && binary.op().supports_propagation()
67 && let Some(unwrapped) = try_unwrap_cast_comparison(
68 Arc::clone(inner_expr),
69 literal.value(),
70 *binary.op(),
71 schema,
72 )?
73 {
74 return Ok(Some(unwrapped));
75 }
76
77 if let (Some(literal), Some((inner_expr, _cast_type))) = (
79 binary.left().downcast_ref::<Literal>(),
80 extract_cast_info(binary.right()),
81 ) {
82 if let Some(swapped_op) = binary.op().swap()
84 && binary.op().supports_propagation()
85 && let Some(unwrapped) = try_unwrap_cast_comparison(
86 Arc::clone(inner_expr),
87 literal.value(),
88 swapped_op,
89 schema,
90 )?
91 {
92 return Ok(Some(unwrapped));
93 }
94 }
97
98 Ok(None)
99}
100
101fn extract_cast_info(
106 expr: &Arc<dyn PhysicalExpr>,
107) -> Option<(&Arc<dyn PhysicalExpr>, &DataType)> {
108 if let Some(cast) = expr.downcast_ref::<CastExpr>() {
109 Some((cast.expr(), cast.cast_type()))
110 } else if let Some(try_cast) = expr.downcast_ref::<TryCastExpr>() {
111 Some((try_cast.expr(), try_cast.cast_type()))
112 } else {
113 None
114 }
115}
116
117fn try_unwrap_cast_comparison(
119 inner_expr: Arc<dyn PhysicalExpr>,
120 literal_value: &ScalarValue,
121 op: Operator,
122 schema: &Schema,
123) -> Result<Option<Arc<dyn PhysicalExpr>>> {
124 let inner_type = inner_expr.data_type(schema)?;
126
127 if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
129 let literal_expr = lit(casted_literal);
130 let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr);
131 return Ok(Some(Arc::new(binary_expr)));
132 }
133
134 Ok(None)
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::expressions::col;
141 use arrow::datatypes::Field;
142 use datafusion_common::tree_node::TreeNode;
143
144 fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
146 expr.downcast_ref::<CastExpr>().is_some()
147 || expr.downcast_ref::<TryCastExpr>().is_some()
148 }
149
150 fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
152 let left_cast_right_literal = is_cast_expr(binary.left())
154 && binary.right().downcast_ref::<Literal>().is_some();
155
156 let left_literal_right_cast = binary.left().downcast_ref::<Literal>().is_some()
158 && is_cast_expr(binary.right());
159
160 left_cast_right_literal || left_literal_right_cast
161 }
162
163 fn test_schema() -> Schema {
164 Schema::new(vec![
165 Field::new("c1", DataType::Int32, false),
166 Field::new("c2", DataType::Int64, false),
167 Field::new("c3", DataType::Utf8, false),
168 ])
169 }
170
171 #[test]
172 fn test_unwrap_cast_in_binary_comparison() {
173 let schema = test_schema();
174
175 let column_expr = col("c1", &schema).unwrap();
177 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
178 let literal_expr = lit(10i64);
179 let binary_expr =
180 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
181
182 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
184
185 assert!(result.transformed);
187
188 let optimized = result.data;
190 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
191
192 assert!(!is_cast_expr(optimized_binary.left()));
194
195 let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
197 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10)));
198 }
199
200 #[test]
201 fn test_unwrap_cast_with_literal_on_left() {
202 let schema = test_schema();
203
204 let column_expr = col("c1", &schema).unwrap();
206 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
207 let literal_expr = lit(10i64);
208 let binary_expr =
209 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
210
211 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
213
214 assert!(result.transformed);
216
217 let optimized = result.data;
219 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
220
221 assert_eq!(*optimized_binary.op(), Operator::Gt);
223 }
224
225 #[test]
226 fn test_no_unwrap_when_types_unsupported() {
227 let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]);
228
229 let column_expr = col("f1", &schema).unwrap();
231 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None));
232 let literal_expr = lit(10.5f64);
233 let binary_expr =
234 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
235
236 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
238
239 assert!(!result.transformed);
241 }
242
243 #[test]
244 fn test_is_binary_expr_with_cast_and_literal() {
245 let schema = test_schema();
246
247 let column_expr = col("c1", &schema).unwrap();
248 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
249 let literal_expr = lit(10i64);
250 let binary_expr =
251 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
252 assert!(is_binary_expr_with_cast_and_literal(&binary_expr));
253 }
254
255 #[test]
256 fn test_unwrap_cast_literal_on_left_side() {
257 let schema = Schema::new(vec![Field::new(
260 "decimal_col",
261 DataType::Decimal128(9, 2),
262 true,
263 )]);
264
265 let column_expr = col("decimal_col", &schema).unwrap();
267 let cast_expr = Arc::new(CastExpr::new(
268 column_expr,
269 DataType::Decimal128(22, 2),
270 None,
271 ));
272 let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2));
273 let binary_expr =
274 Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr));
275
276 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
278
279 assert!(result.transformed);
281
282 let optimized = result.data;
284 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
285
286 assert_eq!(*optimized_binary.op(), Operator::GtEq);
288
289 assert!(!is_cast_expr(optimized_binary.left()));
291
292 let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
294 assert_eq!(
295 right_literal.value().data_type(),
296 DataType::Decimal128(9, 2)
297 );
298 }
299
300 #[test]
301 fn test_unwrap_cast_with_different_comparison_operators() {
302 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
303
304 let operators = vec![
306 (Operator::Lt, Operator::Gt),
307 (Operator::LtEq, Operator::GtEq),
308 (Operator::Gt, Operator::Lt),
309 (Operator::GtEq, Operator::LtEq),
310 (Operator::Eq, Operator::Eq),
311 (Operator::NotEq, Operator::NotEq),
312 ];
313
314 for (original_op, expected_op) in operators {
315 let column_expr = col("int_col", &schema).unwrap();
317 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
318 let literal_expr = lit(100i64);
319 let binary_expr =
320 Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr));
321
322 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
324
325 assert!(result.transformed);
327
328 let optimized = result.data;
329 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
330
331 assert_eq!(
333 *optimized_binary.op(),
334 expected_op,
335 "Failed for operator {original_op:?} -> {expected_op:?}"
336 );
337
338 assert!(!is_cast_expr(optimized_binary.left()));
340
341 let right_literal =
343 optimized_binary.right().downcast_ref::<Literal>().unwrap();
344 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
345 }
346 }
347
348 #[test]
349 fn test_unwrap_cast_with_decimal_types() {
350 let test_cases = vec![
352 (9, 2, 22, 2, 400),
354 (10, 3, 20, 3, 1000),
355 (5, 1, 10, 1, 99),
356 ];
357
358 for (col_p, col_s, cast_p, cast_s, value) in test_cases {
359 let schema = Schema::new(vec![Field::new(
360 "decimal_col",
361 DataType::Decimal128(col_p, col_s),
362 true,
363 )]);
364
365 let column_expr = col("decimal_col", &schema).unwrap();
369 let cast_expr = Arc::new(CastExpr::new(
370 Arc::clone(&column_expr),
371 DataType::Decimal128(cast_p, cast_s),
372 None,
373 ));
374 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
375 let binary_expr =
376 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
377
378 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
379 assert!(result.transformed);
380
381 let cast_expr = Arc::new(CastExpr::new(
383 column_expr,
384 DataType::Decimal128(cast_p, cast_s),
385 None,
386 ));
387 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
388 let binary_expr =
389 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
390
391 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
392 assert!(result.transformed);
393 }
394 }
395
396 #[test]
397 fn test_unwrap_cast_with_null_literals() {
398 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]);
400
401 let column_expr = col("int_col", &schema).unwrap();
403 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
404 let null_literal = lit(ScalarValue::Int64(None));
405 let binary_expr =
406 Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal));
407
408 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
410
411 assert!(result.transformed);
413
414 let optimized = result.data;
416 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
417 let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
418 assert_eq!(right_literal.value(), &ScalarValue::Int32(None));
419 }
420
421 #[test]
422 fn test_unwrap_cast_with_try_cast() {
423 let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]);
425
426 let column_expr = col("str_col", &schema).unwrap();
428 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
429 let literal_expr = lit(100i64);
430 let binary_expr =
431 Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr));
432
433 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
435
436 assert!(!result.transformed);
438 }
439
440 #[test]
441 fn test_unwrap_cast_preserves_non_comparison_operators() {
442 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
444
445 let column_expr = col("int_col", &schema).unwrap();
447
448 let cast1 = Arc::new(CastExpr::new(
449 Arc::clone(&column_expr),
450 DataType::Int64,
451 None,
452 ));
453 let lit1 = lit(10i64);
454 let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1));
455
456 let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
457 let lit2 = lit(20i64);
458 let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2));
459
460 let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2));
461
462 let result = (and_expr as Arc<dyn PhysicalExpr>)
464 .transform_down(|node| unwrap_cast_in_comparison(node, &schema))
465 .unwrap();
466
467 assert!(result.transformed);
469
470 let optimized = result.data;
472 let and_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
473 assert_eq!(*and_binary.op(), Operator::And);
474
475 let left_binary = and_binary.left().downcast_ref::<BinaryExpr>().unwrap();
477 let right_binary = and_binary.right().downcast_ref::<BinaryExpr>().unwrap();
478
479 assert!(!is_cast_expr(left_binary.left()));
480 assert!(!is_cast_expr(right_binary.left()));
481 }
482
483 #[test]
484 fn test_try_cast_unwrapping() {
485 let schema = test_schema();
486
487 let column_expr = col("c1", &schema).unwrap();
489 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
490 let literal_expr = lit(100i64);
491 let binary_expr =
492 Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr));
493
494 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
496
497 assert!(result.transformed);
499
500 let optimized = result.data;
501 let optimized_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
502
503 assert!(!is_cast_expr(optimized_binary.left()));
505
506 let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
508 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
509 }
510
511 #[test]
512 fn test_non_swappable_operator() {
513 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
515
516 let column_expr = col("int_col", &schema).unwrap();
519 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
520 let literal_expr = lit(10i64);
521 let binary_expr =
522 Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr));
523
524 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
526
527 assert!(!result.transformed);
529 }
530
531 #[test]
532 fn test_cast_that_cannot_be_unwrapped_overflow() {
533 let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]);
535
536 let column_expr = col("small_int", &schema).unwrap();
539 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
540 let literal_expr = lit(1000i64); let binary_expr =
542 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
543
544 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
546
547 assert!(!result.transformed);
549 }
550
551 #[test]
552 fn test_complex_nested_expression() {
553 let schema = test_schema();
554
555 let c1_expr = col("c1", &schema).unwrap();
558 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
559 let c1_literal = lit(10i64);
560 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
561
562 let c2_expr = col("c2", &schema).unwrap();
563 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
564 let c2_literal = lit(20i32);
565 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal));
566
567 let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary));
569
570 let result = (and_expr as Arc<dyn PhysicalExpr>)
572 .transform_down(|node| unwrap_cast_in_comparison(node, &schema))
573 .unwrap();
574
575 assert!(result.transformed);
577
578 let optimized = result.data;
580 let and_binary = optimized.downcast_ref::<BinaryExpr>().unwrap();
581
582 let left_binary = and_binary.left().downcast_ref::<BinaryExpr>().unwrap();
584 assert!(!is_cast_expr(left_binary.left()));
585 let left_literal = left_binary.right().downcast_ref::<Literal>().unwrap();
586 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10)));
587
588 let right_binary = and_binary.right().downcast_ref::<BinaryExpr>().unwrap();
590 assert!(!is_cast_expr(right_binary.left()));
591 let right_literal = right_binary.right().downcast_ref::<Literal>().unwrap();
592 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20)));
593 }
594}