datafusion_spark/function/map/
str_to_map.rs1use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use arrow::array::{
23 Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder,
24};
25use arrow::buffer::NullBuffer;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::cast::{
28 as_large_string_array, as_string_array, as_string_view_array,
29};
30use datafusion_common::{Result, exec_err, internal_err};
31use datafusion_expr::{
32 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33 TypeSignature, Volatility,
34};
35
36use crate::function::map::utils::map_type_from_key_value_types;
37
38const DEFAULT_PAIR_DELIM: &str = ",";
39const DEFAULT_KV_DELIM: &str = ":";
40
41#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct SparkStrToMap {
59 signature: Signature,
60}
61
62impl Default for SparkStrToMap {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl SparkStrToMap {
69 pub fn new() -> Self {
70 Self {
71 signature: Signature::one_of(
72 vec![
73 TypeSignature::String(1),
75 TypeSignature::String(2),
77 TypeSignature::String(3),
79 ],
80 Volatility::Immutable,
81 ),
82 }
83 }
84}
85
86impl ScalarUDFImpl for SparkStrToMap {
87 fn as_any(&self) -> &dyn Any {
88 self
89 }
90
91 fn name(&self) -> &str {
92 "str_to_map"
93 }
94
95 fn signature(&self) -> &Signature {
96 &self.signature
97 }
98
99 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
100 internal_err!("return_field_from_args should be used instead")
101 }
102
103 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
104 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
105 let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8);
106 Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
107 }
108
109 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110 let arrays: Vec<ArrayRef> = ColumnarValue::values_to_arrays(&args.args)?;
111 let result = str_to_map_inner(&arrays)?;
112 Ok(ColumnarValue::Array(result))
113 }
114}
115
116fn str_to_map_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
117 match args.len() {
118 1 => match args[0].data_type() {
119 DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None),
120 DataType::LargeUtf8 => {
121 str_to_map_impl(as_large_string_array(&args[0])?, None, None)
122 }
123 DataType::Utf8View => {
124 str_to_map_impl(as_string_view_array(&args[0])?, None, None)
125 }
126 other => exec_err!(
127 "Unsupported data type {other:?} for str_to_map, \
128 expected Utf8, LargeUtf8, or Utf8View"
129 ),
130 },
131 2 => match (args[0].data_type(), args[1].data_type()) {
132 (DataType::Utf8, DataType::Utf8) => str_to_map_impl(
133 as_string_array(&args[0])?,
134 Some(as_string_array(&args[1])?),
135 None,
136 ),
137 (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl(
138 as_large_string_array(&args[0])?,
139 Some(as_large_string_array(&args[1])?),
140 None,
141 ),
142 (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl(
143 as_string_view_array(&args[0])?,
144 Some(as_string_view_array(&args[1])?),
145 None,
146 ),
147 (t1, t2) => exec_err!(
148 "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \
149 expected matching Utf8, LargeUtf8, or Utf8View"
150 ),
151 },
152 3 => match (
153 args[0].data_type(),
154 args[1].data_type(),
155 args[2].data_type(),
156 ) {
157 (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl(
158 as_string_array(&args[0])?,
159 Some(as_string_array(&args[1])?),
160 Some(as_string_array(&args[2])?),
161 ),
162 (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
163 str_to_map_impl(
164 as_large_string_array(&args[0])?,
165 Some(as_large_string_array(&args[1])?),
166 Some(as_large_string_array(&args[2])?),
167 )
168 }
169 (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
170 str_to_map_impl(
171 as_string_view_array(&args[0])?,
172 Some(as_string_view_array(&args[1])?),
173 Some(as_string_view_array(&args[2])?),
174 )
175 }
176 (t1, t2, t3) => exec_err!(
177 "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \
178 expected matching Utf8, LargeUtf8, or Utf8View"
179 ),
180 },
181 n => exec_err!("str_to_map expects 1-3 arguments, got {n}"),
182 }
183}
184
185fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>(
186 text_array: V,
187 pair_delim_array: Option<V>,
188 kv_delim_array: Option<V>,
189) -> Result<ArrayRef> {
190 let num_rows = text_array.len();
191
192 let text_nulls = text_array.nulls().cloned();
196 let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned());
197 let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned());
198 let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()]
199 .into_iter()
200 .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
201
202 let field_names = MapFieldNames {
204 entry: "entries".to_string(),
205 key: "key".to_string(),
206 value: "value".to_string(),
207 };
208 let mut map_builder = MapBuilder::new(
209 Some(field_names),
210 StringBuilder::new(),
211 StringBuilder::new(),
212 );
213
214 let mut seen_keys = HashSet::new();
215 for row_idx in 0..num_rows {
216 if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) {
217 map_builder.append(false)?;
218 continue;
219 }
220
221 let pair_delim =
223 pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx));
224 let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx));
225
226 let text = text_array.value(row_idx);
227 if text.is_empty() {
228 map_builder.keys().append_value("");
230 map_builder.values().append_null();
231 map_builder.append(true)?;
232 continue;
233 }
234
235 seen_keys.clear();
236 for pair in text.split(pair_delim) {
237 if pair.is_empty() {
238 continue;
239 }
240
241 let mut kv_iter = pair.splitn(2, kv_delim);
242 let key = kv_iter.next().unwrap_or("");
243 let value = kv_iter.next();
244
245 if !seen_keys.insert(key) {
248 return exec_err!(
249 "Duplicate map key '{key}' was found, please check the input data. \
250 If you want to remove the duplicated keys, you can set \
251 spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \
252 inserted at last takes precedence."
253 );
254 }
255
256 map_builder.keys().append_value(key);
257 match value {
258 Some(v) => map_builder.values().append_value(v),
259 None => map_builder.values().append_null(),
260 }
261 }
262 map_builder.append(true)?;
263 }
264
265 Ok(Arc::new(map_builder.finish()))
266}