Skip to main content

datafusion_functions/core/
arrow_cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`
19
20use arrow::datatypes::{DataType, Field, FieldRef};
21use arrow::error::ArrowError;
22use datafusion_common::{
23    Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt,
24    exec_datafusion_err, exec_err, internal_err, types::logical_string,
25    utils::take_function_args,
26};
27
28use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
29use datafusion_expr::{
30    Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
31    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
32};
33use datafusion_macros::user_doc;
34
35/// Implements casting to arbitrary arrow types (rather than SQL types)
36///
37/// Note that the `arrow_cast` function is somewhat special in that its
38/// return depends only on the *value* of its second argument (not its type)
39///
40/// It is implemented by calling the same underlying arrow `cast` kernel as
41/// normal SQL casts.
42///
43/// For example to cast to `int` using SQL  (which is then mapped to the arrow
44/// type `Int32`)
45///
46/// ```sql
47/// select cast(column_x as int) ...
48/// ```
49///
50/// Use the `arrow_cast` function to cast to a specific arrow type
51///
52/// For example
53/// ```sql
54/// select arrow_cast(column_x, 'Float64')
55/// ```
56#[user_doc(
57    doc_section(label = "Other Functions"),
58    description = "Casts a value to a specific Arrow data type.",
59    syntax_example = "arrow_cast(expression, datatype)",
60    sql_example = r#"```sql
61> select
62  arrow_cast(-5,    'Int8') as a,
63  arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b,
64  arrow_cast('bar', 'LargeUtf8') as c;
65
66+----+-----+-----+
67| a  | b   | c   |
68+----+-----+-----+
69| -5 | foo | bar |
70+----+-----+-----+
71
72> select
73  arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs, "+08:00")') as d,
74  arrow_cast('2023-01-02T12:53:02', 'Timestamp(µs)') as e;
75
76+---------------------------+---------------------+
77| d                         | e                   |
78+---------------------------+---------------------+
79| 2023-01-02T12:53:02+08:00 | 2023-01-02T12:53:02 |
80+---------------------------+---------------------+
81```"#,
82    argument(
83        name = "expression",
84        description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
85    ),
86    argument(
87        name = "datatype",
88        description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
89    )
90)]
91#[derive(Debug, PartialEq, Eq, Hash)]
92pub struct ArrowCastFunc {
93    signature: Signature,
94}
95
96impl Default for ArrowCastFunc {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl ArrowCastFunc {
103    pub fn new() -> Self {
104        Self {
105            signature: Signature::coercible(
106                vec![
107                    Coercion::new_exact(TypeSignatureClass::Any),
108                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
109                ],
110                Volatility::Immutable,
111            ),
112        }
113    }
114}
115
116impl ScalarUDFImpl for ArrowCastFunc {
117    fn name(&self) -> &str {
118        "arrow_cast"
119    }
120
121    fn signature(&self) -> &Signature {
122        &self.signature
123    }
124
125    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
126        internal_err!("return_field_from_args should be called instead")
127    }
128
129    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
130        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
131
132        let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
133
134        type_arg
135            .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
136            .map_or_else(
137                || {
138                    exec_err!(
139                        "{} requires its second argument to be a non-empty constant string",
140                        self.name()
141                    )
142                },
143                |casted_type| match casted_type.parse::<DataType>() {
144                    Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()),
145                    Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
146                    Err(e) => Err(arrow_datafusion_err!(e)),
147                },
148            )
149    }
150
151    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
152        internal_err!("arrow_cast should have been simplified to cast")
153    }
154
155    fn simplify(
156        &self,
157        args: Vec<Expr>,
158        info: &SimplifyContext,
159    ) -> Result<ExprSimplifyResult> {
160        // convert this into a real cast
161        let [source_arg, type_arg] = take_function_args(self.name(), args)?;
162        let target_type = data_type_from_type_arg(self.name(), &type_arg)?;
163        let source_type = info.get_data_type(&source_arg)?;
164        let new_expr = if source_type == target_type {
165            // the argument's data type is already the correct type
166            source_arg
167        } else {
168            // Use an actual cast to get the correct type
169            Expr::Cast(datafusion_expr::Cast {
170                expr: Box::new(source_arg),
171                field: target_type.into_nullable_field_ref(),
172            })
173        };
174        // return the newly written argument to DataFusion
175        Ok(ExprSimplifyResult::Simplified(new_expr))
176    }
177
178    fn documentation(&self) -> Option<&Documentation> {
179        self.doc()
180    }
181}
182
183/// Returns the requested type from the type argument
184pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result<DataType> {
185    let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else {
186        return exec_err!(
187            "{name} requires its second argument to be a constant string, got {:?}",
188            type_arg
189        );
190    };
191
192    val.parse().map_err(|e| match e {
193        // If the data type cannot be parsed, return a Plan error to signal an
194        // error in the input rather than a more general ArrowError
195        ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
196        e => arrow_datafusion_err!(e),
197    })
198}