Skip to main content

datafusion_functions/core/
union_extract.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use arrow::datatypes::{DataType, Field, FieldRef, UnionFields};
20use datafusion_common::cast::as_union_array;
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{
23    Result, ScalarValue, exec_datafusion_err, exec_err, internal_err,
24};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31    doc_section(label = "Union Functions"),
32    description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
33    syntax_example = "union_extract(union, field_name)",
34    sql_example = r#"```sql
35❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
36+--------------+----------------------------------+----------------------------------+
37| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
38+--------------+----------------------------------+----------------------------------+
39| {a=1}        | 1                                |                                  |
40| {b=3.0}      |                                  | 3.0                              |
41| {a=4}        | 4                                |                                  |
42| {b=}         |                                  |                                  |
43| {a=}         |                                  |                                  |
44+--------------+----------------------------------+----------------------------------+
45```"#,
46    standard_argument(name = "union", prefix = "Union"),
47    argument(
48        name = "field_name",
49        description = "String expression to operate on. Must be a constant."
50    )
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct UnionExtractFun {
54    signature: Signature,
55}
56
57impl Default for UnionExtractFun {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl UnionExtractFun {
64    pub fn new() -> Self {
65        Self {
66            signature: Signature::any(2, Volatility::Immutable),
67        }
68    }
69}
70
71impl ScalarUDFImpl for UnionExtractFun {
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75
76    fn name(&self) -> &str {
77        "union_extract"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
85        // should be using return_field_from_args and not calling the default implementation
86        internal_err!("union_extract should return type from args")
87    }
88
89    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
90        if args.arg_fields.len() != 2 {
91            return exec_err!(
92                "union_extract expects 2 arguments, got {} instead",
93                args.arg_fields.len()
94            );
95        }
96
97        let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else {
98            return exec_err!(
99                "union_extract first argument must be a union, got {} instead",
100                args.arg_fields[0].data_type()
101            );
102        };
103
104        let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
105            return exec_err!(
106                "union_extract second argument must be a non-null string literal, got {} instead",
107                args.arg_fields[1].data_type()
108            );
109        };
110
111        let field = find_field(fields, field_name)?.1;
112
113        Ok(Field::new(self.name(), field.data_type().clone(), true).into())
114    }
115
116    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
117        let [array, target_name] = take_function_args("union_extract", args.args)?;
118
119        let target_name = match target_name {
120            ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => {
121                Ok(target_name)
122            }
123            ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!(
124                "union_extract second argument must be a non-null string literal, got a null instead"
125            ),
126            _ => exec_err!(
127                "union_extract second argument must be a non-null string literal, got {} instead",
128                target_name.data_type()
129            ),
130        }?;
131
132        match array {
133            ColumnarValue::Array(array) => {
134                let union_array = as_union_array(&array).map_err(|_| {
135                    exec_datafusion_err!(
136                        "union_extract first argument must be a union, got {} instead",
137                        array.data_type()
138                    )
139                })?;
140
141                Ok(ColumnarValue::Array(
142                    arrow::compute::kernels::union_extract::union_extract(
143                        union_array,
144                        &target_name,
145                    )?,
146                ))
147            }
148            ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
149                let (target_type_id, target) = find_field(&fields, &target_name)?;
150
151                let result = match value {
152                    Some((type_id, value)) if target_type_id == type_id => *value,
153                    _ => ScalarValue::try_new_null(target.data_type())?,
154                };
155
156                Ok(ColumnarValue::Scalar(result))
157            }
158            other => exec_err!(
159                "union_extract first argument must be a union, got {} instead",
160                other.data_type()
161            ),
162        }
163    }
164
165    fn documentation(&self) -> Option<&Documentation> {
166        self.doc()
167    }
168}
169
170fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
171    fields
172        .iter()
173        .find(|field| field.1.name() == name)
174        .ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
175}
176
177#[cfg(test)]
178mod tests {
179    use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
180    use datafusion_common::config::ConfigOptions;
181    use datafusion_common::{Result, ScalarValue};
182    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
183    use std::sync::Arc;
184
185    use super::UnionExtractFun;
186
187    // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
188    #[test]
189    fn test_scalar_value() -> Result<()> {
190        let fun = UnionExtractFun::new();
191
192        let fields = UnionFields::try_new(
193            vec![1, 3],
194            vec![
195                Field::new("str", DataType::Utf8, false),
196                Field::new("int", DataType::Int32, false),
197            ],
198        )
199        .unwrap();
200
201        let args = vec![
202            ColumnarValue::Scalar(ScalarValue::Union(
203                None,
204                fields.clone(),
205                UnionMode::Dense,
206            )),
207            ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
208        ];
209        let arg_fields = args
210            .iter()
211            .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
212            .collect::<Vec<_>>();
213
214        let result = fun.invoke_with_args(ScalarFunctionArgs {
215            args,
216            arg_fields,
217            number_rows: 1,
218            return_field: Field::new("f", DataType::Utf8, true).into(),
219            config_options: Arc::new(ConfigOptions::default()),
220        })?;
221
222        assert_scalar(result, ScalarValue::Utf8(None));
223
224        let args = vec![
225            ColumnarValue::Scalar(ScalarValue::Union(
226                Some((3, Box::new(ScalarValue::Int32(Some(42))))),
227                fields.clone(),
228                UnionMode::Dense,
229            )),
230            ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
231        ];
232        let arg_fields = args
233            .iter()
234            .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
235            .collect::<Vec<_>>();
236
237        let result = fun.invoke_with_args(ScalarFunctionArgs {
238            args,
239            arg_fields,
240            number_rows: 1,
241            return_field: Field::new("f", DataType::Utf8, true).into(),
242            config_options: Arc::new(ConfigOptions::default()),
243        })?;
244
245        assert_scalar(result, ScalarValue::Utf8(None));
246
247        let args = vec![
248            ColumnarValue::Scalar(ScalarValue::Union(
249                Some((1, Box::new(ScalarValue::new_utf8("42")))),
250                fields.clone(),
251                UnionMode::Dense,
252            )),
253            ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
254        ];
255        let arg_fields = args
256            .iter()
257            .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
258            .collect::<Vec<_>>();
259        let result = fun.invoke_with_args(ScalarFunctionArgs {
260            args,
261            arg_fields,
262            number_rows: 1,
263            return_field: Field::new("f", DataType::Utf8, true).into(),
264            config_options: Arc::new(ConfigOptions::default()),
265        })?;
266
267        assert_scalar(result, ScalarValue::new_utf8("42"));
268
269        Ok(())
270    }
271
272    fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
273        match value {
274            ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
275            ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
276        }
277    }
278}