Skip to main content

datafusion_functions/core/
named_struct.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::getfield::GetFieldFunc;
19use arrow::array::StructArray;
20use arrow::datatypes::{DataType, Field, FieldRef, Fields};
21use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
22use datafusion_expr::{
23    ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
24    StructFieldMapping,
25};
26use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
27use datafusion_macros::user_doc;
28use std::sync::Arc;
29
30#[user_doc(
31    doc_section(label = "Struct Functions"),
32    description = "Returns an Arrow struct using the specified name and input expressions pairs.
33For information on comparing and ordering struct values (including `NULL` handling),
34see [Comparison and Ordering](struct_coercion.md#comparison-and-ordering).",
35    syntax_example = "named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])",
36    sql_example = r#"
37For example, this query converts two columns `a` and `b` to a single column with
38a struct type of fields `field_a` and `field_b`:
39```sql
40> select * from t;
41+---+---+
42| a | b |
43+---+---+
44| 1 | 2 |
45| 3 | 4 |
46+---+---+
47> select named_struct('field_a', a, 'field_b', b) from t;
48+-------------------------------------------------------+
49| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) |
50+-------------------------------------------------------+
51| {field_a: 1, field_b: 2}                              |
52| {field_a: 3, field_b: 4}                              |
53+-------------------------------------------------------+
54```"#,
55    argument(
56        name = "expression_n_name",
57        description = "Name of the column field. Must be a constant string."
58    ),
59    argument(
60        name = "expression_n_input",
61        description = "Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators."
62    )
63)]
64#[derive(Debug, PartialEq, Eq, Hash)]
65pub struct NamedStructFunc {
66    signature: Signature,
67}
68
69impl Default for NamedStructFunc {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl NamedStructFunc {
76    pub fn new() -> Self {
77        Self {
78            signature: Signature::variadic_any(Volatility::Immutable),
79        }
80    }
81}
82
83impl ScalarUDFImpl for NamedStructFunc {
84    fn name(&self) -> &str {
85        "named_struct"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93        internal_err!(
94            "named_struct: return_type called instead of return_field_from_args"
95        )
96    }
97
98    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
99        // do not accept 0 arguments.
100        if args.scalar_arguments.is_empty() {
101            return exec_err!(
102                "named_struct requires at least one pair of arguments, got 0 instead"
103            );
104        }
105
106        if !args.scalar_arguments.len().is_multiple_of(2) {
107            return exec_err!(
108                "named_struct requires an even number of arguments, got {} instead",
109                args.scalar_arguments.len()
110            );
111        }
112
113        let names = args
114            .scalar_arguments
115            .iter()
116            .enumerate()
117            .step_by(2)
118            .map(|(i, sv)|
119                sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
120                .map_or_else(
121                    ||
122                        exec_err!(
123                    "{} requires {i}-th (0-indexed) field name as non-empty constant string",
124                    self.name()
125                ),
126                Ok
127                )
128            )
129            .collect::<Result<Vec<_>>>()?;
130        let types = args
131            .arg_fields
132            .iter()
133            .skip(1)
134            .step_by(2)
135            .map(|f| f.data_type())
136            .collect::<Vec<_>>();
137
138        let return_fields = names
139            .into_iter()
140            .zip(types)
141            .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true)))
142            .collect::<Result<Vec<Field>>>()?;
143
144        Ok(Field::new(
145            self.name(),
146            DataType::Struct(Fields::from(return_fields)),
147            true,
148        )
149        .into())
150    }
151
152    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
153        let DataType::Struct(fields) = args.return_type() else {
154            return internal_err!("incorrect named_struct return type");
155        };
156
157        assert_eq!(
158            fields.len(),
159            args.args.len() / 2,
160            "return type field count != argument count / 2"
161        );
162
163        let values: Vec<ColumnarValue> = args
164            .args
165            .chunks_exact(2)
166            .map(|chunk| chunk[1].clone())
167            .collect();
168        let arrays = ColumnarValue::values_to_arrays(&values)?;
169        Ok(ColumnarValue::Array(Arc::new(StructArray::new(
170            fields.clone(),
171            arrays,
172            None,
173        ))))
174    }
175
176    fn documentation(&self) -> Option<&Documentation> {
177        self.doc()
178    }
179
180    fn struct_field_mapping(
181        &self,
182        literal_args: &[Option<ScalarValue>],
183    ) -> Option<StructFieldMapping> {
184        if literal_args.is_empty() || !literal_args.len().is_multiple_of(2) {
185            return None;
186        }
187
188        let mut fields = Vec::with_capacity(literal_args.len() / 2);
189        for (i, chunk) in literal_args.chunks(2).enumerate() {
190            match chunk {
191                [Some(ScalarValue::Utf8(Some(name))), _] => {
192                    fields.push((
193                        vec![ScalarValue::Utf8(Some(name.clone()))],
194                        i * 2 + 1, // index of the value argument
195                    ));
196                }
197                _ => return None,
198            }
199        }
200
201        Some(StructFieldMapping {
202            field_accessor: Arc::new(ScalarUDF::from(GetFieldFunc::new())),
203            fields,
204        })
205    }
206}