datafusion_spark/function/string/
elt.rs1use std::sync::Arc;
19
20use arrow::array::{
21 Array, ArrayRef, AsArray, PrimitiveArray, StringArray, StringBuilder,
22};
23use arrow::compute::{can_cast_types, cast};
24use arrow::datatypes::DataType::{Int64, Utf8};
25use arrow::datatypes::{DataType, Int64Type};
26use datafusion_common::cast::as_string_array;
27use datafusion_common::{DataFusionError, Result, plan_datafusion_err};
28use datafusion_expr::{
29 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
30};
31use datafusion_functions::utils::make_scalar_function;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
34pub struct SparkElt {
35 signature: Signature,
36}
37
38impl Default for SparkElt {
39 fn default() -> Self {
40 SparkElt::new()
41 }
42}
43
44impl SparkElt {
45 pub fn new() -> Self {
46 Self {
47 signature: Signature::user_defined(Volatility::Immutable),
48 }
49 }
50}
51
52impl ScalarUDFImpl for SparkElt {
53 fn name(&self) -> &str {
54 "elt"
55 }
56
57 fn signature(&self) -> &Signature {
58 &self.signature
59 }
60
61 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
62 Ok(Utf8)
63 }
64
65 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
66 make_scalar_function(elt, vec![])(&args.args)
67 }
68
69 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
70 let length = arg_types.len();
71 if length < 2 {
72 plan_datafusion_err!(
73 "ELT function expects at least 2 arguments: index, value1"
74 );
75 }
76
77 let idx_dt: &DataType = &arg_types[0];
78 if *idx_dt != Int64 && !can_cast_types(idx_dt, &Int64) {
79 return Err(DataFusionError::Plan(format!(
80 "ELT index must be Int64 (or castable to Int64), got {idx_dt:?}"
81 )));
82 }
83 let mut coerced = Vec::with_capacity(arg_types.len());
84 coerced.push(Int64);
85
86 for _ in 1..length {
87 coerced.push(Utf8);
88 }
89
90 Ok(coerced)
91 }
92}
93
94fn elt(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
95 let n_rows = args[0].len();
96
97 let idx: &PrimitiveArray<Int64Type> =
98 args[0].as_primitive_opt::<Int64Type>().ok_or_else(|| {
99 DataFusionError::Plan(format!(
100 "ELT function: first argument must be Int64 (got {:?})",
101 args[0].data_type()
102 ))
103 })?;
104
105 let num_values = args.len() - 1;
106 let mut cols: Vec<Arc<StringArray>> = Vec::with_capacity(num_values);
107 for a in args.iter().skip(1) {
108 let casted = cast(a, &Utf8)?;
109 let sa = as_string_array(&casted)?;
110 cols.push(Arc::new(sa.clone()));
111 }
112
113 let mut builder = StringBuilder::new();
114
115 for i in 0..n_rows {
116 if idx.is_null(i) {
117 builder.append_null();
118 continue;
119 }
120
121 let index = idx.value(i);
122
123 if index < 1 || (index as usize) > num_values {
127 builder.append_null();
128 continue;
129 }
130
131 let value_idx = (index as usize) - 1;
132 let col = &cols[value_idx];
133
134 if col.is_null(i) {
135 builder.append_null();
136 } else {
137 builder.append_value(col.value(i));
138 }
139 }
140
141 Ok(Arc::new(builder.finish()) as ArrayRef)
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use arrow::array::Int64Array;
148
149 fn run_elt_arrays(arrs: Vec<ArrayRef>) -> Result<Arc<StringArray>> {
150 let arr = elt(&arrs)?;
151 let string_array = arr
152 .as_any()
153 .downcast_ref::<StringArray>()
154 .ok_or_else(|| DataFusionError::Internal("expected Utf8".into()))?;
155 Ok(Arc::new(string_array.clone()))
156 }
157
158 #[test]
159 fn elt_utf8_basic() -> Result<()> {
160 let idx = Arc::new(Int64Array::from(vec![
161 Some(1),
162 Some(2),
163 Some(3),
164 Some(4),
165 Some(0),
166 None,
167 ]));
168 let v1 = Arc::new(StringArray::from(vec![
169 Some("a1"),
170 Some("a2"),
171 Some("a3"),
172 Some("a4"),
173 Some("a5"),
174 Some("a6"),
175 ]));
176 let v2 = Arc::new(StringArray::from(vec![
177 Some("b1"),
178 Some("b2"),
179 None,
180 Some("b4"),
181 Some("b5"),
182 Some("b6"),
183 ]));
184 let v3 = Arc::new(StringArray::from(vec![
185 Some("c1"),
186 Some("c2"),
187 Some("c3"),
188 None,
189 Some("c5"),
190 Some("c6"),
191 ]));
192
193 let out = run_elt_arrays(vec![idx, v1, v2, v3])?;
194 assert_eq!(out.len(), 6);
195 assert_eq!(out.value(0), "a1");
196 assert_eq!(out.value(1), "b2");
197 assert_eq!(out.value(2), "c3");
198 assert!(out.is_null(3));
199 assert!(out.is_null(4));
200 assert!(out.is_null(5));
201 Ok(())
202 }
203
204 #[test]
205 fn elt_int64_basic() -> Result<()> {
206 let idx = Arc::new(Int64Array::from(vec![Some(2), Some(1), Some(2)]));
207 let v1 = Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)]));
208 let v2 = Arc::new(Int64Array::from(vec![Some(100), None, Some(300)]));
209
210 let out = run_elt_arrays(vec![idx, v1, v2])?;
211 assert_eq!(out.len(), 3);
212 assert_eq!(out.value(0), "100");
213 assert_eq!(out.value(1), "20");
214 assert_eq!(out.value(2), "300");
215 Ok(())
216 }
217
218 #[test]
219 fn elt_out_of_range_all_null() -> Result<()> {
220 let idx = Arc::new(Int64Array::from(vec![Some(5), Some(-1), Some(0)]));
221 let v1 = Arc::new(StringArray::from(vec![Some("x"), Some("y"), Some("z")]));
222 let v2 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")]));
223
224 let out = run_elt_arrays(vec![idx, v1, v2])?;
225 assert!(out.is_null(0));
226 assert!(out.is_null(1));
227 assert!(out.is_null(2));
228 Ok(())
229 }
230
231 #[test]
232 fn elt_utf8_returns_utf8() -> Result<()> {
233 let idx = Arc::new(Int64Array::from(vec![Some(1)]));
234 let v1 = Arc::new(StringArray::from(vec![Some("scala")]));
235 let v2 = Arc::new(StringArray::from(vec![Some("java")]));
236
237 let out = run_elt_arrays(vec![idx, v1, v2])?;
238 assert_eq!(out.data_type(), &Utf8);
239 Ok(())
240 }
241}