Skip to main content

datafusion_optimizer/simplify_expressions/
simplify_literal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Parses and simplifies an expression to a literal of a given type.
19//!
20//! This module provides functionality to parse and simplify static expressions
21//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required
22//! in a planning (not an execution) phase, they need to be reduced to literals of a given type.
23
24use crate::simplify_expressions::ExprSimplifier;
25use arrow::datatypes::ArrowPrimitiveType;
26use datafusion_common::{
27    DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err,
28    plan_err,
29};
30use datafusion_expr::Expr;
31use datafusion_expr::simplify::SimplifyContext;
32use std::sync::Arc;
33
34/// Parse and simplifies an expression to a numeric literal,
35/// corresponding to an arrow primitive type `T` (for example, Float64Type).
36///
37/// This function simplifies and coerces the expression, then extracts the underlying
38/// native type using `TryFrom<ScalarValue>`.
39///
40/// # Example
41/// ```ignore
42/// let value: f64 = parse_literal::<Float64Type>(expr)?;
43/// ```
44pub fn parse_literal<T>(expr: &Expr) -> Result<T::Native>
45where
46    T: ArrowPrimitiveType,
47    T::Native: TryFrom<ScalarValue, Error = DataFusionError>,
48{
49    // Empty schema is sufficient because it parses only literal expressions
50    let schema = DFSchemaRef::new(DFSchema::empty());
51
52    log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE);
53
54    let simplifier =
55        ExprSimplifier::new(SimplifyContext::default().with_schema(Arc::clone(&schema)));
56
57    // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5)
58    let simplified_expr: Expr = simplifier
59        .simplify(expr.clone())
60        .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?;
61    let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?;
62    log::debug!("Coerced expression: {:?}", &coerced_expr);
63
64    match coerced_expr {
65        Expr::Literal(scalar_value, _) => {
66            // It is a literal - proceed to the underlying value
67            // Cast to the target type if needed
68            let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?;
69
70            // Extract the native type
71            T::Native::try_from(casted_scalar).map_err(|err| {
72                plan_datafusion_err!(
73                    "Cannot extract {} from scalar value: {err}",
74                    std::any::type_name::<T>()
75                )
76            })
77        }
78        actual => {
79            plan_err!(
80                "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression"
81            )
82        }
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use arrow::datatypes::{Float64Type, Int64Type};
90    use datafusion_expr::{BinaryExpr, lit};
91    use datafusion_expr_common::operator::Operator;
92
93    #[test]
94    fn test_parse_sql_float_literal() {
95        let test_cases = vec![
96            (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0),
97            (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0),
98            (
99                Expr::BinaryExpr(BinaryExpr::new(
100                    Box::new(lit(50.0)),
101                    Operator::Minus,
102                    Box::new(lit(10.0)),
103                )),
104                40.0,
105            ),
106            (
107                Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None),
108                100.0,
109            ),
110            (
111                Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None),
112                0.25,
113            ),
114        ];
115
116        for (expr, expected) in test_cases {
117            let result: Result<f64> = parse_literal::<Float64Type>(&expr);
118
119            match result {
120                Ok(value) => {
121                    assert!(
122                        (value - expected).abs() < 1e-10,
123                        "For expression '{expr}': expected {expected}, got {value}",
124                    );
125                }
126                Err(e) => panic!("Failed to parse expression '{expr}': {e}"),
127            }
128        }
129    }
130
131    #[test]
132    fn test_parse_sql_integer_literal() {
133        let expr = Expr::BinaryExpr(BinaryExpr::new(
134            Box::new(lit(2)),
135            Operator::Plus,
136            Box::new(lit(4)),
137        ));
138
139        let result: Result<i64> = parse_literal::<Int64Type>(&expr);
140
141        match result {
142            Ok(value) => {
143                assert_eq!(6, value);
144            }
145            Err(e) => panic!("Failed to parse expression: {e}"),
146        }
147    }
148}