Skip to main content

datafusion_spark/function/map/
map_from_arrays.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19
20use crate::function::map::utils::{
21    get_element_type, get_list_offsets, get_list_values,
22    map_from_keys_values_offsets_nulls, map_type_from_key_value_types,
23};
24use arrow::array::{Array, ArrayRef, NullArray};
25use arrow::compute::kernels::cast;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::utils::take_function_args;
28use datafusion_common::{Result, internal_err};
29use datafusion_expr::{
30    ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33use std::sync::Arc;
34
35/// Spark-compatible `map_from_arrays` expression
36/// <https://spark.apache.org/docs/latest/api/sql/index.html#map_from_arrays>
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct MapFromArrays {
39    signature: Signature,
40}
41
42impl Default for MapFromArrays {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl MapFromArrays {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::any(2, Volatility::Immutable),
52        }
53    }
54}
55
56impl ScalarUDFImpl for MapFromArrays {
57    fn as_any(&self) -> &dyn Any {
58        self
59    }
60
61    fn name(&self) -> &str {
62        "map_from_arrays"
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70        internal_err!("return_field_from_args should be used instead")
71    }
72
73    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
74        let [keys_field, values_field] = args.arg_fields else {
75            return internal_err!("map_from_arrays expects exactly 2 arguments");
76        };
77
78        let map_type = map_type_from_key_value_types(
79            get_element_type(keys_field.data_type())?,
80            get_element_type(values_field.data_type())?,
81        );
82        // Spark marks map_from_arrays as null intolerant, so the output is
83        // nullable if either input is nullable.
84        let nullable = keys_field.is_nullable() || values_field.is_nullable();
85        Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
86    }
87
88    fn invoke_with_args(
89        &self,
90        args: datafusion_expr::ScalarFunctionArgs,
91    ) -> Result<ColumnarValue> {
92        make_scalar_function(map_from_arrays_inner, vec![])(&args.args)
93    }
94}
95
96fn map_from_arrays_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
97    let [keys, values] = take_function_args("map_from_arrays", args)?;
98
99    if matches!(keys.data_type(), DataType::Null)
100        || matches!(values.data_type(), DataType::Null)
101    {
102        return Ok(cast(
103            &NullArray::new(keys.len()),
104            &map_type_from_key_value_types(
105                get_element_type(keys.data_type())?,
106                get_element_type(values.data_type())?,
107            ),
108        )?);
109    }
110
111    map_from_keys_values_offsets_nulls(
112        get_list_values(keys)?,
113        get_list_values(values)?,
114        &get_list_offsets(keys)?,
115        &get_list_offsets(values)?,
116        keys.nulls(),
117        values.nulls(),
118    )
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use arrow::datatypes::Field;
125    use datafusion_expr::ReturnFieldArgs;
126
127    #[test]
128    fn test_map_from_arrays_nullability_and_type() {
129        let func = MapFromArrays::new();
130
131        let keys_field: FieldRef = Arc::new(Field::new(
132            "keys",
133            DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
134            false,
135        ));
136        let values_field: FieldRef = Arc::new(Field::new(
137            "values",
138            DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
139            false,
140        ));
141
142        let out = func
143            .return_field_from_args(ReturnFieldArgs {
144                arg_fields: &[Arc::clone(&keys_field), Arc::clone(&values_field)],
145                scalar_arguments: &[None, None],
146            })
147            .expect("return_field_from_args should succeed");
148
149        let expected_type =
150            map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
151        assert_eq!(out.data_type(), &expected_type);
152        assert!(
153            !out.is_nullable(),
154            "map_from_arrays should be non-nullable when both inputs are non-nullable"
155        );
156
157        let nullable_keys: FieldRef = Arc::new(Field::new(
158            "keys",
159            DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
160            true,
161        ));
162
163        let out_nullable = func
164            .return_field_from_args(ReturnFieldArgs {
165                arg_fields: &[nullable_keys, values_field],
166                scalar_arguments: &[None, None],
167            })
168            .expect("return_field_from_args should succeed");
169
170        assert!(
171            out_nullable.is_nullable(),
172            "map_from_arrays should be nullable when any input is nullable"
173        );
174    }
175}