datafusion_spark/function/map/
map_from_entries.rs1use std::sync::Arc;
19
20use crate::function::map::utils::{
21 get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls,
22 map_type_from_key_value_types,
23};
24use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray};
25use arrow::buffer::NullBuffer;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::utils::take_function_args;
28use datafusion_common::{Result, exec_err, internal_err};
29use datafusion_expr::{
30 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31 Volatility,
32};
33use datafusion_functions::utils::make_scalar_function;
34
35#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct MapFromEntries {
39 signature: Signature,
40}
41
42impl Default for MapFromEntries {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl MapFromEntries {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::array(Volatility::Immutable),
52 }
53 }
54}
55
56impl ScalarUDFImpl for MapFromEntries {
57 fn name(&self) -> &str {
58 "map_from_entries"
59 }
60
61 fn signature(&self) -> &Signature {
62 &self.signature
63 }
64
65 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
66 internal_err!("return_field_from_args should be used instead")
67 }
68
69 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
70 let [entries_field] = args.arg_fields else {
71 return exec_err!("map_from_entries: expected one argument");
72 };
73
74 let (entries_element_field, entries_element_type) =
75 match entries_field.data_type() {
76 DataType::List(field)
77 | DataType::LargeList(field)
78 | DataType::FixedSizeList(field, _) => {
79 Ok((field.as_ref(), field.data_type()))
80 }
81 wrong_type => exec_err!(
82 "map_from_entries: expected array<struct<key, value>>, got {:?}",
83 wrong_type
84 ),
85 }?;
86
87 let (keys_type, values_type) = match entries_element_type {
88 DataType::Struct(fields) if fields.len() == 2 => {
89 Ok((fields[0].data_type(), fields[1].data_type()))
90 }
91 wrong_type => exec_err!(
92 "map_from_entries: expected array<struct<key, value>>, got {:?}",
93 wrong_type
94 ),
95 }?;
96
97 let map_type = map_type_from_key_value_types(keys_type, values_type);
98 let nullable = entries_field.is_nullable() || entries_element_field.is_nullable();
99
100 Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
101 }
102
103 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104 make_scalar_function(map_from_entries_inner, vec![])(&args.args)
105 }
106}
107
108fn map_from_entries_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
109 let [entries] = take_function_args("map_from_entries", args)?;
110 let entries_offsets = get_list_offsets(entries)?;
111 let entries_values = get_list_values(entries)?;
112
113 let (flat_keys, flat_values) =
114 match entries_values.as_any().downcast_ref::<StructArray>() {
115 Some(a) => Ok((a.column(0), a.column(1))),
116 None => exec_err!(
117 "map_from_entries: expected array<struct<key, value>>, got {:?}",
118 entries_values.data_type()
119 ),
120 }?;
121
122 let entries_with_nulls = entries_values.nulls().and_then(|entries_inner_nulls| {
123 let mut builder = NullBufferBuilder::new_with_len(0);
124 let mut cur_offset = entries_offsets
125 .first()
126 .map(|offset| *offset as usize)
127 .unwrap_or(0);
128
129 for next_offset in entries_offsets.iter().skip(1) {
130 let num_entries = *next_offset as usize - cur_offset;
131 builder.append(
132 entries_inner_nulls
133 .slice(cur_offset, num_entries)
134 .null_count()
135 == 0,
136 );
137 cur_offset = *next_offset as usize;
138 }
139 builder.finish()
140 });
141
142 let res_nulls = NullBuffer::union(entries.nulls(), entries_with_nulls.as_ref());
143
144 map_from_keys_values_offsets_nulls(
145 flat_keys,
146 flat_values,
147 &entries_offsets,
148 &entries_offsets,
149 None,
150 res_nulls.as_ref(),
151 )
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use arrow::datatypes::Fields;
158
159 fn make_entries_field(array_nullable: bool, element_nullable: bool) -> FieldRef {
160 let struct_type = DataType::Struct(Fields::from(vec![
161 Field::new("key", DataType::Int32, false),
162 Field::new("value", DataType::Utf8, true),
163 ]));
164 Arc::new(Field::new(
165 "entries",
166 DataType::List(Arc::new(Field::new("item", struct_type, element_nullable))),
167 array_nullable,
168 ))
169 }
170
171 #[test]
172 fn test_map_from_entries_nullability_matches_input() {
173 let func = MapFromEntries::new();
174 let expected_type =
175 map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
176
177 let non_nullable_field = make_entries_field(false, false);
179 let result = func
180 .return_field_from_args(ReturnFieldArgs {
181 arg_fields: &[Arc::clone(&non_nullable_field)],
182 scalar_arguments: &[None],
183 })
184 .expect("should infer field");
185 assert!(!result.is_nullable());
186 assert_eq!(result.data_type(), &expected_type);
187
188 let element_nullable_field = make_entries_field(false, true);
190 let result = func
191 .return_field_from_args(ReturnFieldArgs {
192 arg_fields: &[Arc::clone(&element_nullable_field)],
193 scalar_arguments: &[None],
194 })
195 .expect("should infer field");
196 assert!(result.is_nullable());
197 assert_eq!(result.data_type(), &expected_type);
198
199 let array_nullable_field = make_entries_field(true, false);
201 let result = func
202 .return_field_from_args(ReturnFieldArgs {
203 arg_fields: &[Arc::clone(&array_nullable_field)],
204 scalar_arguments: &[None],
205 })
206 .expect("should infer field");
207 assert!(result.is_nullable());
208 assert_eq!(result.data_type(), &expected_type);
209 }
210}