datafusion_optimizer/simplify_expressions/
simplify_literal.rs1use crate::simplify_expressions::ExprSimplifier;
25use arrow::datatypes::ArrowPrimitiveType;
26use datafusion_common::{
27 DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err,
28 plan_err,
29};
30use datafusion_expr::Expr;
31use datafusion_expr::simplify::SimplifyContext;
32use std::sync::Arc;
33
34pub fn parse_literal<T>(expr: &Expr) -> Result<T::Native>
45where
46 T: ArrowPrimitiveType,
47 T::Native: TryFrom<ScalarValue, Error = DataFusionError>,
48{
49 let schema = DFSchemaRef::new(DFSchema::empty());
51
52 log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE);
53
54 let simplifier =
55 ExprSimplifier::new(SimplifyContext::default().with_schema(Arc::clone(&schema)));
56
57 let simplified_expr: Expr = simplifier
59 .simplify(expr.clone())
60 .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?;
61 let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?;
62 log::debug!("Coerced expression: {:?}", &coerced_expr);
63
64 match coerced_expr {
65 Expr::Literal(scalar_value, _) => {
66 let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?;
69
70 T::Native::try_from(casted_scalar).map_err(|err| {
72 plan_datafusion_err!(
73 "Cannot extract {} from scalar value: {err}",
74 std::any::type_name::<T>()
75 )
76 })
77 }
78 actual => {
79 plan_err!(
80 "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression"
81 )
82 }
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use arrow::datatypes::{Float64Type, Int64Type};
90 use datafusion_expr::{BinaryExpr, lit};
91 use datafusion_expr_common::operator::Operator;
92
93 #[test]
94 fn test_parse_sql_float_literal() {
95 let test_cases = vec![
96 (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0),
97 (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0),
98 (
99 Expr::BinaryExpr(BinaryExpr::new(
100 Box::new(lit(50.0)),
101 Operator::Minus,
102 Box::new(lit(10.0)),
103 )),
104 40.0,
105 ),
106 (
107 Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None),
108 100.0,
109 ),
110 (
111 Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None),
112 0.25,
113 ),
114 ];
115
116 for (expr, expected) in test_cases {
117 let result: Result<f64> = parse_literal::<Float64Type>(&expr);
118
119 match result {
120 Ok(value) => {
121 assert!(
122 (value - expected).abs() < 1e-10,
123 "For expression '{expr}': expected {expected}, got {value}",
124 );
125 }
126 Err(e) => panic!("Failed to parse expression '{expr}': {e}"),
127 }
128 }
129 }
130
131 #[test]
132 fn test_parse_sql_integer_literal() {
133 let expr = Expr::BinaryExpr(BinaryExpr::new(
134 Box::new(lit(2)),
135 Operator::Plus,
136 Box::new(lit(4)),
137 ));
138
139 let result: Result<i64> = parse_literal::<Int64Type>(&expr);
140
141 match result {
142 Ok(value) => {
143 assert_eq!(6, value);
144 }
145 Err(e) => panic!("Failed to parse expression: {e}"),
146 }
147 }
148}