Skip to main content

datafusion_functions/core/
arrow_try_cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast`
19
20use arrow::datatypes::{DataType, Field, FieldRef};
21use arrow::error::ArrowError;
22use datafusion_common::{
23    Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err,
24    internal_err, types::logical_string, utils::take_function_args,
25};
26
27use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
28use datafusion_expr::{
29    Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
30    ScalarUDFImpl, Signature, TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34use super::arrow_cast::data_type_from_type_arg;
35
36/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
37///
38/// This is implemented by simplifying `arrow_try_cast(expr, 'Type')` into
39/// `Expr::TryCast` during optimization.
40#[user_doc(
41    doc_section(label = "Other Functions"),
42    description = "Casts a value to a specific Arrow data type, returning NULL if the cast fails.",
43    syntax_example = "arrow_try_cast(expression, datatype)",
44    sql_example = r#"```sql
45> select arrow_try_cast('123', 'Int64') as a,
46         arrow_try_cast('not_a_number', 'Int64') as b;
47
48+-----+------+
49| a   | b    |
50+-----+------+
51| 123 | NULL |
52+-----+------+
53```"#,
54    argument(
55        name = "expression",
56        description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators."
57    ),
58    argument(
59        name = "datatype",
60        description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]"
61    )
62)]
63#[derive(Debug, PartialEq, Eq, Hash)]
64pub struct ArrowTryCastFunc {
65    signature: Signature,
66}
67
68impl Default for ArrowTryCastFunc {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl ArrowTryCastFunc {
75    pub fn new() -> Self {
76        Self {
77            signature: Signature::coercible(
78                vec![
79                    Coercion::new_exact(TypeSignatureClass::Any),
80                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
81                ],
82                Volatility::Immutable,
83            ),
84        }
85    }
86}
87
88impl ScalarUDFImpl for ArrowTryCastFunc {
89    fn name(&self) -> &str {
90        "arrow_try_cast"
91    }
92
93    fn signature(&self) -> &Signature {
94        &self.signature
95    }
96
97    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
98        internal_err!("return_field_from_args should be called instead")
99    }
100
101    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
102        // TryCast can always return NULL (on cast failure), so always nullable
103        let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
104
105        type_arg
106            .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
107            .map_or_else(
108                || {
109                    exec_err!(
110                        "{} requires its second argument to be a non-empty constant string",
111                        self.name()
112                    )
113                },
114                |casted_type| match casted_type.parse::<DataType>() {
115                    Ok(data_type) => {
116                        Ok(Field::new(self.name(), data_type, true).into())
117                    }
118                    Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
119                    Err(e) => Err(arrow_datafusion_err!(e)),
120                },
121            )
122    }
123
124    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125        internal_err!("arrow_try_cast should have been simplified to try_cast")
126    }
127
128    fn simplify(
129        &self,
130        args: Vec<Expr>,
131        info: &SimplifyContext,
132    ) -> Result<ExprSimplifyResult> {
133        let [source_arg, type_arg] = take_function_args(self.name(), args)?;
134        let target_type = data_type_from_type_arg(self.name(), &type_arg)?;
135
136        let source_type = info.get_data_type(&source_arg)?;
137        let new_expr = if source_type == target_type {
138            source_arg
139        } else {
140            Expr::TryCast(datafusion_expr::TryCast {
141                expr: Box::new(source_arg),
142                field: target_type.into_nullable_field_ref(),
143            })
144        };
145        Ok(ExprSimplifyResult::Simplified(new_expr))
146    }
147
148    fn documentation(&self) -> Option<&Documentation> {
149        self.doc()
150    }
151}