datafusion_functions_nested/
distance.rs1use crate::utils::make_scalar_function;
21use arrow::array::{
22 Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25 DataType,
26 DataType::{FixedSizeList, Float64, LargeList, List},
27};
28use datafusion_common::cast::{
29 as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30 as_int64_array,
31};
32use datafusion_common::utils::coerced_fixed_size_list_to_list;
33use datafusion_common::{
34 exec_err, internal_datafusion_err, utils::take_function_args, Result,
35};
36use datafusion_expr::{
37 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_functions::{downcast_arg, downcast_named_arg};
40use datafusion_macros::user_doc;
41use std::any::Any;
42use std::sync::Arc;
43
44make_udf_expr_and_func!(
45 ArrayDistance,
46 array_distance,
47 array,
48 "returns the Euclidean distance between two numeric arrays.",
49 array_distance_udf
50);
51
52#[user_doc(
53 doc_section(label = "Array Functions"),
54 description = "Returns the Euclidean distance between two input arrays of equal length.",
55 syntax_example = "array_distance(array1, array2)",
56 sql_example = r#"```sql
57> select array_distance([1, 2], [1, 4]);
58+------------------------------------+
59| array_distance(List([1,2], [1,4])) |
60+------------------------------------+
61| 2.0 |
62+------------------------------------+
63```"#,
64 argument(
65 name = "array1",
66 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
67 ),
68 argument(
69 name = "array2",
70 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
71 )
72)]
73#[derive(Debug)]
74pub struct ArrayDistance {
75 signature: Signature,
76 aliases: Vec<String>,
77}
78
79impl Default for ArrayDistance {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl ArrayDistance {
86 pub fn new() -> Self {
87 Self {
88 signature: Signature::user_defined(Volatility::Immutable),
89 aliases: vec!["list_distance".to_string()],
90 }
91 }
92}
93
94impl ScalarUDFImpl for ArrayDistance {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "array_distance"
101 }
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108 match arg_types[0] {
109 List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64),
110 _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
111 }
112 }
113
114 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
115 let [_, _] = take_function_args(self.name(), arg_types)?;
116 let mut result = Vec::new();
117 for arg_type in arg_types {
118 match arg_type {
119 List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)),
120 _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
121 }
122 }
123
124 Ok(result)
125 }
126
127 fn invoke_with_args(
128 &self,
129 args: datafusion_expr::ScalarFunctionArgs,
130 ) -> Result<ColumnarValue> {
131 make_scalar_function(array_distance_inner)(&args.args)
132 }
133
134 fn aliases(&self) -> &[String] {
135 &self.aliases
136 }
137
138 fn documentation(&self) -> Option<&Documentation> {
139 self.doc()
140 }
141}
142
143pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
144 let [array1, array2] = take_function_args("array_distance", args)?;
145
146 match (&array1.data_type(), &array2.data_type()) {
147 (List(_), List(_)) => general_array_distance::<i32>(args),
148 (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
149 (array_type1, array_type2) => {
150 exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'")
151 }
152 }
153}
154
155fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
156 let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
157 let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
158
159 let result = list_array1
160 .iter()
161 .zip(list_array2.iter())
162 .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
163 .collect::<Result<Float64Array>>()?;
164
165 Ok(Arc::new(result) as ArrayRef)
166}
167
168fn compute_array_distance(
170 arr1: Option<ArrayRef>,
171 arr2: Option<ArrayRef>,
172) -> Result<Option<f64>> {
173 let value1 = match arr1 {
174 Some(arr) => arr,
175 None => return Ok(None),
176 };
177 let value2 = match arr2 {
178 Some(arr) => arr,
179 None => return Ok(None),
180 };
181
182 let mut value1 = value1;
183 let mut value2 = value2;
184
185 loop {
186 match value1.data_type() {
187 List(_) => {
188 if downcast_arg!(value1, ListArray).null_count() > 0 {
189 return Ok(None);
190 }
191 value1 = downcast_arg!(value1, ListArray).value(0);
192 }
193 LargeList(_) => {
194 if downcast_arg!(value1, LargeListArray).null_count() > 0 {
195 return Ok(None);
196 }
197 value1 = downcast_arg!(value1, LargeListArray).value(0);
198 }
199 _ => break,
200 }
201
202 match value2.data_type() {
203 List(_) => {
204 if downcast_arg!(value2, ListArray).null_count() > 0 {
205 return Ok(None);
206 }
207 value2 = downcast_arg!(value2, ListArray).value(0);
208 }
209 LargeList(_) => {
210 if downcast_arg!(value2, LargeListArray).null_count() > 0 {
211 return Ok(None);
212 }
213 value2 = downcast_arg!(value2, LargeListArray).value(0);
214 }
215 _ => break,
216 }
217 }
218
219 if value1.null_count() != 0 || value2.null_count() != 0 {
221 return Ok(None);
222 }
223
224 let values1 = convert_to_f64_array(&value1)?;
225 let values2 = convert_to_f64_array(&value2)?;
226
227 if values1.len() != values2.len() {
228 return exec_err!("Both arrays must have the same length");
229 }
230
231 let sum_squares: f64 = values1
232 .iter()
233 .zip(values2.iter())
234 .map(|(v1, v2)| {
235 let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
236 diff * diff
237 })
238 .sum();
239
240 Ok(Some(sum_squares.sqrt()))
241}
242
243fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
245 match array.data_type() {
246 Float64 => Ok(as_float64_array(array)?.clone()),
247 DataType::Float32 => {
248 let array = as_float32_array(array)?;
249 let converted: Float64Array =
250 array.iter().map(|v| v.map(|v| v as f64)).collect();
251 Ok(converted)
252 }
253 DataType::Int64 => {
254 let array = as_int64_array(array)?;
255 let converted: Float64Array =
256 array.iter().map(|v| v.map(|v| v as f64)).collect();
257 Ok(converted)
258 }
259 DataType::Int32 => {
260 let array = as_int32_array(array)?;
261 let converted: Float64Array =
262 array.iter().map(|v| v.map(|v| v as f64)).collect();
263 Ok(converted)
264 }
265 _ => exec_err!("Unsupported array type for conversion to Float64Array"),
266 }
267}