Skip to main content

datafusion_spark/function/map/
str_to_map.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashSet;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder,
23};
24use arrow::buffer::NullBuffer;
25use arrow::datatypes::{DataType, Field, FieldRef};
26use datafusion_common::cast::{
27    as_large_string_array, as_string_array, as_string_view_array,
28};
29use datafusion_common::{Result, exec_err, internal_err};
30use datafusion_expr::{
31    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    TypeSignature, Volatility,
33};
34
35use crate::function::map::utils::map_type_from_key_value_types;
36
37const DEFAULT_PAIR_DELIM: &str = ",";
38const DEFAULT_KV_DELIM: &str = ":";
39
40/// Spark-compatible `str_to_map` expression
41/// <https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map>
42///
43/// Creates a map from a string by splitting on delimiters.
44/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map<String, String>
45///
46/// - text: The input string
47/// - pairDelim: Delimiter between key-value pairs (default: ',')
48/// - keyValueDelim: Delimiter between key and value (default: ':')
49///
50/// # Duplicate Key Handling
51/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys.
52/// See `spark.sql.mapKeyDedupPolicy`:
53/// <https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4502-L4511>
54///
55/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR.
56#[derive(Debug, PartialEq, Eq, Hash)]
57pub struct SparkStrToMap {
58    signature: Signature,
59}
60
61impl Default for SparkStrToMap {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl SparkStrToMap {
68    pub fn new() -> Self {
69        Self {
70            signature: Signature::one_of(
71                vec![
72                    // str_to_map(text)
73                    TypeSignature::String(1),
74                    // str_to_map(text, pairDelim)
75                    TypeSignature::String(2),
76                    // str_to_map(text, pairDelim, keyValueDelim)
77                    TypeSignature::String(3),
78                ],
79                Volatility::Immutable,
80            ),
81        }
82    }
83}
84
85impl ScalarUDFImpl for SparkStrToMap {
86    fn name(&self) -> &str {
87        "str_to_map"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95        internal_err!("return_field_from_args should be used instead")
96    }
97
98    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
99        let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
100        let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8);
101        Ok(Arc::new(Field::new(self.name(), map_type, nullable)))
102    }
103
104    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105        let arrays: Vec<ArrayRef> = ColumnarValue::values_to_arrays(&args.args)?;
106        let result = str_to_map_inner(&arrays)?;
107        Ok(ColumnarValue::Array(result))
108    }
109}
110
111fn str_to_map_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
112    match args.len() {
113        1 => match args[0].data_type() {
114            DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None),
115            DataType::LargeUtf8 => {
116                str_to_map_impl(as_large_string_array(&args[0])?, None, None)
117            }
118            DataType::Utf8View => {
119                str_to_map_impl(as_string_view_array(&args[0])?, None, None)
120            }
121            other => exec_err!(
122                "Unsupported data type {other:?} for str_to_map, \
123                expected Utf8, LargeUtf8, or Utf8View"
124            ),
125        },
126        2 => match (args[0].data_type(), args[1].data_type()) {
127            (DataType::Utf8, DataType::Utf8) => str_to_map_impl(
128                as_string_array(&args[0])?,
129                Some(as_string_array(&args[1])?),
130                None,
131            ),
132            (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl(
133                as_large_string_array(&args[0])?,
134                Some(as_large_string_array(&args[1])?),
135                None,
136            ),
137            (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl(
138                as_string_view_array(&args[0])?,
139                Some(as_string_view_array(&args[1])?),
140                None,
141            ),
142            (t1, t2) => exec_err!(
143                "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \
144                expected matching Utf8, LargeUtf8, or Utf8View"
145            ),
146        },
147        3 => match (
148            args[0].data_type(),
149            args[1].data_type(),
150            args[2].data_type(),
151        ) {
152            (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl(
153                as_string_array(&args[0])?,
154                Some(as_string_array(&args[1])?),
155                Some(as_string_array(&args[2])?),
156            ),
157            (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => {
158                str_to_map_impl(
159                    as_large_string_array(&args[0])?,
160                    Some(as_large_string_array(&args[1])?),
161                    Some(as_large_string_array(&args[2])?),
162                )
163            }
164            (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => {
165                str_to_map_impl(
166                    as_string_view_array(&args[0])?,
167                    Some(as_string_view_array(&args[1])?),
168                    Some(as_string_view_array(&args[2])?),
169                )
170            }
171            (t1, t2, t3) => exec_err!(
172                "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \
173                expected matching Utf8, LargeUtf8, or Utf8View"
174            ),
175        },
176        n => exec_err!("str_to_map expects 1-3 arguments, got {n}"),
177    }
178}
179
180fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>(
181    text_array: V,
182    pair_delim_array: Option<V>,
183    kv_delim_array: Option<V>,
184) -> Result<ArrayRef> {
185    let num_rows = text_array.len();
186
187    // Precompute combined null buffer from all input arrays.
188    // NullBuffer::union_many performs a bitmap-level AND, which is more
189    // efficient than checking per-row nullability inline.
190    let combined_nulls = NullBuffer::union_many([
191        text_array.nulls(),
192        pair_delim_array.as_ref().and_then(|a| a.nulls()),
193        kv_delim_array.as_ref().and_then(|a| a.nulls()),
194    ]);
195
196    // Use field names matching map_type_from_key_value_types: "key" and "value"
197    let field_names = MapFieldNames {
198        entry: "entries".to_string(),
199        key: "key".to_string(),
200        value: "value".to_string(),
201    };
202    let mut map_builder = MapBuilder::new(
203        Some(field_names),
204        StringBuilder::new(),
205        StringBuilder::new(),
206    );
207
208    let mut seen_keys = HashSet::new();
209    for row_idx in 0..num_rows {
210        if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) {
211            map_builder.append(false)?;
212            continue;
213        }
214
215        // Per-row delimiter extraction
216        let pair_delim =
217            pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx));
218        let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx));
219
220        let text = text_array.value(row_idx);
221        if text.is_empty() {
222            // Empty string -> map with empty key and NULL value (Spark behavior)
223            map_builder.keys().append_value("");
224            map_builder.values().append_null();
225            map_builder.append(true)?;
226            continue;
227        }
228
229        seen_keys.clear();
230        for pair in text.split(pair_delim) {
231            if pair.is_empty() {
232                continue;
233            }
234
235            let mut kv_iter = pair.splitn(2, kv_delim);
236            let key = kv_iter.next().unwrap_or("");
237            let value = kv_iter.next();
238
239            // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config
240            // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default)
241            if !seen_keys.insert(key) {
242                return exec_err!(
243                    "Duplicate map key '{key}' was found, please check the input data. \
244                    If you want to remove the duplicated keys, you can set \
245                    spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \
246                    inserted at last takes precedence."
247                );
248            }
249
250            map_builder.keys().append_value(key);
251            match value {
252                Some(v) => map_builder.values().append_value(v),
253                None => map_builder.values().append_null(),
254            }
255        }
256        map_builder.append(true)?;
257    }
258
259    Ok(Arc::new(map_builder.finish()))
260}