datafusion_spark/function/string/
elt.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder,
23};
24use arrow::compute::{can_cast_types, cast};
25use arrow::datatypes::DataType::{Int64, Utf8};
26use arrow::datatypes::{DataType, Int64Type};
27use datafusion_common::cast::as_string_array;
28use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
29use datafusion_expr::{
30    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkElt {
36    signature: Signature,
37}
38
39impl Default for SparkElt {
40    fn default() -> Self {
41        SparkElt::new()
42    }
43}
44
45impl SparkElt {
46    pub fn new() -> Self {
47        Self {
48            signature: Signature::user_defined(Volatility::Immutable),
49        }
50    }
51}
52
53impl ScalarUDFImpl for SparkElt {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn name(&self) -> &str {
59        "elt"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67        Ok(Utf8)
68    }
69
70    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71        make_scalar_function(elt, vec![])(&args.args)
72    }
73
74    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
75        let length = arg_types.len();
76        if length < 2 {
77            plan_datafusion_err!(
78                "ELT function expects at least 2 arguments: index, value1"
79            );
80        }
81
82        let idx_dt: &DataType = &arg_types[0];
83        if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) {
84            return Err(DataFusionError::Plan(format!(
85                "ELT index must be Int64 (or castable to Int64), got {idx_dt:?}"
86            )));
87        }
88        let mut coerced = Vec::with_capacity(arg_types.len());
89        coerced.push(Int64);
90
91        for _ in 1..length {
92            coerced.push(Utf8);
93        }
94
95        Ok(coerced)
96    }
97}
98
99fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
100    let n_rows = args[0].len();
101
102    let idx: &PrimitiveArray<Int64Type> =
103        args[0].as_primitive_opt::<Int64Type>().ok_or_else(|| {
104            DataFusionError::Plan(format!(
105                "ELT function: first argument must be Int64 (got {:?})",
106                args[0].data_type()
107            ))
108        })?;
109
110    let num_values = args.len() - 1;
111    let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(num_values);
112    for a in args.iter().skip(1) {
113        let casted = cast(a, &Utf8)?;
114        let sa = as_string_array(&casted)?;
115        cols.push(Arc::new(sa.clone()));
116    }
117
118    let mut builder = StringBuilder::new();
119
120    for i in 0..n_rows {
121        if idx.is_null(i) {
122            builder.append_null();
123            continue;
124        }
125
126        let index = idx.value(i);
127
128        // TODO: if spark.sql.ansi.enabled is true,
129        //  throw ArrayIndexOutOfBoundsException for invalid indices;
130        //  if false, return NULL instead (current behavior).
131        if index < 1 || (index as usize) > num_values {
132            builder.append_null();
133            continue;
134        }
135
136        let value_idx = (index as usize) - 1;
137        let col = &cols[value_idx];
138
139        if col.is_null(i) {
140            builder.append_null();
141        } else {
142            builder.append_value(col.value(i));
143        }
144    }
145
146    Ok(Arc::new(builder.finish()) as ArrayRef)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use arrow::array::Int64Array;
153    use datafusion_common::Result;
154
155    use arrow::array::{ArrayRef, StringArray};
156    use datafusion_common::DataFusionError;
157    use std::sync::Arc;
158
159    fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<Arc<StringArray>> {
160        let arr = elt(&arrs)?;
161        let string_array = arr
162            .as_any()
163            .downcast_ref::<StringArray>()
164            .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?;
165        Ok(Arc::new(string_array.clone()))
166    }
167
168    #[test]
169    fn elt_utf8_basic() -> Result<()> {
170        let idx = Arc::new(Int64Array::from(vec![
171            Some(1),
172            Some(2),
173            Some(3),
174            Some(4),
175            Some(0),
176            None,
177        ]));
178        let v1 = Arc::new(StringArray::from(vec![
179            Some("a1"),
180            Some("a2"),
181            Some("a3"),
182            Some("a4"),
183            Some("a5"),
184            Some("a6"),
185        ]));
186        let v2 = Arc::new(StringArray::from(vec![
187            Some("b1"),
188            Some("b2"),
189            None,
190            Some("b4"),
191            Some("b5"),
192            Some("b6"),
193        ]));
194        let v3 = Arc::new(StringArray::from(vec![
195            Some("c1"),
196            Some("c2"),
197            Some("c3"),
198            None,
199            Some("c5"),
200            Some("c6"),
201        ]));
202
203        let out = run_elt_arrays(vec![idx, v1, v2, v3])?;
204        assert_eq!(out.len(), 6);
205        assert_eq!(out.value(0), "a1");
206        assert_eq!(out.value(1), "b2");
207        assert_eq!(out.value(2), "c3");
208        assert!(out.is_null(3));
209        assert!(out.is_null(4));
210        assert!(out.is_null(5));
211        Ok(())
212    }
213
214    #[test]
215    fn elt_int64_basic() -> Result<()> {
216        let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)]));
217        let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)]));
218        let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)]));
219
220        let out = run_elt_arrays(vec![idx, v1, v2])?;
221        assert_eq!(out.len(), 3);
222        assert_eq!(out.value(0), "100");
223        assert_eq!(out.value(1), "20");
224        assert_eq!(out.value(2), "300");
225        Ok(())
226    }
227
228    #[test]
229    fn elt_out_of_range_all_null() -> Result<()> {
230        let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)]));
231        let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")]));
232        let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")]));
233
234        let out = run_elt_arrays(vec![idx, v1, v2])?;
235        assert!(out.is_null(0));
236        assert!(out.is_null(1));
237        assert!(out.is_null(2));
238        Ok(())
239    }
240
241    #[test]
242    fn elt_utf8_returns_utf8() -> Result<()> {
243        let idx = Arc::new(Int64Array::from(vec![Some(1)]));
244        let v1 = Arc::new(StringArray::from(vec![Some("scala")]));
245        let v2 = Arc::new(StringArray::from(vec![Some("java")]));
246
247        let out = run_elt_arrays(vec![idx, v1, v2])?;
248        assert_eq!(out.data_type(), &Utf8);
249        Ok(())
250    }
251}