datafusion_optimizer/simplify_expressions/
simplify_literal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Parses and simplifies an expression to a literal of a given type.
19//!
20//! This module provides functionality to parse and simplify static expressions
21//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required
22//! in a planning (not an execution) phase, they need to be reduced to literals of a given type.
23
24use crate::simplify_expressions::ExprSimplifier;
25use arrow::datatypes::ArrowPrimitiveType;
26use datafusion_common::{
27    DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err,
28    plan_err,
29};
30use datafusion_expr::Expr;
31use datafusion_expr::execution_props::ExecutionProps;
32use datafusion_expr::simplify::SimplifyContext;
33use std::sync::Arc;
34
35/// Parse and simplifies an expression to a numeric literal,
36/// corresponding to an arrow primitive type `T` (for example, Float64Type).
37///
38/// This function simplifies and coerces the expression, then extracts the underlying
39/// native type using `TryFrom<ScalarValue>`.
40///
41/// # Example
42/// ```ignore
43/// let value: f64 = parse_literal::<Float64Type>(expr)?;
44/// ```
45pub fn parse_literal<T>(expr: &Expr) -> Result<T::Native>
46where
47    T: ArrowPrimitiveType,
48    T::Native: TryFrom<ScalarValue, Error = DataFusionError>,
49{
50    // Empty schema is sufficient because it parses only literal expressions
51    let schema = DFSchemaRef::new(DFSchema::empty());
52
53    log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE);
54
55    let execution_props = ExecutionProps::new();
56    let simplifier = ExprSimplifier::new(
57        SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)),
58    );
59
60    // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5)
61    let simplified_expr: Expr = simplifier
62        .simplify(expr.clone())
63        .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?;
64    let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?;
65    log::debug!("Coerced expression: {:?}", &coerced_expr);
66
67    match coerced_expr {
68        Expr::Literal(scalar_value, _) => {
69            // It is a literal - proceed to the underlying value
70            // Cast to the target type if needed
71            let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?;
72
73            // Extract the native type
74            T::Native::try_from(casted_scalar).map_err(|err| {
75                plan_datafusion_err!(
76                    "Cannot extract {} from scalar value: {err}",
77                    std::any::type_name::<T>()
78                )
79            })
80        }
81        actual => {
82            plan_err!(
83                "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression"
84            )
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use arrow::datatypes::{Float64Type, Int64Type};
93    use datafusion_expr::{BinaryExpr, lit};
94    use datafusion_expr_common::operator::Operator;
95
96    #[test]
97    fn test_parse_sql_float_literal() {
98        let test_cases = vec![
99            (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0),
100            (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0),
101            (
102                Expr::BinaryExpr(BinaryExpr::new(
103                    Box::new(lit(50.0)),
104                    Operator::Minus,
105                    Box::new(lit(10.0)),
106                )),
107                40.0,
108            ),
109            (
110                Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None),
111                100.0,
112            ),
113            (
114                Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None),
115                0.25,
116            ),
117        ];
118
119        for (expr, expected) in test_cases {
120            let result: Result<f64> = parse_literal::<Float64Type>(&expr);
121
122            match result {
123                Ok(value) => {
124                    assert!(
125                        (value - expected).abs() < 1e-10,
126                        "For expression '{expr}': expected {expected}, got {value}",
127                    );
128                }
129                Err(e) => panic!("Failed to parse expression '{expr}': {e}"),
130            }
131        }
132    }
133
134    #[test]
135    fn test_parse_sql_integer_literal() {
136        let expr = Expr::BinaryExpr(BinaryExpr::new(
137            Box::new(lit(2)),
138            Operator::Plus,
139            Box::new(lit(4)),
140        ));
141
142        let result: Result<i64> = parse_literal::<Int64Type>(&expr);
143
144        match result {
145            Ok(value) => {
146                assert_eq!(6, value);
147            }
148            Err(e) => panic!("Failed to parse expression: {e}"),
149        }
150    }
151}