datafusion_functions_nested/
map_keys.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for map_keys function.
19
20use crate::utils::{get_map_entry_field, make_scalar_function};
21use arrow::array::{Array, ArrayRef, ListArray};
22use arrow::datatypes::{DataType, Field};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{cast::as_map_array, exec_err, Result};
25use datafusion_expr::{
26    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
27    TypeSignature, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::any::Any;
31use std::sync::Arc;
32
33make_udf_expr_and_func!(
34    MapKeysFunc,
35    map_keys,
36    map,
37    "Return a list of all keys in the map.",
38    map_keys_udf
39);
40
41#[user_doc(
42    doc_section(label = "Map Functions"),
43    description = "Returns a list of all keys in the map.",
44    syntax_example = "map_keys(map)",
45    sql_example = r#"```sql
46SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3});
47----
48[a, b, c]
49
50SELECT map_keys(map([100, 5], [42, 43]));
51----
52[100, 5]
53```"#,
54    argument(
55        name = "map",
56        description = "Map expression. Can be a constant, column, or function, and any combination of map operators."
57    )
58)]
59#[derive(Debug)]
60pub struct MapKeysFunc {
61    signature: Signature,
62}
63
64impl Default for MapKeysFunc {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl MapKeysFunc {
71    pub fn new() -> Self {
72        Self {
73            signature: Signature::new(
74                TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl ScalarUDFImpl for MapKeysFunc {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "map_keys"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95        let [map_type] = take_function_args(self.name(), arg_types)?;
96        let map_fields = get_map_entry_field(map_type)?;
97        Ok(DataType::List(Arc::new(Field::new_list_field(
98            map_fields.first().unwrap().data_type().clone(),
99            false,
100        ))))
101    }
102
103    fn invoke_with_args(
104        &self,
105        args: datafusion_expr::ScalarFunctionArgs,
106    ) -> Result<ColumnarValue> {
107        make_scalar_function(map_keys_inner)(&args.args)
108    }
109
110    fn documentation(&self) -> Option<&Documentation> {
111        self.doc()
112    }
113}
114
115fn map_keys_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
116    let [map_arg] = take_function_args("map_keys", args)?;
117
118    let map_array = match map_arg.data_type() {
119        DataType::Map(_, _) => as_map_array(&map_arg)?,
120        _ => return exec_err!("Argument for map_keys should be a map"),
121    };
122
123    Ok(Arc::new(ListArray::new(
124        Arc::new(Field::new_list_field(map_array.key_type().clone(), false)),
125        map_array.offsets().clone(),
126        Arc::clone(map_array.keys()),
127        map_array.nulls().cloned(),
128    )))
129}