datafusion_functions_nested/
map.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::collections::VecDeque;
20use std::sync::Arc;
21
22use arrow::array::{Array, ArrayData, ArrayRef, MapArray, OffsetSizeTrait, StructArray};
23use arrow::buffer::Buffer;
24use arrow::datatypes::{DataType, Field, SchemaBuilder, ToByteSlice};
25
26use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays};
27use datafusion_common::{
28    exec_err, utils::take_function_args, HashSet, Result, ScalarValue,
29};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::{
32    ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::make_array::make_array;
37
38/// Returns a map created from a key list and a value list
39pub fn map(keys: Vec<Expr>, values: Vec<Expr>) -> Expr {
40    let keys = make_array(keys);
41    let values = make_array(values);
42    Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values]))
43}
44
45create_func!(MapFunc, map_udf);
46
47/// Check if we can evaluate the expr to constant directly.
48///
49/// # Example
50/// ```sql
51/// SELECT make_map('type', 'test') from test
52/// ```
53/// We can evaluate the result of `make_map` directly.
54fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool {
55    args.iter()
56        .all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
57}
58
59fn make_map_batch(args: &[ColumnarValue]) -> Result<ColumnarValue> {
60    let [keys_arg, values_arg] = take_function_args("make_map", args)?;
61
62    let can_evaluate_to_const = can_evaluate_to_const(args);
63
64    // check the keys array is unique
65    let keys = get_first_array_ref(keys_arg)?;
66    if keys.null_count() > 0 {
67        return exec_err!("map key cannot be null");
68    }
69    let key_array = keys.as_ref();
70
71    match keys_arg {
72        ColumnarValue::Array(_) => {
73            let row_keys = match key_array.data_type() {
74                DataType::List(_) => list_to_arrays::<i32>(&keys),
75                DataType::LargeList(_) => list_to_arrays::<i64>(&keys),
76                DataType::FixedSizeList(_, _) => fixed_size_list_to_arrays(&keys),
77                data_type => {
78                    return exec_err!(
79                        "Expected list, large_list or fixed_size_list, got {:?}",
80                        data_type
81                    );
82                }
83            };
84
85            row_keys
86                .iter()
87                .try_for_each(|key| check_unique_keys(key.as_ref()))?;
88        }
89        ColumnarValue::Scalar(_) => {
90            check_unique_keys(key_array)?;
91        }
92    }
93
94    let values = get_first_array_ref(values_arg)?;
95    make_map_batch_internal(keys, values, can_evaluate_to_const, keys_arg.data_type())
96}
97
98fn check_unique_keys(array: &dyn Array) -> Result<()> {
99    let mut seen_keys = HashSet::with_capacity(array.len());
100
101    for i in 0..array.len() {
102        let key = ScalarValue::try_from_array(array, i)?;
103        if seen_keys.contains(&key) {
104            return exec_err!("map key must be unique, duplicate key found: {}", key);
105        }
106        seen_keys.insert(key);
107    }
108    Ok(())
109}
110
111fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result<ArrayRef> {
112    match columnar_value {
113        ColumnarValue::Scalar(value) => match value {
114            ScalarValue::List(array) => Ok(array.value(0)),
115            ScalarValue::LargeList(array) => Ok(array.value(0)),
116            ScalarValue::FixedSizeList(array) => Ok(array.value(0)),
117            _ => exec_err!("Expected array, got {:?}", value),
118        },
119        ColumnarValue::Array(array) => Ok(array.to_owned()),
120    }
121}
122
123fn make_map_batch_internal(
124    keys: ArrayRef,
125    values: ArrayRef,
126    can_evaluate_to_const: bool,
127    data_type: DataType,
128) -> Result<ColumnarValue> {
129    if keys.len() != values.len() {
130        return exec_err!("map requires key and value lists to have the same length");
131    }
132
133    if !can_evaluate_to_const {
134        return if let DataType::LargeList(..) = data_type {
135            make_map_array_internal::<i64>(keys, values)
136        } else {
137            make_map_array_internal::<i32>(keys, values)
138        };
139    }
140
141    let key_field = Arc::new(Field::new("key", keys.data_type().clone(), false));
142    let value_field = Arc::new(Field::new("value", values.data_type().clone(), true));
143    let mut entry_struct_buffer: VecDeque<(Arc<Field>, ArrayRef)> = VecDeque::new();
144    let mut entry_offsets_buffer = VecDeque::new();
145    entry_offsets_buffer.push_back(0);
146
147    entry_struct_buffer.push_back((Arc::clone(&key_field), Arc::clone(&keys)));
148    entry_struct_buffer.push_back((Arc::clone(&value_field), Arc::clone(&values)));
149    entry_offsets_buffer.push_back(keys.len() as u32);
150
151    let entry_struct: Vec<(Arc<Field>, ArrayRef)> = entry_struct_buffer.into();
152    let entry_struct = StructArray::from(entry_struct);
153
154    let map_data_type = DataType::Map(
155        Arc::new(Field::new(
156            "entries",
157            entry_struct.data_type().clone(),
158            false,
159        )),
160        false,
161    );
162
163    let entry_offsets: Vec<u32> = entry_offsets_buffer.into();
164    let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice());
165
166    let map_data = ArrayData::builder(map_data_type)
167        .len(entry_offsets.len() - 1)
168        .add_buffer(entry_offsets_buffer)
169        .add_child_data(entry_struct.to_data())
170        .build()?;
171    let map_array = Arc::new(MapArray::from(map_data));
172
173    Ok(if can_evaluate_to_const {
174        ColumnarValue::Scalar(ScalarValue::try_from_array(map_array.as_ref(), 0)?)
175    } else {
176        ColumnarValue::Array(map_array)
177    })
178}
179
180#[user_doc(
181    doc_section(label = "Map Functions"),
182    description = "Returns an Arrow map with the specified key-value pairs.\n\n\
183    The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null.",
184    syntax_example = "map(key, value)\nmap(key: value)\nmake_map(['key1', 'key2'], ['value1', 'value2'])",
185    sql_example = r#"
186```sql
187-- Using map function
188SELECT MAP('type', 'test');
189----
190{type: test}
191
192SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]);
193----
194{POST: 41, HEAD: 33, PATCH: NULL}
195
196SELECT MAP([[1,2], [3,4]], ['a', 'b']);
197----
198{[1, 2]: a, [3, 4]: b}
199
200SELECT MAP { 'a': 1, 'b': 2 };
201----
202{a: 1, b: 2}
203
204-- Using make_map function
205SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]);
206----
207{POST: 41, HEAD: 33}
208
209SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]);
210----
211{key1: value1, key2: }
212```"#,
213    argument(
214        name = "key",
215        description = "For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\
216                        For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null."
217    ),
218    argument(
219        name = "value",
220        description = "For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\
221                        For `make_map`: The list of values to be mapped to the corresponding keys."
222    )
223)]
224#[derive(Debug)]
225pub struct MapFunc {
226    signature: Signature,
227}
228
229impl Default for MapFunc {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235impl MapFunc {
236    pub fn new() -> Self {
237        Self {
238            signature: Signature::variadic_any(Volatility::Immutable),
239        }
240    }
241}
242
243impl ScalarUDFImpl for MapFunc {
244    fn as_any(&self) -> &dyn Any {
245        self
246    }
247
248    fn name(&self) -> &str {
249        "map"
250    }
251
252    fn signature(&self) -> &Signature {
253        &self.signature
254    }
255
256    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
257        let [keys_arg, values_arg] = take_function_args(self.name(), arg_types)?;
258        let mut builder = SchemaBuilder::new();
259        builder.push(Field::new(
260            "key",
261            get_element_type(keys_arg)?.clone(),
262            false,
263        ));
264        builder.push(Field::new(
265            "value",
266            get_element_type(values_arg)?.clone(),
267            true,
268        ));
269        let fields = builder.finish().fields;
270        Ok(DataType::Map(
271            Arc::new(Field::new("entries", DataType::Struct(fields), false)),
272            false,
273        ))
274    }
275
276    fn invoke_with_args(
277        &self,
278        args: datafusion_expr::ScalarFunctionArgs,
279    ) -> Result<ColumnarValue> {
280        make_map_batch(&args.args)
281    }
282
283    fn documentation(&self) -> Option<&Documentation> {
284        self.doc()
285    }
286}
287
288fn get_element_type(data_type: &DataType) -> Result<&DataType> {
289    match data_type {
290        DataType::List(element) => Ok(element.data_type()),
291        DataType::LargeList(element) => Ok(element.data_type()),
292        DataType::FixedSizeList(element, _) => Ok(element.data_type()),
293        _ => exec_err!(
294            "Expected list, large_list or fixed_size_list, got {:?}",
295            data_type
296        ),
297    }
298}
299
300/// Helper function to create MapArray from array of values to support arrays for Map scalar function
301///
302/// ``` text
303/// Format of input KEYS and VALUES column
304///         keys                        values
305/// +---------------------+       +---------------------+
306/// | +-----------------+ |       | +-----------------+ |
307/// | | [k11, k12, k13] | |       | | [v11, v12, v13] | |
308/// | +-----------------+ |       | +-----------------+ |
309/// |                     |       |                     |
310/// | +-----------------+ |       | +-----------------+ |
311/// | | [k21, k22, k23] | |       | | [v21, v22, v23] | |
312/// | +-----------------+ |       | +-----------------+ |
313/// |                     |       |                     |
314/// | +-----------------+ |       | +-----------------+ |
315/// | |[k31, k32, k33]  | |       | |[v31, v32, v33]  | |
316/// | +-----------------+ |       | +-----------------+ |
317/// +---------------------+       +---------------------+
318/// ```
319/// Flattened keys and values array to user create `StructArray`,
320/// which serves as inner child for `MapArray`
321///
322/// ``` text
323/// Flattened           Flattened
324/// Keys                Values
325/// +-----------+      +-----------+
326/// | +-------+ |      | +-------+ |
327/// | |  k11  | |      | |  v11  | |
328/// | +-------+ |      | +-------+ |
329/// | +-------+ |      | +-------+ |
330/// | |  k12  | |      | |  v12  | |
331/// | +-------+ |      | +-------+ |
332/// | +-------+ |      | +-------+ |
333/// | |  k13  | |      | |  v13  | |
334/// | +-------+ |      | +-------+ |
335/// | +-------+ |      | +-------+ |
336/// | |  k21  | |      | |  v21  | |
337/// | +-------+ |      | +-------+ |
338/// | +-------+ |      | +-------+ |
339/// | |  k22  | |      | |  v22  | |
340/// | +-------+ |      | +-------+ |
341/// | +-------+ |      | +-------+ |
342/// | |  k23  | |      | |  v23  | |
343/// | +-------+ |      | +-------+ |
344/// | +-------+ |      | +-------+ |
345/// | |  k31  | |      | |  v31  | |
346/// | +-------+ |      | +-------+ |
347/// | +-------+ |      | +-------+ |
348/// | |  k32  | |      | |  v32  | |
349/// | +-------+ |      | +-------+ |
350/// | +-------+ |      | +-------+ |
351/// | |  k33  | |      | |  v33  | |
352/// | +-------+ |      | +-------+ |
353/// +-----------+      +-----------+
354/// ```text
355fn make_map_array_internal<O: OffsetSizeTrait>(
356    keys: ArrayRef,
357    values: ArrayRef,
358) -> Result<ColumnarValue> {
359    let mut offset_buffer = vec![O::zero()];
360    let mut running_offset = O::zero();
361
362    let keys = list_to_arrays::<O>(&keys);
363    let values = list_to_arrays::<O>(&values);
364
365    let mut key_array_vec = vec![];
366    let mut value_array_vec = vec![];
367    for (k, v) in keys.iter().zip(values.iter()) {
368        running_offset = running_offset.add(O::usize_as(k.len()));
369        offset_buffer.push(running_offset);
370        key_array_vec.push(k.as_ref());
371        value_array_vec.push(v.as_ref());
372    }
373
374    // concatenate all the arrays
375    let flattened_keys = arrow::compute::concat(key_array_vec.as_ref())?;
376    if flattened_keys.null_count() > 0 {
377        return exec_err!("keys cannot be null");
378    }
379    let flattened_values = arrow::compute::concat(value_array_vec.as_ref())?;
380
381    let fields = vec![
382        Arc::new(Field::new("key", flattened_keys.data_type().clone(), false)),
383        Arc::new(Field::new(
384            "value",
385            flattened_values.data_type().clone(),
386            true,
387        )),
388    ];
389
390    let struct_data = ArrayData::builder(DataType::Struct(fields.into()))
391        .len(flattened_keys.len())
392        .add_child_data(flattened_keys.to_data())
393        .add_child_data(flattened_values.to_data())
394        .build()?;
395
396    let map_data = ArrayData::builder(DataType::Map(
397        Arc::new(Field::new(
398            "entries",
399            struct_data.data_type().clone(),
400            false,
401        )),
402        false,
403    ))
404    .len(keys.len())
405    .add_child_data(struct_data)
406    .add_buffer(Buffer::from_slice_ref(offset_buffer.as_slice()))
407    .build()?;
408    Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data))))
409}