Skip to main content

datafusion_spark/function/map/
map_from_arrays.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::function::map::utils::{
19    get_element_type, get_list_offsets, get_list_values,
20    map_from_keys_values_offsets_nulls, map_type_from_key_value_types,
21};
22use arrow::array::{Array, ArrayRef, NullArray};
23use arrow::compute::kernels::cast;
24use arrow::datatypes::{DataType, Field, FieldRef};
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, internal_err};
27use datafusion_expr::{
28    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29    Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32use std::sync::Arc;
33
34/// Spark-compatible `map_from_arrays` expression
35/// <https://spark.apache.org/docs/latest/api/sql/index.html#map_from_arrays>
36#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct MapFromArrays {
38    signature: Signature,
39}
40
41impl Default for MapFromArrays {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl MapFromArrays {
48    pub fn new() -> Self {
49        Self {
50            signature: Signature::any(2, Volatility::Immutable),
51        }
52    }
53}
54
55impl ScalarUDFImpl for MapFromArrays {
56    fn name(&self) -> &str {
57        "map_from_arrays"
58    }
59
60    fn signature(&self) -> &Signature {
61        &self.signature
62    }
63
64    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65        internal_err!("return_field_from_args should be used instead")
66    }
67
68    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
69        let [keys_field, values_field] = args.arg_fields else {
70            return internal_err!("map_from_arrays expects exactly 2 arguments");
71        };
72
73        let map_type = map_type_from_key_value_types(
74            get_element_type(keys_field.data_type())?,
75            get_element_type(values_field.data_type())?,
76        );
77        // Spark marks map_from_arrays as null intolerant, so the output is
78        // nullable if either input is nullable.
79        let nullable = keys_field.is_nullable() || values_field.is_nullable();
80        Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
81    }
82
83    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84        make_scalar_function(map_from_arrays_inner, vec![])(&args.args)
85    }
86}
87
88fn map_from_arrays_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
89    let [keys, values] = take_function_args("map_from_arrays", args)?;
90
91    if *keys.data_type() == DataType::Null || *values.data_type() == DataType::Null {
92        return Ok(cast(
93            &NullArray::new(keys.len()),
94            &map_type_from_key_value_types(
95                get_element_type(keys.data_type())?,
96                get_element_type(values.data_type())?,
97            ),
98        )?);
99    }
100
101    map_from_keys_values_offsets_nulls(
102        get_list_values(keys)?,
103        get_list_values(values)?,
104        &get_list_offsets(keys)?,
105        &get_list_offsets(values)?,
106        keys.nulls(),
107        values.nulls(),
108    )
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn test_map_from_arrays_nullability_and_type() {
117        let func = MapFromArrays::new();
118
119        let keys_field: FieldRef = Arc::new(Field::new(
120            "keys",
121            DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
122            false,
123        ));
124        let values_field: FieldRef = Arc::new(Field::new(
125            "values",
126            DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
127            false,
128        ));
129
130        let out = func
131            .return_field_from_args(ReturnFieldArgs {
132                arg_fields: &[Arc::clone(&keys_field), Arc::clone(&values_field)],
133                scalar_arguments: &[None, None],
134            })
135            .expect("return_field_from_args should succeed");
136
137        let expected_type =
138            map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
139        assert_eq!(out.data_type(), &expected_type);
140        assert!(
141            !out.is_nullable(),
142            "map_from_arrays should be non-nullable when both inputs are non-nullable"
143        );
144
145        let nullable_keys: FieldRef = Arc::new(Field::new(
146            "keys",
147            DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
148            true,
149        ));
150
151        let out_nullable = func
152            .return_field_from_args(ReturnFieldArgs {
153                arg_fields: &[nullable_keys, values_field],
154                scalar_arguments: &[None, None],
155            })
156            .expect("return_field_from_args should succeed");
157
158        assert!(
159            out_nullable.is_nullable(),
160            "map_from_arrays should be nullable when any input is nullable"
161        );
162    }
163}