datafusion_functions_nested/
map_values.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for map_values function.
19
20use crate::utils::{get_map_entry_field, make_scalar_function};
21use arrow::array::{Array, ArrayRef, ListArray};
22use arrow::datatypes::{DataType, Field, FieldRef};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result};
25use datafusion_expr::{
26    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
27    TypeSignature, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::any::Any;
31use std::ops::Deref;
32use std::sync::Arc;
33
34make_udf_expr_and_func!(
35    MapValuesFunc,
36    map_values,
37    map,
38    "Return a list of all values in the map.",
39    map_values_udf
40);
41
42#[user_doc(
43    doc_section(label = "Map Functions"),
44    description = "Returns a list of all values in the map.",
45    syntax_example = "map_values(map)",
46    sql_example = r#"```sql
47SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3});
48----
49[1, , 3]
50
51SELECT map_values(map([100, 5], [42, 43]));
52----
53[42, 43]
54```"#,
55    argument(
56        name = "map",
57        description = "Map expression. Can be a constant, column, or function, and any combination of map operators."
58    )
59)]
60#[derive(Debug)]
61pub(crate) struct MapValuesFunc {
62    signature: Signature,
63}
64
65impl Default for MapValuesFunc {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl MapValuesFunc {
72    pub fn new() -> Self {
73        Self {
74            signature: Signature::new(
75                TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
76                Volatility::Immutable,
77            ),
78        }
79    }
80}
81
82impl ScalarUDFImpl for MapValuesFunc {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "map_values"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96        internal_err!("return_field_from_args should be used instead")
97    }
98
99    fn return_field_from_args(
100        &self,
101        args: datafusion_expr::ReturnFieldArgs,
102    ) -> Result<FieldRef> {
103        let [map_type] = take_function_args(self.name(), args.arg_fields)?;
104
105        Ok(Field::new(
106            self.name(),
107            DataType::List(get_map_values_field_as_list_field(map_type.data_type())?),
108            // Nullable if the map is nullable
109            args.arg_fields.iter().any(|x| x.is_nullable()),
110        )
111        .into())
112    }
113
114    fn invoke_with_args(
115        &self,
116        args: datafusion_expr::ScalarFunctionArgs,
117    ) -> Result<ColumnarValue> {
118        make_scalar_function(map_values_inner)(&args.args)
119    }
120
121    fn documentation(&self) -> Option<&Documentation> {
122        self.doc()
123    }
124}
125
126fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
127    let [map_arg] = take_function_args("map_values", args)?;
128
129    let map_array = match map_arg.data_type() {
130        DataType::Map(_, _) => as_map_array(&map_arg)?,
131        _ => return exec_err!("Argument for map_values should be a map"),
132    };
133
134    Ok(Arc::new(ListArray::new(
135        get_map_values_field_as_list_field(map_arg.data_type())?,
136        map_array.offsets().clone(),
137        Arc::clone(map_array.values()),
138        map_array.nulls().cloned(),
139    )))
140}
141
142fn get_map_values_field_as_list_field(map_type: &DataType) -> Result<FieldRef> {
143    let map_fields = get_map_entry_field(map_type)?;
144
145    let values_field = map_fields
146        .last()
147        .unwrap()
148        .deref()
149        .clone()
150        .with_name(Field::LIST_FIELD_DEFAULT_NAME);
151
152    Ok(Arc::new(values_field))
153}
154
155#[cfg(test)]
156mod tests {
157    use crate::map_values::MapValuesFunc;
158    use arrow::datatypes::{DataType, Field, FieldRef};
159    use datafusion_common::ScalarValue;
160    use datafusion_expr::ScalarUDFImpl;
161    use std::sync::Arc;
162
163    #[test]
164    fn return_type_field() {
165        fn get_map_field(
166            is_map_nullable: bool,
167            is_keys_nullable: bool,
168            is_values_nullable: bool,
169        ) -> FieldRef {
170            Field::new_map(
171                "something",
172                "entries",
173                Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)),
174                Arc::new(Field::new(
175                    "values",
176                    DataType::LargeUtf8,
177                    is_values_nullable,
178                )),
179                false,
180                is_map_nullable,
181            )
182            .into()
183        }
184
185        fn get_list_field(
186            name: &str,
187            is_list_nullable: bool,
188            list_item_type: DataType,
189            is_list_items_nullable: bool,
190        ) -> FieldRef {
191            Field::new_list(
192                name,
193                Arc::new(Field::new_list_field(
194                    list_item_type,
195                    is_list_items_nullable,
196                )),
197                is_list_nullable,
198            )
199            .into()
200        }
201
202        fn get_return_field(field: FieldRef) -> FieldRef {
203            let func = MapValuesFunc::new();
204            let args = datafusion_expr::ReturnFieldArgs {
205                arg_fields: &[field],
206                scalar_arguments: &[None::<&ScalarValue>],
207            };
208
209            func.return_field_from_args(args).unwrap()
210        }
211
212        // Test cases:
213        //
214        // |                      Input Map                         ||                   Expected Output                     |
215        // | ------------------------------------------------------ || ----------------------------------------------------- |
216        // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
217        // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
218        // | false        | false             | false               || false                  | false                        |
219        // | false        | false             | true                || false                  | true                         |
220        // | false        | true              | false               || false                  | false                        |
221        // | false        | true              | true                || false                  | true                         |
222        // | true         | false             | false               || true                   | false                        |
223        // | true         | false             | true                || true                   | true                         |
224        // | true         | true              | false               || true                   | false                        |
225        // | true         | true              | true                || true                   | true                         |
226        //
227        // ---------------
228        // We added the key nullability to show that it does not affect the nullability of the list or the list items.
229
230        assert_eq!(
231            get_return_field(get_map_field(false, false, false)),
232            get_list_field("map_values", false, DataType::LargeUtf8, false)
233        );
234
235        assert_eq!(
236            get_return_field(get_map_field(false, false, true)),
237            get_list_field("map_values", false, DataType::LargeUtf8, true)
238        );
239
240        assert_eq!(
241            get_return_field(get_map_field(false, true, false)),
242            get_list_field("map_values", false, DataType::LargeUtf8, false)
243        );
244
245        assert_eq!(
246            get_return_field(get_map_field(false, true, true)),
247            get_list_field("map_values", false, DataType::LargeUtf8, true)
248        );
249
250        assert_eq!(
251            get_return_field(get_map_field(true, false, false)),
252            get_list_field("map_values", true, DataType::LargeUtf8, false)
253        );
254
255        assert_eq!(
256            get_return_field(get_map_field(true, false, true)),
257            get_list_field("map_values", true, DataType::LargeUtf8, true)
258        );
259
260        assert_eq!(
261            get_return_field(get_map_field(true, true, false)),
262            get_list_field("map_values", true, DataType::LargeUtf8, false)
263        );
264
265        assert_eq!(
266            get_return_field(get_map_field(true, true, true)),
267            get_list_field("map_values", true, DataType::LargeUtf8, true)
268        );
269    }
270}