Skip to main content

datafusion_spark/function/string/
elt.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{
21    Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder,
22};
23use arrow::compute::{can_cast_types, cast};
24use arrow::datatypes::DataType::{Int64, Utf8};
25use arrow::datatypes::{DataType, Int64Type};
26use datafusion_common::cast::as_string_array;
27use datafusion_common::{DataFusionError, Result, plan_datafusion_err};
28use datafusion_expr::{
29    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkElt {
35    signature: Signature,
36}
37
38impl Default for SparkElt {
39    fn default() -> Self {
40        SparkElt::new()
41    }
42}
43
44impl SparkElt {
45    pub fn new() -> Self {
46        Self {
47            signature: Signature::user_defined(Volatility::Immutable),
48        }
49    }
50}
51
52impl ScalarUDFImpl for SparkElt {
53    fn name(&self) -> &str {
54        "elt"
55    }
56
57    fn signature(&self) -> &Signature {
58        &self.signature
59    }
60
61    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62        Ok(Utf8)
63    }
64
65    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
66        make_scalar_function(elt, vec![])(&args.args)
67    }
68
69    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
70        let length = arg_types.len();
71        if length < 2 {
72            plan_datafusion_err!(
73                "ELT function expects at least 2 arguments: index, value1"
74            );
75        }
76
77        let idx_dt: &DataType = &arg_types[0];
78        if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) {
79            return Err(DataFusionError::Plan(format!(
80                "ELT index must be Int64 (or castable to Int64), got {idx_dt:?}"
81            )));
82        }
83        let mut coerced = Vec::with_capacity(arg_types.len());
84        coerced.push(Int64);
85
86        for _ in 1..length {
87            coerced.push(Utf8);
88        }
89
90        Ok(coerced)
91    }
92}
93
94fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
95    let n_rows = args[0].len();
96
97    let idx: &PrimitiveArray<Int64Type> =
98        args[0].as_primitive_opt::<Int64Type>().ok_or_else(|| {
99            DataFusionError::Plan(format!(
100                "ELT function: first argument must be Int64 (got {:?})",
101                args[0].data_type()
102            ))
103        })?;
104
105    let num_values = args.len() - 1;
106    let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(num_values);
107    for a in args.iter().skip(1) {
108        let casted = cast(a, &Utf8)?;
109        let sa = as_string_array(&casted)?;
110        cols.push(Arc::new(sa.clone()));
111    }
112
113    let mut builder = StringBuilder::new();
114
115    for i in 0..n_rows {
116        if idx.is_null(i) {
117            builder.append_null();
118            continue;
119        }
120
121        let index = idx.value(i);
122
123        // TODO: if spark.sql.ansi.enabled is true,
124        //  throw ArrayIndexOutOfBoundsException for invalid indices;
125        //  if false, return NULL instead (current behavior).
126        if index < 1 || (index as usize) > num_values {
127            builder.append_null();
128            continue;
129        }
130
131        let value_idx = (index as usize) - 1;
132        let col = &cols[value_idx];
133
134        if col.is_null(i) {
135            builder.append_null();
136        } else {
137            builder.append_value(col.value(i));
138        }
139    }
140
141    Ok(Arc::new(builder.finish()) as ArrayRef)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use arrow::array::Int64Array;
148
149    fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<Arc<StringArray>> {
150        let arr = elt(&arrs)?;
151        let string_array = arr
152            .as_any()
153            .downcast_ref::<StringArray>()
154            .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?;
155        Ok(Arc::new(string_array.clone()))
156    }
157
158    #[test]
159    fn elt_utf8_basic() -> Result<()> {
160        let idx = Arc::new(Int64Array::from(vec![
161            Some(1),
162            Some(2),
163            Some(3),
164            Some(4),
165            Some(0),
166            None,
167        ]));
168        let v1 = Arc::new(StringArray::from(vec![
169            Some("a1"),
170            Some("a2"),
171            Some("a3"),
172            Some("a4"),
173            Some("a5"),
174            Some("a6"),
175        ]));
176        let v2 = Arc::new(StringArray::from(vec![
177            Some("b1"),
178            Some("b2"),
179            None,
180            Some("b4"),
181            Some("b5"),
182            Some("b6"),
183        ]));
184        let v3 = Arc::new(StringArray::from(vec![
185            Some("c1"),
186            Some("c2"),
187            Some("c3"),
188            None,
189            Some("c5"),
190            Some("c6"),
191        ]));
192
193        let out = run_elt_arrays(vec![idx, v1, v2, v3])?;
194        assert_eq!(out.len(), 6);
195        assert_eq!(out.value(0), "a1");
196        assert_eq!(out.value(1), "b2");
197        assert_eq!(out.value(2), "c3");
198        assert!(out.is_null(3));
199        assert!(out.is_null(4));
200        assert!(out.is_null(5));
201        Ok(())
202    }
203
204    #[test]
205    fn elt_int64_basic() -> Result<()> {
206        let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)]));
207        let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)]));
208        let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)]));
209
210        let out = run_elt_arrays(vec![idx, v1, v2])?;
211        assert_eq!(out.len(), 3);
212        assert_eq!(out.value(0), "100");
213        assert_eq!(out.value(1), "20");
214        assert_eq!(out.value(2), "300");
215        Ok(())
216    }
217
218    #[test]
219    fn elt_out_of_range_all_null() -> Result<()> {
220        let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)]));
221        let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")]));
222        let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")]));
223
224        let out = run_elt_arrays(vec![idx, v1, v2])?;
225        assert!(out.is_null(0));
226        assert!(out.is_null(1));
227        assert!(out.is_null(2));
228        Ok(())
229    }
230
231    #[test]
232    fn elt_utf8_returns_utf8() -> Result<()> {
233        let idx = Arc::new(Int64Array::from(vec![Some(1)]));
234        let v1 = Arc::new(StringArray::from(vec![Some("scala")]));
235        let v2 = Arc::new(StringArray::from(vec![Some("java")]));
236
237        let out = run_elt_arrays(vec![idx, v1, v2])?;
238        assert_eq!(out.data_type(), &Utf8);
239        Ok(())
240    }
241}