Skip to main content

datafusion_spark/function/map/
str_to_map.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::collections::HashSet;
20use std::sync::Arc;
21
22use arrow::array::{
23    Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder,
24};
25use arrow::buffer::NullBuffer;
26use arrow::datatypes::{DataType, Field, FieldRef};
27use datafusion_common::cast::{
28    as_large_string_array, as_string_array, as_string_view_array,
29};
30use datafusion_common::{Result, exec_err, internal_err};
31use datafusion_expr::{
32    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33    TypeSignature, Volatility,
34};
35
36use crate::function::map::utils::map_type_from_key_value_types;
37
38const DEFAULT_PAIR_DELIM: &str = ",";
39const DEFAULT_KV_DELIM: &str = ":";
40
41/// Spark-compatible `str_to_map` expression
42/// <https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map>
43///
44/// Creates a map from a string by splitting on delimiters.
45/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map<String, String>
46///
47/// - text: The input string
48/// - pairDelim: Delimiter between key-value pairs (default: ',')
49/// - keyValueDelim: Delimiter between key and value (default: ':')
50///
51/// # Duplicate Key Handling
52/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys.
53/// See `spark.sql.mapKeyDedupPolicy`:
54/// <https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4502-L4511>
55///
56/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR.
57#[derive(Debug, PartialEq, Eq, Hash)]
58pub struct SparkStrToMap {
59    signature: Signature,
60}
61
62impl Default for SparkStrToMap {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl SparkStrToMap {
69    pub fn new() -> Self {
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    // str_to_map(text)
74                    TypeSignature::String(1),
75                    // str_to_map(text, pairDelim)
76                    TypeSignature::String(2),
77                    // str_to_map(text, pairDelim, keyValueDelim)
78                    TypeSignature::String(3),
79                ],
80                Volatility::Immutable,
81            ),
82        }
83    }
84}
85
86impl ScalarUDFImpl for SparkStrToMap {
87    fn as_any(&self) -> &dyn Any {
88        self
89    }
90
91    fn name(&self) -> &str {
92        "str_to_map"
93    }
94
95    fn signature(&self) -> &Signature {
96        &self.signature
97    }
98
99    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
100        internal_err!("return_field_from_args should be used instead")
101    }
102
103    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
104        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
105        let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8);
106        Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
107    }
108
109    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110        let arrays: Vec<ArrayRef> = ColumnarValue::values_to_arrays(&args.args)?;
111        let result = str_to_map_inner(&arrays)?;
112        Ok(ColumnarValue::Array(result))
113    }
114}
115
116fn str_to_map_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
117    match args.len() {
118        1 => match args[0].data_type() {
119            DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None),
120            DataType::LargeUtf8 => {
121                str_to_map_impl(as_large_string_array(&args[0])?, None, None)
122            }
123            DataType::Utf8View => {
124                str_to_map_impl(as_string_view_array(&args[0])?, None, None)
125            }
126            other => exec_err!(
127                "Unsupported data type {other:?} for str_to_map, \
128                expected Utf8, LargeUtf8, or Utf8View"
129            ),
130        },
131        2 => match (args[0].data_type(), args[1].data_type()) {
132            (DataType::Utf8, DataType::Utf8) => str_to_map_impl(
133                as_string_array(&args[0])?,
134                Some(as_string_array(&args[1])?),
135                None,
136            ),
137            (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl(
138                as_large_string_array(&args[0])?,
139                Some(as_large_string_array(&args[1])?),
140                None,
141            ),
142            (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl(
143                as_string_view_array(&args[0])?,
144                Some(as_string_view_array(&args[1])?),
145                None,
146            ),
147            (t1, t2) => exec_err!(
148                "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \
149                expected matching Utf8, LargeUtf8, or Utf8View"
150            ),
151        },
152        3 => match (
153            args[0].data_type(),
154            args[1].data_type(),
155            args[2].data_type(),
156        ) {
157            (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl(
158                as_string_array(&args[0])?,
159                Some(as_string_array(&args[1])?),
160                Some(as_string_array(&args[2])?),
161            ),
162            (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
163                str_to_map_impl(
164                    as_large_string_array(&args[0])?,
165                    Some(as_large_string_array(&args[1])?),
166                    Some(as_large_string_array(&args[2])?),
167                )
168            }
169            (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
170                str_to_map_impl(
171                    as_string_view_array(&args[0])?,
172                    Some(as_string_view_array(&args[1])?),
173                    Some(as_string_view_array(&args[2])?),
174                )
175            }
176            (t1, t2, t3) => exec_err!(
177                "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \
178                expected matching Utf8, LargeUtf8, or Utf8View"
179            ),
180        },
181        n => exec_err!("str_to_map expects 1-3 arguments, got {n}"),
182    }
183}
184
185fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>(
186    text_array: V,
187    pair_delim_array: Option<V>,
188    kv_delim_array: Option<V>,
189) -> Result<ArrayRef> {
190    let num_rows = text_array.len();
191
192    // Precompute combined null buffer from all input arrays.
193    // NullBuffer::union performs a bitmap-level AND, which is more efficient
194    // than checking per-row nullability inline.
195    let text_nulls = text_array.nulls().cloned();
196    let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned());
197    let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned());
198    let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()]
199        .into_iter()
200        .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
201
202    // Use field names matching map_type_from_key_value_types: "key" and "value"
203    let field_names = MapFieldNames {
204        entry: "entries".to_string(),
205        key: "key".to_string(),
206        value: "value".to_string(),
207    };
208    let mut map_builder = MapBuilder::new(
209        Some(field_names),
210        StringBuilder::new(),
211        StringBuilder::new(),
212    );
213
214    let mut seen_keys = HashSet::new();
215    for row_idx in 0..num_rows {
216        if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) {
217            map_builder.append(false)?;
218            continue;
219        }
220
221        // Per-row delimiter extraction
222        let pair_delim =
223            pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx));
224        let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx));
225
226        let text = text_array.value(row_idx);
227        if text.is_empty() {
228            // Empty string -> map with empty key and NULL value (Spark behavior)
229            map_builder.keys().append_value("");
230            map_builder.values().append_null();
231            map_builder.append(true)?;
232            continue;
233        }
234
235        seen_keys.clear();
236        for pair in text.split(pair_delim) {
237            if pair.is_empty() {
238                continue;
239            }
240
241            let mut kv_iter = pair.splitn(2, kv_delim);
242            let key = kv_iter.next().unwrap_or("");
243            let value = kv_iter.next();
244
245            // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config
246            // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default)
247            if !seen_keys.insert(key) {
248                return exec_err!(
249                    "Duplicate map key '{key}' was found, please check the input data. \
250                    If you want to remove the duplicated keys, you can set \
251                    spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \
252                    inserted at last takes precedence."
253                );
254            }
255
256            map_builder.keys().append_value(key);
257            match value {
258                Some(v) => map_builder.values().append_value(v),
259                None => map_builder.values().append_null(),
260            }
261        }
262        map_builder.append(true)?;
263    }
264
265    Ok(Arc::new(map_builder.finish()))
266}