datafusion_spark/function/map/
map_from_entries.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::map::utils::{
22    get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls,
23    map_type_from_key_value_types,
24};
25use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray};
26use arrow::buffer::NullBuffer;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31    ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34
35/// Spark-compatible `map_from_entries` expression
36/// <https://spark.apache.org/docs/latest/api/sql/index.html#map_from_entries>
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct MapFromEntries {
39    signature: Signature,
40}
41
42impl Default for MapFromEntries {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl MapFromEntries {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::array(Volatility::Immutable),
52        }
53    }
54}
55
56impl ScalarUDFImpl for MapFromEntries {
57    fn as_any(&self) -> &dyn Any {
58        self
59    }
60
61    fn name(&self) -> &str {
62        "map_from_entries"
63    }
64
65    fn signature(&self) -> &Signature {
66        &self.signature
67    }
68
69    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70        internal_err!("return_field_from_args should be used instead")
71    }
72
73    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
74        let [entries_field] = args.arg_fields else {
75            return exec_err!("map_from_entries: expected one argument");
76        };
77
78        let (entries_element_field, entries_element_type) =
79            match entries_field.data_type() {
80                DataType::List(field)
81                | DataType::LargeList(field)
82                | DataType::FixedSizeList(field, _) => {
83                    Ok((field.as_ref(), field.data_type()))
84                }
85                wrong_type => exec_err!(
86                    "map_from_entries: expected array<struct<key, value>>, got {:?}",
87                    wrong_type
88                ),
89            }?;
90
91        let (keys_type, values_type) = match entries_element_type {
92            DataType::Struct(fields) if fields.len() == 2 => {
93                Ok((fields[0].data_type(), fields[1].data_type()))
94            }
95            wrong_type => exec_err!(
96                "map_from_entries: expected array<struct<key, value>>, got {:?}",
97                wrong_type
98            ),
99        }?;
100
101        let map_type = map_type_from_key_value_types(keys_type, values_type);
102        let nullable = entries_field.is_nullable() || entries_element_field.is_nullable();
103
104        Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
105    }
106
107    fn invoke_with_args(
108        &self,
109        args: datafusion_expr::ScalarFunctionArgs,
110    ) -> Result<ColumnarValue> {
111        make_scalar_function(map_from_entries_inner, vec![])(&args.args)
112    }
113}
114
115fn map_from_entries_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
116    let [entries] = take_function_args("map_from_entries", args)?;
117    let entries_offsets = get_list_offsets(entries)?;
118    let entries_values = get_list_values(entries)?;
119
120    let (flat_keys, flat_values) =
121        match entries_values.as_any().downcast_ref::<StructArray>() {
122            Some(a) => Ok((a.column(0), a.column(1))),
123            None => exec_err!(
124                "map_from_entries: expected array<struct<key, value>>, got {:?}",
125                entries_values.data_type()
126            ),
127        }?;
128
129    let entries_with_nulls = entries_values.nulls().and_then(|entries_inner_nulls| {
130        let mut builder = NullBufferBuilder::new_with_len(0);
131        let mut cur_offset = entries_offsets
132            .first()
133            .map(|offset| *offset as usize)
134            .unwrap_or(0);
135
136        for next_offset in entries_offsets.iter().skip(1) {
137            let num_entries = *next_offset as usize - cur_offset;
138            builder.append(
139                entries_inner_nulls
140                    .slice(cur_offset, num_entries)
141                    .null_count()
142                    == 0,
143            );
144            cur_offset = *next_offset as usize;
145        }
146        builder.finish()
147    });
148
149    let res_nulls = NullBuffer::union(entries.nulls(), entries_with_nulls.as_ref());
150
151    map_from_keys_values_offsets_nulls(
152        flat_keys,
153        flat_values,
154        &entries_offsets,
155        &entries_offsets,
156        None,
157        res_nulls.as_ref(),
158    )
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use arrow::datatypes::Fields;
165    use datafusion_expr::ReturnFieldArgs;
166
167    fn make_entries_field(array_nullable: bool, element_nullable: bool) -> FieldRef {
168        let struct_type = DataType::Struct(Fields::from(vec![
169            Field::new("key", DataType::Int32, false),
170            Field::new("value", DataType::Utf8, true),
171        ]));
172        Arc::new(Field::new(
173            "entries",
174            DataType::List(Arc::new(Field::new("item", struct_type, element_nullable))),
175            array_nullable,
176        ))
177    }
178
179    #[test]
180    fn test_map_from_entries_nullability_matches_input() {
181        let func = MapFromEntries::new();
182        let expected_type =
183            map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
184
185        // Non-nullable array and elements => non-nullable result
186        let non_nullable_field = make_entries_field(false, false);
187        let result = func
188            .return_field_from_args(ReturnFieldArgs {
189                arg_fields: &[Arc::clone(&non_nullable_field)],
190                scalar_arguments: &[None],
191            })
192            .expect("should infer field");
193        assert!(!result.is_nullable());
194        assert_eq!(result.data_type(), &expected_type);
195
196        // Nullable elements should make result nullable even if array is non-nullable
197        let element_nullable_field = make_entries_field(false, true);
198        let result = func
199            .return_field_from_args(ReturnFieldArgs {
200                arg_fields: &[Arc::clone(&element_nullable_field)],
201                scalar_arguments: &[None],
202            })
203            .expect("should infer field");
204        assert!(result.is_nullable());
205        assert_eq!(result.data_type(), &expected_type);
206
207        // Nullable array should also yield nullable result
208        let array_nullable_field = make_entries_field(true, false);
209        let result = func
210            .return_field_from_args(ReturnFieldArgs {
211                arg_fields: &[Arc::clone(&array_nullable_field)],
212                scalar_arguments: &[None],
213            })
214            .expect("should infer field");
215        assert!(result.is_nullable());
216        assert_eq!(result.data_type(), &expected_type);
217    }
218}