datafusion_spark/function/map/
map_from_arrays.rs1use crate::function::map::utils::{
19 get_element_type, get_list_offsets, get_list_values,
20 map_from_keys_values_offsets_nulls, map_type_from_key_value_types,
21};
22use arrow::array::{Array, ArrayRef, NullArray};
23use arrow::compute::kernels::cast;
24use arrow::datatypes::{DataType, Field, FieldRef};
25use datafusion_common::utils::take_function_args;
26use datafusion_common::{Result, internal_err};
27use datafusion_expr::{
28 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32use std::sync::Arc;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
37pub struct MapFromArrays {
38 signature: Signature,
39}
40
41impl Default for MapFromArrays {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl MapFromArrays {
48 pub fn new() -> Self {
49 Self {
50 signature: Signature::any(2, Volatility::Immutable),
51 }
52 }
53}
54
55impl ScalarUDFImpl for MapFromArrays {
56 fn name(&self) -> &str {
57 "map_from_arrays"
58 }
59
60 fn signature(&self) -> &Signature {
61 &self.signature
62 }
63
64 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
65 internal_err!("return_field_from_args should be used instead")
66 }
67
68 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
69 let [keys_field, values_field] = args.arg_fields else {
70 return internal_err!("map_from_arrays expects exactly 2 arguments");
71 };
72
73 let map_type = map_type_from_key_value_types(
74 get_element_type(keys_field.data_type())?,
75 get_element_type(values_field.data_type())?,
76 );
77 let nullable = keys_field.is_nullable() || values_field.is_nullable();
80 Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
81 }
82
83 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84 make_scalar_function(map_from_arrays_inner, vec![])(&args.args)
85 }
86}
87
88fn map_from_arrays_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
89 let [keys, values] = take_function_args("map_from_arrays", args)?;
90
91 if *keys.data_type() == DataType::Null || *values.data_type() == DataType::Null {
92 return Ok(cast(
93 &NullArray::new(keys.len()),
94 &map_type_from_key_value_types(
95 get_element_type(keys.data_type())?,
96 get_element_type(values.data_type())?,
97 ),
98 )?);
99 }
100
101 map_from_keys_values_offsets_nulls(
102 get_list_values(keys)?,
103 get_list_values(values)?,
104 &get_list_offsets(keys)?,
105 &get_list_offsets(values)?,
106 keys.nulls(),
107 values.nulls(),
108 )
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_map_from_arrays_nullability_and_type() {
117 let func = MapFromArrays::new();
118
119 let keys_field: FieldRef = Arc::new(Field::new(
120 "keys",
121 DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
122 false,
123 ));
124 let values_field: FieldRef = Arc::new(Field::new(
125 "values",
126 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
127 false,
128 ));
129
130 let out = func
131 .return_field_from_args(ReturnFieldArgs {
132 arg_fields: &[Arc::clone(&keys_field), Arc::clone(&values_field)],
133 scalar_arguments: &[None, None],
134 })
135 .expect("return_field_from_args should succeed");
136
137 let expected_type =
138 map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8);
139 assert_eq!(out.data_type(), &expected_type);
140 assert!(
141 !out.is_nullable(),
142 "map_from_arrays should be non-nullable when both inputs are non-nullable"
143 );
144
145 let nullable_keys: FieldRef = Arc::new(Field::new(
146 "keys",
147 DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
148 true,
149 ));
150
151 let out_nullable = func
152 .return_field_from_args(ReturnFieldArgs {
153 arg_fields: &[nullable_keys, values_field],
154 scalar_arguments: &[None, None],
155 })
156 .expect("return_field_from_args should succeed");
157
158 assert!(
159 out_nullable.is_nullable(),
160 "map_from_arrays should be nullable when any input is nullable"
161 );
162 }
163}