datafusion_spark/function/map/
map_from_entries.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::map::utils::{
22 get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls,
23 map_type_from_key_value_types,
24};
25use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray};
26use arrow::buffer::NullBuffer;
27use arrow::datatypes::{DataType, Field, FieldRef};
28use datafusion_common::utils::take_function_args;
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31 ColumnarValue, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34
35#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct MapFromEntries {
39 signature: Signature,
40}
41
42impl Default for MapFromEntries {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl MapFromEntries {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::array(Volatility::Immutable),
52 }
53 }
54}
55
56impl ScalarUDFImpl for MapFromEntries {
57 fn as_any(&self) -> &dyn Any {
58 self
59 }
60
61 fn name(&self) -> &str {
62 "map_from_entries"
63 }
64
65 fn signature(&self) -> &Signature {
66 &self.signature
67 }
68
69 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
70 internal_err!("return_field_from_args should be used instead")
71 }
72
73 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
74 let [entries_field] = args.arg_fields else {
75 return exec_err!("map_from_entries: expected one argument");
76 };
77
78 let (entries_element_field, entries_element_type) =
79 match entries_field.data_type() {
80 DataType::List(field)
81 | DataType::LargeList(field)
82 | DataType::FixedSizeList(field, _) => {
83 Ok((field.as_ref(), field.data_type()))
84 }
85 wrong_type => exec_err!(
86 "map_from_entries: expected array<struct<key, value>>, got {:?}",
87 wrong_type
88 ),
89 }?;
90
91 let (keys_type, values_type) = match entries_element_type {
92 DataType::Struct(fields) if fields.len() == 2 => {
93 Ok((fields[0].data_type(), fields[1].data_type()))
94 }
95 wrong_type => exec_err!(
96 "map_from_entries: expected array<struct<key, value>>, got {:?}",
97 wrong_type
98 ),
99 }?;
100
101 let map_type = map_type_from_key_value_types(keys_type, values_type);
102 let nullable = entries_field.is_nullable() || entries_element_field.is_nullable();
103
104 Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
105 }
106
107 fn invoke_with_args(
108 &self,
109 args: datafusion_expr::ScalarFunctionArgs,
110 ) -> Result<ColumnarValue> {
111 make_scalar_function(map_from_entries_inner, vec![])(&args.args)
112 }
113}
114
115fn map_from_entries_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
116 let [entries] = take_function_args("map_from_entries", args)?;
117 let entries_offsets = get_list_offsets(entries)?;
118 let entries_values = get_list_values(entries)?;
119
120 let (flat_keys, flat_values) =
121 match entries_values.as_any().downcast_ref::<StructArray>() {
122 Some(a) => Ok((a.column(0), a.column(1))),
123 None => exec_err!(
124 "map_from_entries: expected array<struct<key, value>>, got {:?}",
125 entries_values.data_type()
126 ),
127 }?;
128
129 let entries_with_nulls = entries_values.nulls().and_then(|entries_inner_nulls| {
130 let mut builder = NullBufferBuilder::new_with_len(0);
131 let mut cur_offset = entries_offsets
132 .first()
133 .map(|offset| *offset as usize)
134 .unwrap_or(0);
135
136 for next_offset in entries_offsets.iter().skip(1) {
137 let num_entries = *next_offset as usize - cur_offset;
138 builder.append(
139 entries_inner_nulls
140 .slice(cur_offset, num_entries)
141 .null_count()
142 == 0,
143 );
144 cur_offset = *next_offset as usize;
145 }
146 builder.finish()
147 });
148
149 let res_nulls = NullBuffer::union(entries.nulls(), entries_with_nulls.as_ref());
150
151 map_from_keys_values_offsets_nulls(
152 flat_keys,
153 flat_values,
154 &entries_offsets,
155 &entries_offsets,
156 None,
157 res_nulls.as_ref(),
158 )
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use arrow::datatypes::Fields;
165 use datafusion_expr::ReturnFieldArgs;
166
167 fn make_entries_field(array_nullable: bool, element_nullable: bool) -> FieldRef {
168 let struct_type = DataType::Struct(Fields::from(vec![
169 Field::new("key", DataType::Int32, false),
170 Field::new("value", DataType::Utf8, true),
171 ]));
172 Arc::new(Field::new(
173 "entries",
174 DataType::List(Arc::new(Field::new("item", struct_type, element_nullable))),
175 array_nullable,
176 ))
177 }
178
179 #[test]
180 fn test_map_from_entries_nullability_matches_input() {
181 let func = MapFromEntries::new();
182 let expected_type =
183 map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
184
185 let non_nullable_field = make_entries_field(false, false);
187 let result = func
188 .return_field_from_args(ReturnFieldArgs {
189 arg_fields: &[Arc::clone(&non_nullable_field)],
190 scalar_arguments: &[None],
191 })
192 .expect("should infer field");
193 assert!(!result.is_nullable());
194 assert_eq!(result.data_type(), &expected_type);
195
196 let element_nullable_field = make_entries_field(false, true);
198 let result = func
199 .return_field_from_args(ReturnFieldArgs {
200 arg_fields: &[Arc::clone(&element_nullable_field)],
201 scalar_arguments: &[None],
202 })
203 .expect("should infer field");
204 assert!(result.is_nullable());
205 assert_eq!(result.data_type(), &expected_type);
206
207 let array_nullable_field = make_entries_field(true, false);
209 let result = func
210 .return_field_from_args(ReturnFieldArgs {
211 arg_fields: &[Arc::clone(&array_nullable_field)],
212 scalar_arguments: &[None],
213 })
214 .expect("should infer field");
215 assert!(result.is_nullable());
216 assert_eq!(result.data_type(), &expected_type);
217 }
218}