1use std::sync::Arc;
35
36use arrow::datatypes::{DataType, Schema};
37use datafusion_common::{Result, ScalarValue, tree_node::Transformed};
38use datafusion_expr::Operator;
39use datafusion_expr_common::casts::try_cast_literal_to_type;
40
41use crate::PhysicalExpr;
42use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
43
44pub(crate) fn unwrap_cast_in_comparison(
46 expr: Arc<dyn PhysicalExpr>,
47 schema: &Schema,
48) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
49 if let Some(binary) = expr.as_any().downcast_ref::<BinaryExpr>()
50 && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)?
51 {
52 return Ok(Transformed::yes(unwrapped));
53 }
54 Ok(Transformed::no(expr))
55}
56
57fn try_unwrap_cast_binary(
59 binary: &BinaryExpr,
60 schema: &Schema,
61) -> Result<Option<Arc<dyn PhysicalExpr>>> {
62 if let (Some((inner_expr, _cast_type)), Some(literal)) = (
64 extract_cast_info(binary.left()),
65 binary.right().as_any().downcast_ref::<Literal>(),
66 ) && binary.op().supports_propagation()
67 && let Some(unwrapped) = try_unwrap_cast_comparison(
68 Arc::clone(inner_expr),
69 literal.value(),
70 *binary.op(),
71 schema,
72 )?
73 {
74 return Ok(Some(unwrapped));
75 }
76
77 if let (Some(literal), Some((inner_expr, _cast_type))) = (
79 binary.left().as_any().downcast_ref::<Literal>(),
80 extract_cast_info(binary.right()),
81 ) {
82 if let Some(swapped_op) = binary.op().swap()
84 && binary.op().supports_propagation()
85 && let Some(unwrapped) = try_unwrap_cast_comparison(
86 Arc::clone(inner_expr),
87 literal.value(),
88 swapped_op,
89 schema,
90 )?
91 {
92 return Ok(Some(unwrapped));
93 }
94 }
97
98 Ok(None)
99}
100
101fn extract_cast_info(
106 expr: &Arc<dyn PhysicalExpr>,
107) -> Option<(&Arc<dyn PhysicalExpr>, &DataType)> {
108 if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
109 Some((cast.expr(), cast.cast_type()))
110 } else if let Some(try_cast) = expr.as_any().downcast_ref::<TryCastExpr>() {
111 Some((try_cast.expr(), try_cast.cast_type()))
112 } else {
113 None
114 }
115}
116
117fn try_unwrap_cast_comparison(
119 inner_expr: Arc<dyn PhysicalExpr>,
120 literal_value: &ScalarValue,
121 op: Operator,
122 schema: &Schema,
123) -> Result<Option<Arc<dyn PhysicalExpr>>> {
124 let inner_type = inner_expr.data_type(schema)?;
126
127 if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
129 let literal_expr = lit(casted_literal);
130 let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr);
131 return Ok(Some(Arc::new(binary_expr)));
132 }
133
134 Ok(None)
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::expressions::{col, lit};
141 use arrow::datatypes::{DataType, Field, Schema};
142 use datafusion_common::{ScalarValue, tree_node::TreeNode};
143 use datafusion_expr::Operator;
144
145 fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
147 expr.as_any().downcast_ref::<CastExpr>().is_some()
148 || expr.as_any().downcast_ref::<TryCastExpr>().is_some()
149 }
150
151 fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
153 let left_cast_right_literal = is_cast_expr(binary.left())
155 && binary.right().as_any().downcast_ref::<Literal>().is_some();
156
157 let left_literal_right_cast =
159 binary.left().as_any().downcast_ref::<Literal>().is_some()
160 && is_cast_expr(binary.right());
161
162 left_cast_right_literal || left_literal_right_cast
163 }
164
165 fn test_schema() -> Schema {
166 Schema::new(vec![
167 Field::new("c1", DataType::Int32, false),
168 Field::new("c2", DataType::Int64, false),
169 Field::new("c3", DataType::Utf8, false),
170 ])
171 }
172
173 #[test]
174 fn test_unwrap_cast_in_binary_comparison() {
175 let schema = test_schema();
176
177 let column_expr = col("c1", &schema).unwrap();
179 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
180 let literal_expr = lit(10i64);
181 let binary_expr =
182 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
183
184 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
186
187 assert!(result.transformed);
189
190 let optimized = result.data;
192 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
193
194 assert!(!is_cast_expr(optimized_binary.left()));
196
197 let right_literal = optimized_binary
199 .right()
200 .as_any()
201 .downcast_ref::<Literal>()
202 .unwrap();
203 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10)));
204 }
205
206 #[test]
207 fn test_unwrap_cast_with_literal_on_left() {
208 let schema = test_schema();
209
210 let column_expr = col("c1", &schema).unwrap();
212 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
213 let literal_expr = lit(10i64);
214 let binary_expr =
215 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
216
217 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
219
220 assert!(result.transformed);
222
223 let optimized = result.data;
225 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
226
227 assert_eq!(*optimized_binary.op(), Operator::Gt);
229 }
230
231 #[test]
232 fn test_no_unwrap_when_types_unsupported() {
233 let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]);
234
235 let column_expr = col("f1", &schema).unwrap();
237 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None));
238 let literal_expr = lit(10.5f64);
239 let binary_expr =
240 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
241
242 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
244
245 assert!(!result.transformed);
247 }
248
249 #[test]
250 fn test_is_binary_expr_with_cast_and_literal() {
251 let schema = test_schema();
252
253 let column_expr = col("c1", &schema).unwrap();
254 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
255 let literal_expr = lit(10i64);
256 let binary_expr =
257 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
258 let binary_ref = binary_expr.as_any().downcast_ref::<BinaryExpr>().unwrap();
259
260 assert!(is_binary_expr_with_cast_and_literal(binary_ref));
261 }
262
263 #[test]
264 fn test_unwrap_cast_literal_on_left_side() {
265 let schema = Schema::new(vec![Field::new(
268 "decimal_col",
269 DataType::Decimal128(9, 2),
270 true,
271 )]);
272
273 let column_expr = col("decimal_col", &schema).unwrap();
275 let cast_expr = Arc::new(CastExpr::new(
276 column_expr,
277 DataType::Decimal128(22, 2),
278 None,
279 ));
280 let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2));
281 let binary_expr =
282 Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr));
283
284 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
286
287 assert!(result.transformed);
289
290 let optimized = result.data;
292 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
293
294 assert_eq!(*optimized_binary.op(), Operator::GtEq);
296
297 assert!(!is_cast_expr(optimized_binary.left()));
299
300 let right_literal = optimized_binary
302 .right()
303 .as_any()
304 .downcast_ref::<Literal>()
305 .unwrap();
306 assert_eq!(
307 right_literal.value().data_type(),
308 DataType::Decimal128(9, 2)
309 );
310 }
311
312 #[test]
313 fn test_unwrap_cast_with_different_comparison_operators() {
314 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
315
316 let operators = vec![
318 (Operator::Lt, Operator::Gt),
319 (Operator::LtEq, Operator::GtEq),
320 (Operator::Gt, Operator::Lt),
321 (Operator::GtEq, Operator::LtEq),
322 (Operator::Eq, Operator::Eq),
323 (Operator::NotEq, Operator::NotEq),
324 ];
325
326 for (original_op, expected_op) in operators {
327 let column_expr = col("int_col", &schema).unwrap();
329 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
330 let literal_expr = lit(100i64);
331 let binary_expr =
332 Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr));
333
334 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
336
337 assert!(result.transformed);
339
340 let optimized = result.data;
341 let optimized_binary =
342 optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
343
344 assert_eq!(
346 *optimized_binary.op(),
347 expected_op,
348 "Failed for operator {original_op:?} -> {expected_op:?}"
349 );
350
351 assert!(!is_cast_expr(optimized_binary.left()));
353
354 let right_literal = optimized_binary
356 .right()
357 .as_any()
358 .downcast_ref::<Literal>()
359 .unwrap();
360 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
361 }
362 }
363
364 #[test]
365 fn test_unwrap_cast_with_decimal_types() {
366 let test_cases = vec![
368 (9, 2, 22, 2, 400),
370 (10, 3, 20, 3, 1000),
371 (5, 1, 10, 1, 99),
372 ];
373
374 for (col_p, col_s, cast_p, cast_s, value) in test_cases {
375 let schema = Schema::new(vec![Field::new(
376 "decimal_col",
377 DataType::Decimal128(col_p, col_s),
378 true,
379 )]);
380
381 let column_expr = col("decimal_col", &schema).unwrap();
385 let cast_expr = Arc::new(CastExpr::new(
386 Arc::clone(&column_expr),
387 DataType::Decimal128(cast_p, cast_s),
388 None,
389 ));
390 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
391 let binary_expr =
392 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
393
394 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
395 assert!(result.transformed);
396
397 let cast_expr = Arc::new(CastExpr::new(
399 column_expr,
400 DataType::Decimal128(cast_p, cast_s),
401 None,
402 ));
403 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
404 let binary_expr =
405 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
406
407 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
408 assert!(result.transformed);
409 }
410 }
411
412 #[test]
413 fn test_unwrap_cast_with_null_literals() {
414 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]);
416
417 let column_expr = col("int_col", &schema).unwrap();
419 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
420 let null_literal = lit(ScalarValue::Int64(None));
421 let binary_expr =
422 Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal));
423
424 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
426
427 assert!(result.transformed);
429
430 let optimized = result.data;
432 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
433 let right_literal = optimized_binary
434 .right()
435 .as_any()
436 .downcast_ref::<Literal>()
437 .unwrap();
438 assert_eq!(right_literal.value(), &ScalarValue::Int32(None));
439 }
440
441 #[test]
442 fn test_unwrap_cast_with_try_cast() {
443 let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]);
445
446 let column_expr = col("str_col", &schema).unwrap();
448 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
449 let literal_expr = lit(100i64);
450 let binary_expr =
451 Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr));
452
453 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
455
456 assert!(!result.transformed);
458 }
459
460 #[test]
461 fn test_unwrap_cast_preserves_non_comparison_operators() {
462 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
464
465 let column_expr = col("int_col", &schema).unwrap();
467
468 let cast1 = Arc::new(CastExpr::new(
469 Arc::clone(&column_expr),
470 DataType::Int64,
471 None,
472 ));
473 let lit1 = lit(10i64);
474 let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1));
475
476 let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
477 let lit2 = lit(20i64);
478 let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2));
479
480 let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2));
481
482 let result = (and_expr as Arc<dyn PhysicalExpr>)
484 .transform_down(|node| unwrap_cast_in_comparison(node, &schema))
485 .unwrap();
486
487 assert!(result.transformed);
489
490 let optimized = result.data;
492 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
493 assert_eq!(*and_binary.op(), Operator::And);
494
495 let left_binary = and_binary
497 .left()
498 .as_any()
499 .downcast_ref::<BinaryExpr>()
500 .unwrap();
501 let right_binary = and_binary
502 .right()
503 .as_any()
504 .downcast_ref::<BinaryExpr>()
505 .unwrap();
506
507 assert!(!is_cast_expr(left_binary.left()));
508 assert!(!is_cast_expr(right_binary.left()));
509 }
510
511 #[test]
512 fn test_try_cast_unwrapping() {
513 let schema = test_schema();
514
515 let column_expr = col("c1", &schema).unwrap();
517 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
518 let literal_expr = lit(100i64);
519 let binary_expr =
520 Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr));
521
522 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
524
525 assert!(result.transformed);
527
528 let optimized = result.data;
529 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
530
531 assert!(!is_cast_expr(optimized_binary.left()));
533
534 let right_literal = optimized_binary
536 .right()
537 .as_any()
538 .downcast_ref::<Literal>()
539 .unwrap();
540 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
541 }
542
543 #[test]
544 fn test_non_swappable_operator() {
545 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
547
548 let column_expr = col("int_col", &schema).unwrap();
551 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
552 let literal_expr = lit(10i64);
553 let binary_expr =
554 Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr));
555
556 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
558
559 assert!(!result.transformed);
561 }
562
563 #[test]
564 fn test_cast_that_cannot_be_unwrapped_overflow() {
565 let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]);
567
568 let column_expr = col("small_int", &schema).unwrap();
571 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
572 let literal_expr = lit(1000i64); let binary_expr =
574 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
575
576 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
578
579 assert!(!result.transformed);
581 }
582
583 #[test]
584 fn test_complex_nested_expression() {
585 let schema = test_schema();
586
587 let c1_expr = col("c1", &schema).unwrap();
590 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
591 let c1_literal = lit(10i64);
592 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
593
594 let c2_expr = col("c2", &schema).unwrap();
595 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
596 let c2_literal = lit(20i32);
597 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal));
598
599 let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary));
601
602 let result = (and_expr as Arc<dyn PhysicalExpr>)
604 .transform_down(|node| unwrap_cast_in_comparison(node, &schema))
605 .unwrap();
606
607 assert!(result.transformed);
609
610 let optimized = result.data;
612 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
613
614 let left_binary = and_binary
616 .left()
617 .as_any()
618 .downcast_ref::<BinaryExpr>()
619 .unwrap();
620 assert!(!is_cast_expr(left_binary.left()));
621 let left_literal = left_binary
622 .right()
623 .as_any()
624 .downcast_ref::<Literal>()
625 .unwrap();
626 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10)));
627
628 let right_binary = and_binary
630 .right()
631 .as_any()
632 .downcast_ref::<BinaryExpr>()
633 .unwrap();
634 assert!(!is_cast_expr(right_binary.left()));
635 let right_literal = right_binary
636 .right()
637 .as_any()
638 .downcast_ref::<Literal>()
639 .unwrap();
640 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20)));
641 }
642}