datafusion_spark/function/map/
str_to_map.rs1use std::collections::HashSet;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder,
23};
24use arrow::buffer::NullBuffer;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::cast::{
27 as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 TypeSignature, Volatility,
33};
34
35use crate::function::map::utils::map_type_from_key_value_types;
36
37const DEFAULT_PAIR_DELIM: &str = ",";
38const DEFAULT_KV_DELIM: &str = ":";
39
40#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct SparkStrToMap {
58 signature: Signature,
59}
60
61impl Default for SparkStrToMap {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl SparkStrToMap {
68 pub fn new() -> Self {
69 Self {
70 signature: Signature::one_of(
71 vec![
72 TypeSignature::String(1),
74 TypeSignature::String(2),
76 TypeSignature::String(3),
78 ],
79 Volatility::Immutable,
80 ),
81 }
82 }
83}
84
85impl ScalarUDFImpl for SparkStrToMap {
86 fn name(&self) -> &str {
87 "str_to_map"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95 internal_err!("return_field_from_args should be used instead")
96 }
97
98 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
99 let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
100 let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8);
101 Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
102 }
103
104 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105 let arrays: Vec<ArrayRef> = ColumnarValue::values_to_arrays(&args.args)?;
106 let result = str_to_map_inner(&arrays)?;
107 Ok(ColumnarValue::Array(result))
108 }
109}
110
111fn str_to_map_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
112 match args.len() {
113 1 => match args[0].data_type() {
114 DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None),
115 DataType::LargeUtf8 => {
116 str_to_map_impl(as_large_string_array(&args[0])?, None, None)
117 }
118 DataType::Utf8View => {
119 str_to_map_impl(as_string_view_array(&args[0])?, None, None)
120 }
121 other => exec_err!(
122 "Unsupported data type {other:?} for str_to_map, \
123 expected Utf8, LargeUtf8, or Utf8View"
124 ),
125 },
126 2 => match (args[0].data_type(), args[1].data_type()) {
127 (DataType::Utf8, DataType::Utf8) => str_to_map_impl(
128 as_string_array(&args[0])?,
129 Some(as_string_array(&args[1])?),
130 None,
131 ),
132 (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl(
133 as_large_string_array(&args[0])?,
134 Some(as_large_string_array(&args[1])?),
135 None,
136 ),
137 (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl(
138 as_string_view_array(&args[0])?,
139 Some(as_string_view_array(&args[1])?),
140 None,
141 ),
142 (t1, t2) => exec_err!(
143 "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \
144 expected matching Utf8, LargeUtf8, or Utf8View"
145 ),
146 },
147 3 => match (
148 args[0].data_type(),
149 args[1].data_type(),
150 args[2].data_type(),
151 ) {
152 (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl(
153 as_string_array(&args[0])?,
154 Some(as_string_array(&args[1])?),
155 Some(as_string_array(&args[2])?),
156 ),
157 (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
158 str_to_map_impl(
159 as_large_string_array(&args[0])?,
160 Some(as_large_string_array(&args[1])?),
161 Some(as_large_string_array(&args[2])?),
162 )
163 }
164 (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
165 str_to_map_impl(
166 as_string_view_array(&args[0])?,
167 Some(as_string_view_array(&args[1])?),
168 Some(as_string_view_array(&args[2])?),
169 )
170 }
171 (t1, t2, t3) => exec_err!(
172 "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \
173 expected matching Utf8, LargeUtf8, or Utf8View"
174 ),
175 },
176 n => exec_err!("str_to_map expects 1-3 arguments, got {n}"),
177 }
178}
179
180fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>(
181 text_array: V,
182 pair_delim_array: Option<V>,
183 kv_delim_array: Option<V>,
184) -> Result<ArrayRef> {
185 let num_rows = text_array.len();
186
187 let combined_nulls = NullBuffer::union_many([
191 text_array.nulls(),
192 pair_delim_array.as_ref().and_then(|a| a.nulls()),
193 kv_delim_array.as_ref().and_then(|a| a.nulls()),
194 ]);
195
196 let field_names = MapFieldNames {
198 entry: "entries".to_string(),
199 key: "key".to_string(),
200 value: "value".to_string(),
201 };
202 let mut map_builder = MapBuilder::new(
203 Some(field_names),
204 StringBuilder::new(),
205 StringBuilder::new(),
206 );
207
208 let mut seen_keys = HashSet::new();
209 for row_idx in 0..num_rows {
210 if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) {
211 map_builder.append(false)?;
212 continue;
213 }
214
215 let pair_delim =
217 pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx));
218 let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx));
219
220 let text = text_array.value(row_idx);
221 if text.is_empty() {
222 map_builder.keys().append_value("");
224 map_builder.values().append_null();
225 map_builder.append(true)?;
226 continue;
227 }
228
229 seen_keys.clear();
230 for pair in text.split(pair_delim) {
231 if pair.is_empty() {
232 continue;
233 }
234
235 let mut kv_iter = pair.splitn(2, kv_delim);
236 let key = kv_iter.next().unwrap_or("");
237 let value = kv_iter.next();
238
239 if !seen_keys.insert(key) {
242 return exec_err!(
243 "Duplicate map key '{key}' was found, please check the input data. \
244 If you want to remove the duplicated keys, you can set \
245 spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \
246 inserted at last takes precedence."
247 );
248 }
249
250 map_builder.keys().append_value(key);
251 match value {
252 Some(v) => map_builder.values().append_value(v),
253 None => map_builder.values().append_null(),
254 }
255 }
256 map_builder.append(true)?;
257 }
258
259 Ok(Arc::new(map_builder.finish()))
260}