datafusion_spark/function/string/
elt.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder,
23};
24use arrow::compute::{can_cast_types, cast};
25use arrow::datatypes::DataType::{Int64, Utf8};
26use arrow::datatypes::{DataType, Int64Type};
27use datafusion_common::cast::as_string_array;
28use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
29use datafusion_expr::{
30 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
31};
32use datafusion_functions::utils::make_scalar_function;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkElt {
36 signature: Signature,
37}
38
39impl Default for SparkElt {
40 fn default() -> Self {
41 SparkElt::new()
42 }
43}
44
45impl SparkElt {
46 pub fn new() -> Self {
47 Self {
48 signature: Signature::user_defined(Volatility::Immutable),
49 }
50 }
51}
52
53impl ScalarUDFImpl for SparkElt {
54 fn as_any(&self) -> &dyn Any {
55 self
56 }
57
58 fn name(&self) -> &str {
59 "elt"
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
67 Ok(Utf8)
68 }
69
70 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71 make_scalar_function(elt, vec![])(&args.args)
72 }
73
74 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
75 let length = arg_types.len();
76 if length < 2 {
77 plan_datafusion_err!(
78 "ELT function expects at least 2 arguments: index, value1"
79 );
80 }
81
82 let idx_dt: &DataType = &arg_types[0];
83 if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) {
84 return Err(DataFusionError::Plan(format!(
85 "ELT index must be Int64 (or castable to Int64), got {idx_dt:?}"
86 )));
87 }
88 let mut coerced = Vec::with_capacity(arg_types.len());
89 coerced.push(Int64);
90
91 for _ in 1..length {
92 coerced.push(Utf8);
93 }
94
95 Ok(coerced)
96 }
97}
98
99fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
100 let n_rows = args[0].len();
101
102 let idx: &PrimitiveArray<Int64Type> =
103 args[0].as_primitive_opt::<Int64Type>().ok_or_else(|| {
104 DataFusionError::Plan(format!(
105 "ELT function: first argument must be Int64 (got {:?})",
106 args[0].data_type()
107 ))
108 })?;
109
110 let num_values = args.len() - 1;
111 let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(num_values);
112 for a in args.iter().skip(1) {
113 let casted = cast(a, &Utf8)?;
114 let sa = as_string_array(&casted)?;
115 cols.push(Arc::new(sa.clone()));
116 }
117
118 let mut builder = StringBuilder::new();
119
120 for i in 0..n_rows {
121 if idx.is_null(i) {
122 builder.append_null();
123 continue;
124 }
125
126 let index = idx.value(i);
127
128 if index < 1 || (index as usize) > num_values {
132 builder.append_null();
133 continue;
134 }
135
136 let value_idx = (index as usize) - 1;
137 let col = &cols[value_idx];
138
139 if col.is_null(i) {
140 builder.append_null();
141 } else {
142 builder.append_value(col.value(i));
143 }
144 }
145
146 Ok(Arc::new(builder.finish()) as ArrayRef)
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use arrow::array::Int64Array;
153 use datafusion_common::Result;
154
155 use arrow::array::{ArrayRef, StringArray};
156 use datafusion_common::DataFusionError;
157 use std::sync::Arc;
158
159 fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<Arc<StringArray>> {
160 let arr = elt(&arrs)?;
161 let string_array = arr
162 .as_any()
163 .downcast_ref::<StringArray>()
164 .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?;
165 Ok(Arc::new(string_array.clone()))
166 }
167
168 #[test]
169 fn elt_utf8_basic() -> Result<()> {
170 let idx = Arc::new(Int64Array::from(vec![
171 Some(1),
172 Some(2),
173 Some(3),
174 Some(4),
175 Some(0),
176 None,
177 ]));
178 let v1 = Arc::new(StringArray::from(vec![
179 Some("a1"),
180 Some("a2"),
181 Some("a3"),
182 Some("a4"),
183 Some("a5"),
184 Some("a6"),
185 ]));
186 let v2 = Arc::new(StringArray::from(vec![
187 Some("b1"),
188 Some("b2"),
189 None,
190 Some("b4"),
191 Some("b5"),
192 Some("b6"),
193 ]));
194 let v3 = Arc::new(StringArray::from(vec![
195 Some("c1"),
196 Some("c2"),
197 Some("c3"),
198 None,
199 Some("c5"),
200 Some("c6"),
201 ]));
202
203 let out = run_elt_arrays(vec![idx, v1, v2, v3])?;
204 assert_eq!(out.len(), 6);
205 assert_eq!(out.value(0), "a1");
206 assert_eq!(out.value(1), "b2");
207 assert_eq!(out.value(2), "c3");
208 assert!(out.is_null(3));
209 assert!(out.is_null(4));
210 assert!(out.is_null(5));
211 Ok(())
212 }
213
214 #[test]
215 fn elt_int64_basic() -> Result<()> {
216 let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)]));
217 let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)]));
218 let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)]));
219
220 let out = run_elt_arrays(vec![idx, v1, v2])?;
221 assert_eq!(out.len(), 3);
222 assert_eq!(out.value(0), "100");
223 assert_eq!(out.value(1), "20");
224 assert_eq!(out.value(2), "300");
225 Ok(())
226 }
227
228 #[test]
229 fn elt_out_of_range_all_null() -> Result<()> {
230 let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)]));
231 let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")]));
232 let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")]));
233
234 let out = run_elt_arrays(vec![idx, v1, v2])?;
235 assert!(out.is_null(0));
236 assert!(out.is_null(1));
237 assert!(out.is_null(2));
238 Ok(())
239 }
240
241 #[test]
242 fn elt_utf8_returns_utf8() -> Result<()> {
243 let idx = Arc::new(Int64Array::from(vec![Some(1)]));
244 let v1 = Arc::new(StringArray::from(vec![Some("scala")]));
245 let v2 = Arc::new(StringArray::from(vec![Some("java")]));
246
247 let out = run_elt_arrays(vec![idx, v1, v2])?;
248 assert_eq!(out.data_type(), &Utf8);
249 Ok(())
250 }
251}