datafusion_functions/core/
union_extract.rs1use arrow::array::Array;
19use arrow::datatypes::{DataType, Field, FieldRef, UnionFields};
20use datafusion_common::cast::as_union_array;
21use datafusion_common::utils::take_function_args;
22use datafusion_common::{
23 Result, ScalarValue, exec_datafusion_err, exec_err, internal_err,
24};
25use datafusion_doc::Documentation;
26use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs};
27use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
28use datafusion_macros::user_doc;
29
30#[user_doc(
31 doc_section(label = "Union Functions"),
32 description = "Returns the value of the given field in the union when selected, or NULL otherwise.",
33 syntax_example = "union_extract(union, field_name)",
34 sql_example = r#"```sql
35❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
36+--------------+----------------------------------+----------------------------------+
37| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
38+--------------+----------------------------------+----------------------------------+
39| {a=1} | 1 | |
40| {b=3.0} | | 3.0 |
41| {a=4} | 4 | |
42| {b=} | | |
43| {a=} | | |
44+--------------+----------------------------------+----------------------------------+
45```"#,
46 standard_argument(name = "union", prefix = "Union"),
47 argument(
48 name = "field_name",
49 description = "String expression to operate on. Must be a constant."
50 )
51)]
52#[derive(Debug, PartialEq, Eq, Hash)]
53pub struct UnionExtractFun {
54 signature: Signature,
55}
56
57impl Default for UnionExtractFun {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl UnionExtractFun {
64 pub fn new() -> Self {
65 Self {
66 signature: Signature::any(2, Volatility::Immutable),
67 }
68 }
69}
70
71impl ScalarUDFImpl for UnionExtractFun {
72 fn name(&self) -> &str {
73 "union_extract"
74 }
75
76 fn signature(&self) -> &Signature {
77 &self.signature
78 }
79
80 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
81 internal_err!("union_extract should return type from args")
83 }
84
85 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
86 if args.arg_fields.len() != 2 {
87 return exec_err!(
88 "union_extract expects 2 arguments, got {} instead",
89 args.arg_fields.len()
90 );
91 }
92
93 let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else {
94 return exec_err!(
95 "union_extract first argument must be a union, got {} instead",
96 args.arg_fields[0].data_type()
97 );
98 };
99
100 let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else {
101 return exec_err!(
102 "union_extract second argument must be a non-null string literal, got {} instead",
103 args.arg_fields[1].data_type()
104 );
105 };
106
107 let field = find_field(fields, field_name)?.1;
108
109 Ok(Field::new(self.name(), field.data_type().clone(), true).into())
110 }
111
112 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
113 let [array, target_name] = take_function_args("union_extract", args.args)?;
114
115 let target_name = match target_name {
116 ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => {
117 Ok(target_name)
118 }
119 ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!(
120 "union_extract second argument must be a non-null string literal, got a null instead"
121 ),
122 _ => exec_err!(
123 "union_extract second argument must be a non-null string literal, got {} instead",
124 target_name.data_type()
125 ),
126 }?;
127
128 match array {
129 ColumnarValue::Array(array) => {
130 let union_array = as_union_array(&array).map_err(|_| {
131 exec_datafusion_err!(
132 "union_extract first argument must be a union, got {} instead",
133 array.data_type()
134 )
135 })?;
136
137 Ok(ColumnarValue::Array(
138 arrow::compute::kernels::union_extract::union_extract(
139 union_array,
140 &target_name,
141 )?,
142 ))
143 }
144 ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
145 let (target_type_id, target) = find_field(&fields, &target_name)?;
146
147 let result = match value {
148 Some((type_id, value)) if target_type_id == type_id => *value,
149 _ => ScalarValue::try_new_null(target.data_type())?,
150 };
151
152 Ok(ColumnarValue::Scalar(result))
153 }
154 other => exec_err!(
155 "union_extract first argument must be a union, got {} instead",
156 other.data_type()
157 ),
158 }
159 }
160
161 fn documentation(&self) -> Option<&Documentation> {
162 self.doc()
163 }
164}
165
166fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> {
167 fields
168 .iter()
169 .find(|field| field.1.name() == name)
170 .ok_or_else(|| exec_datafusion_err!("field {name} not found on union"))
171}
172
173#[cfg(test)]
174mod tests {
175 use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
176 use datafusion_common::config::ConfigOptions;
177 use datafusion_common::{Result, ScalarValue};
178 use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
179 use std::sync::Arc;
180
181 use super::UnionExtractFun;
182
183 #[test]
185 fn test_scalar_value() -> Result<()> {
186 let fun = UnionExtractFun::new();
187
188 let fields = UnionFields::try_new(
189 vec![1, 3],
190 vec![
191 Field::new("str", DataType::Utf8, false),
192 Field::new("int", DataType::Int32, false),
193 ],
194 )
195 .unwrap();
196
197 let args = vec![
198 ColumnarValue::Scalar(ScalarValue::Union(
199 None,
200 fields.clone(),
201 UnionMode::Dense,
202 )),
203 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
204 ];
205 let arg_fields = args
206 .iter()
207 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
208 .collect::<Vec<_>>();
209
210 let result = fun.invoke_with_args(ScalarFunctionArgs {
211 args,
212 arg_fields,
213 number_rows: 1,
214 return_field: Field::new("f", DataType::Utf8, true).into(),
215 config_options: Arc::new(ConfigOptions::default()),
216 })?;
217
218 assert_scalar(result, ScalarValue::Utf8(None));
219
220 let args = vec![
221 ColumnarValue::Scalar(ScalarValue::Union(
222 Some((3, Box::new(ScalarValue::Int32(Some(42))))),
223 fields.clone(),
224 UnionMode::Dense,
225 )),
226 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
227 ];
228 let arg_fields = args
229 .iter()
230 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
231 .collect::<Vec<_>>();
232
233 let result = fun.invoke_with_args(ScalarFunctionArgs {
234 args,
235 arg_fields,
236 number_rows: 1,
237 return_field: Field::new("f", DataType::Utf8, true).into(),
238 config_options: Arc::new(ConfigOptions::default()),
239 })?;
240
241 assert_scalar(result, ScalarValue::Utf8(None));
242
243 let args = vec![
244 ColumnarValue::Scalar(ScalarValue::Union(
245 Some((1, Box::new(ScalarValue::new_utf8("42")))),
246 fields.clone(),
247 UnionMode::Dense,
248 )),
249 ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
250 ];
251 let arg_fields = args
252 .iter()
253 .map(|arg| Field::new("a", arg.data_type().clone(), true).into())
254 .collect::<Vec<_>>();
255 let result = fun.invoke_with_args(ScalarFunctionArgs {
256 args,
257 arg_fields,
258 number_rows: 1,
259 return_field: Field::new("f", DataType::Utf8, true).into(),
260 config_options: Arc::new(ConfigOptions::default()),
261 })?;
262
263 assert_scalar(result, ScalarValue::new_utf8("42"));
264
265 Ok(())
266 }
267
268 fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
269 match value {
270 ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
271 ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
272 }
273 }
274}