datafusion_functions/core/
union_extract.rs1use arrow::array::Array;
19use arrow::datatypes::{DataType, Field, FieldRef, UnionFields};
20use datafusion_common::cast::as_union_array;
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{
23 Result, ScalarValue, exec_datafusion_err, exec_err, internal_err,
24};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31 doc_section(label = "Union Functions"),
32 description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
33 syntax_example = "union_extract(union, field_name)",
34 sql_example = r#"```sql
35❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
36+--------------+----------------------------------+----------------------------------+
37| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
38+--------------+----------------------------------+----------------------------------+
39| {a=1} | 1 | |
40| {b=3.0} | | 3.0 |
41| {a=4} | 4 | |
42| {b=} | | |
43| {a=} | | |
44+--------------+----------------------------------+----------------------------------+
45```"#,
46 standard_argument(name = "union", prefix = "Union"),
47 argument(
48 name = "field_name",
49 description = "String expression to operate on. Must be a constant."
50 )
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct UnionExtractFun {
54 signature: Signature,
55}
56
57impl Default for UnionExtractFun {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl UnionExtractFun {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::any(2, Volatility::Immutable),
67 }
68 }
69}
70
71impl ScalarUDFImpl for UnionExtractFun {
72 fn as_any(&self) -> &dyn std::any::Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "union_extract"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
85 internal_err!("union_extract should return type from args")
87 }
88
89 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
90 if args.arg_fields.len() != 2 {
91 return exec_err!(
92 "union_extract expects 2 arguments, got {} instead",
93 args.arg_fields.len()
94 );
95 }
96
97 let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else {
98 return exec_err!(
99 "union_extract first argument must be a union, got {} instead",
100 args.arg_fields[0].data_type()
101 );
102 };
103
104 let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
105 return exec_err!(
106 "union_extract second argument must be a non-null string literal, got {} instead",
107 args.arg_fields[1].data_type()
108 );
109 };
110
111 let field = find_field(fields, field_name)?.1;
112
113 Ok(Field::new(self.name(), field.data_type().clone(), true).into())
114 }
115
116 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
117 let [array, target_name] = take_function_args("union_extract", args.args)?;
118
119 let target_name = match target_name {
120 ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => {
121 Ok(target_name)
122 }
123 ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!(
124 "union_extract second argument must be a non-null string literal, got a null instead"
125 ),
126 _ => exec_err!(
127 "union_extract second argument must be a non-null string literal, got {} instead",
128 target_name.data_type()
129 ),
130 }?;
131
132 match array {
133 ColumnarValue::Array(array) => {
134 let union_array = as_union_array(&array).map_err(|_| {
135 exec_datafusion_err!(
136 "union_extract first argument must be a union, got {} instead",
137 array.data_type()
138 )
139 })?;
140
141 Ok(ColumnarValue::Array(
142 arrow::compute::kernels::union_extract::union_extract(
143 union_array,
144 &target_name,
145 )?,
146 ))
147 }
148 ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
149 let (target_type_id, target) = find_field(&fields, &target_name)?;
150
151 let result = match value {
152 Some((type_id, value)) if target_type_id == type_id => *value,
153 _ => ScalarValue::try_new_null(target.data_type())?,
154 };
155
156 Ok(ColumnarValue::Scalar(result))
157 }
158 other => exec_err!(
159 "union_extract first argument must be a union, got {} instead",
160 other.data_type()
161 ),
162 }
163 }
164
165 fn documentation(&self) -> Option<&Documentation> {
166 self.doc()
167 }
168}
169
170fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
171 fields
172 .iter()
173 .find(|field| field.1.name() == name)
174 .ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
175}
176
177#[cfg(test)]
178mod tests {
179 use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
180 use datafusion_common::config::ConfigOptions;
181 use datafusion_common::{Result, ScalarValue};
182 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
183 use std::sync::Arc;
184
185 use super::UnionExtractFun;
186
187 #[test]
189 fn test_scalar_value() -> Result<()> {
190 let fun = UnionExtractFun::new();
191
192 let fields = UnionFields::new(
193 vec![1, 3],
194 vec![
195 Field::new("str", DataType::Utf8, false),
196 Field::new("int", DataType::Int32, false),
197 ],
198 );
199
200 let args = vec![
201 ColumnarValue::Scalar(ScalarValue::Union(
202 None,
203 fields.clone(),
204 UnionMode::Dense,
205 )),
206 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
207 ];
208 let arg_fields = args
209 .iter()
210 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
211 .collect::<Vec<_>>();
212
213 let result = fun.invoke_with_args(ScalarFunctionArgs {
214 args,
215 arg_fields,
216 number_rows: 1,
217 return_field: Field::new("f", DataType::Utf8, true).into(),
218 config_options: Arc::new(ConfigOptions::default()),
219 })?;
220
221 assert_scalar(result, ScalarValue::Utf8(None));
222
223 let args = vec![
224 ColumnarValue::Scalar(ScalarValue::Union(
225 Some((3, Box::new(ScalarValue::Int32(Some(42))))),
226 fields.clone(),
227 UnionMode::Dense,
228 )),
229 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
230 ];
231 let arg_fields = args
232 .iter()
233 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
234 .collect::<Vec<_>>();
235
236 let result = fun.invoke_with_args(ScalarFunctionArgs {
237 args,
238 arg_fields,
239 number_rows: 1,
240 return_field: Field::new("f", DataType::Utf8, true).into(),
241 config_options: Arc::new(ConfigOptions::default()),
242 })?;
243
244 assert_scalar(result, ScalarValue::Utf8(None));
245
246 let args = vec![
247 ColumnarValue::Scalar(ScalarValue::Union(
248 Some((1, Box::new(ScalarValue::new_utf8("42")))),
249 fields.clone(),
250 UnionMode::Dense,
251 )),
252 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
253 ];
254 let arg_fields = args
255 .iter()
256 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
257 .collect::<Vec<_>>();
258 let result = fun.invoke_with_args(ScalarFunctionArgs {
259 args,
260 arg_fields,
261 number_rows: 1,
262 return_field: Field::new("f", DataType::Utf8, true).into(),
263 config_options: Arc::new(ConfigOptions::default()),
264 })?;
265
266 assert_scalar(result, ScalarValue::new_utf8("42"));
267
268 Ok(())
269 }
270
271 fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
272 match value {
273 ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
274 ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
275 }
276 }
277}