datafusion_functions/core/
arrow_cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`
19
20use arrow::datatypes::{DataType, Field, FieldRef};
21use arrow::error::ArrowError;
22use datafusion_common::types::logical_string;
23use datafusion_common::{
24    Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err,
25};
26use datafusion_common::{exec_datafusion_err, utils::take_function_args};
27use std::any::Any;
28
29use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
30use datafusion_expr::{
31    Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
32    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
33};
34use datafusion_macros::user_doc;
35
36/// Implements casting to arbitrary arrow types (rather than SQL types)
37///
38/// Note that the `arrow_cast` function is somewhat special in that its
39/// return depends only on the *value* of its second argument (not its type)
40///
41/// It is implemented by calling the same underlying arrow `cast` kernel as
42/// normal SQL casts.
43///
44/// For example to cast to `int` using SQL  (which is then mapped to the arrow
45/// type `Int32`)
46///
47/// ```sql
48/// select cast(column_x as int) ...
49/// ```
50///
51/// Use the `arrow_cast` function to cast to a specific arrow type
52///
53/// For example
54/// ```sql
55/// select arrow_cast(column_x, 'Float64')
56/// ```
57#[user_doc(
58    doc_section(label = "Other Functions"),
59    description = "Casts a value to a specific Arrow data type.",
60    syntax_example = "arrow_cast(expression, datatype)",
61    sql_example = r#"```sql
62> select
63  arrow_cast(-5,    'Int8') as a,
64  arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b,
65  arrow_cast('bar', 'LargeUtf8') as c;
66
67+----+-----+-----+
68| a  | b   | c   |
69+----+-----+-----+
70| -5 | foo | bar |
71+----+-----+-----+
72
73> select
74  arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs, "+08:00")') as d,
75  arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs)') as e;
76
77+---------------------------+---------------------+
78| d                         | e                   |
79+---------------------------+---------------------+
80| 2023-01-02T12:53:02+08:00 | 2023-01-02T12:53:02 |
81+---------------------------+---------------------+
82```"#,
83    argument(
84        name = "expression",
85        description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
86    ),
87    argument(
88        name = "datatype",
89        description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
90    )
91)]
92#[derive(Debug, PartialEq, Eq, Hash)]
93pub struct ArrowCastFunc {
94    signature: Signature,
95}
96
97impl Default for ArrowCastFunc {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl ArrowCastFunc {
104    pub fn new() -> Self {
105        Self {
106            signature: Signature::coercible(
107                vec![
108                    Coercion::new_exact(TypeSignatureClass::Any),
109                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
110                ],
111                Volatility::Immutable,
112            ),
113        }
114    }
115}
116
117impl ScalarUDFImpl for ArrowCastFunc {
118    fn as_any(&self) -> &dyn Any {
119        self
120    }
121
122    fn name(&self) -> &str {
123        "arrow_cast"
124    }
125
126    fn signature(&self) -> &Signature {
127        &self.signature
128    }
129
130    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
131        internal_err!("return_field_from_args should be called instead")
132    }
133
134    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
135        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
136
137        let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
138
139        type_arg
140            .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
141            .map_or_else(
142                || {
143                    exec_err!(
144                        "{} requires its second argument to be a non-empty constant string",
145                        self.name()
146                    )
147                },
148                |casted_type| match casted_type.parse::<DataType>() {
149                    Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()),
150                    Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
151                    Err(e) => Err(arrow_datafusion_err!(e)),
152                },
153            )
154    }
155
156    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
157        internal_err!("arrow_cast should have been simplified to cast")
158    }
159
160    fn simplify(
161        &self,
162        mut args: Vec<Expr>,
163        info: &dyn SimplifyInfo,
164    ) -> Result<ExprSimplifyResult> {
165        // convert this into a real cast
166        let target_type = data_type_from_args(&args)?;
167        // remove second (type) argument
168        args.pop().unwrap();
169        let arg = args.pop().unwrap();
170
171        let source_type = info.get_data_type(&arg)?;
172        let new_expr = if source_type == target_type {
173            // the argument's data type is already the correct type
174            arg
175        } else {
176            // Use an actual cast to get the correct type
177            Expr::Cast(datafusion_expr::Cast {
178                expr: Box::new(arg),
179                data_type: target_type,
180            })
181        };
182        // return the newly written argument to DataFusion
183        Ok(ExprSimplifyResult::Simplified(new_expr))
184    }
185
186    fn documentation(&self) -> Option<&Documentation> {
187        self.doc()
188    }
189}
190
191/// Returns the requested type from the arguments
192fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
193    let [_, type_arg] = take_function_args("arrow_cast", args)?;
194
195    let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
196        return exec_err!(
197            "arrow_cast requires its second argument to be a constant string, got {:?}",
198            type_arg
199        );
200    };
201
202    val.parse().map_err(|e| match e {
203        // If the data type cannot be parsed, return a Plan error to signal an
204        // error in the input rather than a more general ArrowError
205        ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
206        e => arrow_datafusion_err!(e),
207    })
208}