datafusion_optimizer/simplify_expressions/
simplify_literal.rs1use crate::simplify_expressions::ExprSimplifier;
25use arrow::datatypes::ArrowPrimitiveType;
26use datafusion_common::{
27 DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err,
28 plan_err,
29};
30use datafusion_expr::Expr;
31use datafusion_expr::execution_props::ExecutionProps;
32use datafusion_expr::simplify::SimplifyContext;
33use std::sync::Arc;
34
35pub fn parse_literal<T>(expr: &Expr) -> Result<T::Native>
46where
47 T: ArrowPrimitiveType,
48 T::Native: TryFrom<ScalarValue, Error = DataFusionError>,
49{
50 let schema = DFSchemaRef::new(DFSchema::empty());
52
53 log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE);
54
55 let execution_props = ExecutionProps::new();
56 let simplifier = ExprSimplifier::new(
57 SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
58 );
59
60 let simplified_expr: Expr = simplifier
62 .simplify(expr.clone())
63 .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?;
64 let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?;
65 log::debug!("Coerced expression: {:?}", &coerced_expr);
66
67 match coerced_expr {
68 Expr::Literal(scalar_value, _) => {
69 let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?;
72
73 T::Native::try_from(casted_scalar).map_err(|err| {
75 plan_datafusion_err!(
76 "Cannot extract {} from scalar value: {err}",
77 std::any::type_name::<T>()
78 )
79 })
80 }
81 actual => {
82 plan_err!(
83 "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression"
84 )
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use arrow::datatypes::{Float64Type, Int64Type};
93 use datafusion_expr::{BinaryExpr, lit};
94 use datafusion_expr_common::operator::Operator;
95
96 #[test]
97 fn test_parse_sql_float_literal() {
98 let test_cases = vec![
99 (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0),
100 (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0),
101 (
102 Expr::BinaryExpr(BinaryExpr::new(
103 Box::new(lit(50.0)),
104 Operator::Minus,
105 Box::new(lit(10.0)),
106 )),
107 40.0,
108 ),
109 (
110 Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None),
111 100.0,
112 ),
113 (
114 Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None),
115 0.25,
116 ),
117 ];
118
119 for (expr, expected) in test_cases {
120 let result: Result<f64> = parse_literal::<Float64Type>(&expr);
121
122 match result {
123 Ok(value) => {
124 assert!(
125 (value - expected).abs() < 1e-10,
126 "For expression '{expr}': expected {expected}, got {value}",
127 );
128 }
129 Err(e) => panic!("Failed to parse expression '{expr}': {e}"),
130 }
131 }
132 }
133
134 #[test]
135 fn test_parse_sql_integer_literal() {
136 let expr = Expr::BinaryExpr(BinaryExpr::new(
137 Box::new(lit(2)),
138 Operator::Plus,
139 Box::new(lit(4)),
140 ));
141
142 let result: Result<i64> = parse_literal::<Int64Type>(&expr);
143
144 match result {
145 Ok(value) => {
146 assert_eq!(6, value);
147 }
148 Err(e) => panic!("Failed to parse expression: {e}"),
149 }
150 }
151}