Skip to main content

datafusion_functions_nested/
map_values.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for map_values function.
19
20use crate::utils::{get_map_entry_field, make_scalar_function};
21use arrow::array::{Array, ArrayRef, ListArray};
22use arrow::datatypes::{DataType, Field, FieldRef};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, cast::as_map_array, exec_err, internal_err};
25use datafusion_expr::{
26    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs,
27    ScalarUDFImpl, Signature, TypeSignature, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::ops::Deref;
31use std::sync::Arc;
32
33make_udf_expr_and_func!(
34    MapValuesFunc,
35    map_values,
36    map,
37    "Return a list of all values in the map.",
38    map_values_udf
39);
40
41#[user_doc(
42    doc_section(label = "Map Functions"),
43    description = "Returns a list of all values in the map.",
44    syntax_example = "map_values(map)",
45    sql_example = r#"```sql
46SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3});
47----
48[1, , 3]
49
50SELECT map_values(map([100, 5], [42, 43]));
51----
52[42, 43]
53```"#,
54    argument(
55        name = "map",
56        description = "Map expression. Can be a constant, column, or function, and any combination of map operators."
57    )
58)]
59#[derive(Debug, PartialEq, Eq, Hash)]
60pub(crate) struct MapValuesFunc {
61    signature: Signature,
62}
63
64impl Default for MapValuesFunc {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl MapValuesFunc {
71    pub fn new() -> Self {
72        Self {
73            signature: Signature::new(
74                TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl ScalarUDFImpl for MapValuesFunc {
82    fn name(&self) -> &str {
83        "map_values"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
91        internal_err!("return_field_from_args should be used instead")
92    }
93
94    fn return_field_from_args(
95        &self,
96        args: datafusion_expr::ReturnFieldArgs,
97    ) -> Result<FieldRef> {
98        let [map_type] = take_function_args(self.name(), args.arg_fields)?;
99
100        Ok(Field::new(
101            self.name(),
102            DataType::List(get_map_values_field_as_list_field(map_type.data_type())?),
103            // Nullable if the map is nullable
104            args.arg_fields.iter().any(|x| x.is_nullable()),
105        )
106        .into())
107    }
108
109    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110        make_scalar_function(map_values_inner)(&args.args)
111    }
112
113    fn documentation(&self) -> Option<&Documentation> {
114        self.doc()
115    }
116}
117
118fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
119    let [map_arg] = take_function_args("map_values", args)?;
120
121    let map_array = match map_arg.data_type() {
122        DataType::Map(_, _) => as_map_array(&map_arg)?,
123        _ => return exec_err!("Argument for map_values should be a map"),
124    };
125
126    Ok(Arc::new(ListArray::new(
127        get_map_values_field_as_list_field(map_arg.data_type())?,
128        map_array.offsets().clone(),
129        Arc::clone(map_array.values()),
130        map_array.nulls().cloned(),
131    )))
132}
133
134fn get_map_values_field_as_list_field(map_type: &DataType) -> Result<FieldRef> {
135    let map_fields = get_map_entry_field(map_type)?;
136
137    let values_field = map_fields
138        .last()
139        .unwrap()
140        .deref()
141        .clone()
142        .with_name(Field::LIST_FIELD_DEFAULT_NAME);
143
144    Ok(Arc::new(values_field))
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::map_values::MapValuesFunc;
150    use arrow::datatypes::{DataType, Field, FieldRef};
151    use datafusion_common::ScalarValue;
152    use datafusion_expr::ScalarUDFImpl;
153    use std::sync::Arc;
154
155    #[test]
156    fn return_type_field() {
157        fn get_map_field(
158            is_map_nullable: bool,
159            is_keys_nullable: bool,
160            is_values_nullable: bool,
161        ) -> FieldRef {
162            Field::new_map(
163                "something",
164                "entries",
165                Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)),
166                Arc::new(Field::new(
167                    "values",
168                    DataType::LargeUtf8,
169                    is_values_nullable,
170                )),
171                false,
172                is_map_nullable,
173            )
174            .into()
175        }
176
177        fn get_list_field(
178            name: &str,
179            is_list_nullable: bool,
180            list_item_type: DataType,
181            is_list_items_nullable: bool,
182        ) -> FieldRef {
183            Field::new_list(
184                name,
185                Arc::new(Field::new_list_field(
186                    list_item_type,
187                    is_list_items_nullable,
188                )),
189                is_list_nullable,
190            )
191            .into()
192        }
193
194        fn get_return_field(field: FieldRef) -> FieldRef {
195            let func = MapValuesFunc::new();
196            let args = datafusion_expr::ReturnFieldArgs {
197                arg_fields: &[field],
198                scalar_arguments: &[None::<&ScalarValue>],
199            };
200
201            func.return_field_from_args(args).unwrap()
202        }
203
204        // Test cases:
205        //
206        // |                      Input Map                         ||                   Expected Output                     |
207        // | ------------------------------------------------------ || ----------------------------------------------------- |
208        // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
209        // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
210        // | false        | false             | false               || false                  | false                        |
211        // | false        | false             | true                || false                  | true                         |
212        // | false        | true              | false               || false                  | false                        |
213        // | false        | true              | true                || false                  | true                         |
214        // | true         | false             | false               || true                   | false                        |
215        // | true         | false             | true                || true                   | true                         |
216        // | true         | true              | false               || true                   | false                        |
217        // | true         | true              | true                || true                   | true                         |
218        //
219        // ---------------
220        // We added the key nullability to show that it does not affect the nullability of the list or the list items.
221
222        assert_eq!(
223            get_return_field(get_map_field(false, false, false)),
224            get_list_field("map_values", false, DataType::LargeUtf8, false)
225        );
226
227        assert_eq!(
228            get_return_field(get_map_field(false, false, true)),
229            get_list_field("map_values", false, DataType::LargeUtf8, true)
230        );
231
232        assert_eq!(
233            get_return_field(get_map_field(false, true, false)),
234            get_list_field("map_values", false, DataType::LargeUtf8, false)
235        );
236
237        assert_eq!(
238            get_return_field(get_map_field(false, true, true)),
239            get_list_field("map_values", false, DataType::LargeUtf8, true)
240        );
241
242        assert_eq!(
243            get_return_field(get_map_field(true, false, false)),
244            get_list_field("map_values", true, DataType::LargeUtf8, false)
245        );
246
247        assert_eq!(
248            get_return_field(get_map_field(true, false, true)),
249            get_list_field("map_values", true, DataType::LargeUtf8, true)
250        );
251
252        assert_eq!(
253            get_return_field(get_map_field(true, true, false)),
254            get_list_field("map_values", true, DataType::LargeUtf8, false)
255        );
256
257        assert_eq!(
258            get_return_field(get_map_field(true, true, true)),
259            get_list_field("map_values", true, DataType::LargeUtf8, true)
260        );
261    }
262}