datafusion_functions_nested/
distance.rs1use crate::utils::make_scalar_function;
21use arrow::array::{
22 Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25 DataType,
26 DataType::{FixedSizeList, LargeList, List, Null},
27};
28use datafusion_common::cast::{
29 as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30 as_int64_array,
31};
32use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion};
33use datafusion_common::{
34 exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result,
35};
36use datafusion_expr::{
37 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_functions::{downcast_arg, downcast_named_arg};
40use datafusion_macros::user_doc;
41use itertools::Itertools;
42use std::any::Any;
43use std::sync::Arc;
44
45make_udf_expr_and_func!(
46 ArrayDistance,
47 array_distance,
48 array,
49 "returns the Euclidean distance between two numeric arrays.",
50 array_distance_udf
51);
52
53#[user_doc(
54 doc_section(label = "Array Functions"),
55 description = "Returns the Euclidean distance between two input arrays of equal length.",
56 syntax_example = "array_distance(array1, array2)",
57 sql_example = r#"```sql
58> select array_distance([1, 2], [1, 4]);
59+------------------------------------+
60| array_distance(List([1,2], [1,4])) |
61+------------------------------------+
62| 2.0 |
63+------------------------------------+
64```"#,
65 argument(
66 name = "array1",
67 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
68 ),
69 argument(
70 name = "array2",
71 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
72 )
73)]
74#[derive(Debug)]
75pub struct ArrayDistance {
76 signature: Signature,
77 aliases: Vec<String>,
78}
79
80impl Default for ArrayDistance {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl ArrayDistance {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::user_defined(Volatility::Immutable),
90 aliases: vec!["list_distance".to_string()],
91 }
92 }
93}
94
95impl ScalarUDFImpl for ArrayDistance {
96 fn as_any(&self) -> &dyn Any {
97 self
98 }
99
100 fn name(&self) -> &str {
101 "array_distance"
102 }
103
104 fn signature(&self) -> &Signature {
105 &self.signature
106 }
107
108 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109 Ok(DataType::Float64)
110 }
111
112 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
113 let [_, _] = take_function_args(self.name(), arg_types)?;
114 let coercion = Some(&ListCoercion::FixedSizedListToList);
115 let arg_types = arg_types.iter().map(|arg_type| {
116 if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
117 Ok(coerced_type_with_base_type_only(
118 arg_type,
119 &DataType::Float64,
120 coercion,
121 ))
122 } else {
123 plan_err!("{} does not support type {arg_type}", self.name())
124 }
125 });
126
127 arg_types.try_collect()
128 }
129
130 fn invoke_with_args(
131 &self,
132 args: datafusion_expr::ScalarFunctionArgs,
133 ) -> Result<ColumnarValue> {
134 make_scalar_function(array_distance_inner)(&args.args)
135 }
136
137 fn aliases(&self) -> &[String] {
138 &self.aliases
139 }
140
141 fn documentation(&self) -> Option<&Documentation> {
142 self.doc()
143 }
144}
145
146pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
147 let [array1, array2] = take_function_args("array_distance", args)?;
148 match (array1.data_type(), array2.data_type()) {
149 (List(_), List(_)) => general_array_distance::<i32>(args),
150 (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
151 (arg_type1, arg_type2) => {
152 exec_err!("array_distance does not support types {arg_type1} and {arg_type2}")
153 }
154 }
155}
156
157fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
158 let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
159 let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
160
161 let result = list_array1
162 .iter()
163 .zip(list_array2.iter())
164 .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
165 .collect::<Result<Float64Array>>()?;
166
167 Ok(Arc::new(result) as ArrayRef)
168}
169
170fn compute_array_distance(
172 arr1: Option<ArrayRef>,
173 arr2: Option<ArrayRef>,
174) -> Result<Option<f64>> {
175 let value1 = match arr1 {
176 Some(arr) => arr,
177 None => return Ok(None),
178 };
179 let value2 = match arr2 {
180 Some(arr) => arr,
181 None => return Ok(None),
182 };
183
184 let mut value1 = value1;
185 let mut value2 = value2;
186
187 loop {
188 match value1.data_type() {
189 List(_) => {
190 if downcast_arg!(value1, ListArray).null_count() > 0 {
191 return Ok(None);
192 }
193 value1 = downcast_arg!(value1, ListArray).value(0);
194 }
195 LargeList(_) => {
196 if downcast_arg!(value1, LargeListArray).null_count() > 0 {
197 return Ok(None);
198 }
199 value1 = downcast_arg!(value1, LargeListArray).value(0);
200 }
201 _ => break,
202 }
203
204 match value2.data_type() {
205 List(_) => {
206 if downcast_arg!(value2, ListArray).null_count() > 0 {
207 return Ok(None);
208 }
209 value2 = downcast_arg!(value2, ListArray).value(0);
210 }
211 LargeList(_) => {
212 if downcast_arg!(value2, LargeListArray).null_count() > 0 {
213 return Ok(None);
214 }
215 value2 = downcast_arg!(value2, LargeListArray).value(0);
216 }
217 _ => break,
218 }
219 }
220
221 if value1.null_count() != 0 || value2.null_count() != 0 {
223 return Ok(None);
224 }
225
226 let values1 = convert_to_f64_array(&value1)?;
227 let values2 = convert_to_f64_array(&value2)?;
228
229 if values1.len() != values2.len() {
230 return exec_err!("Both arrays must have the same length");
231 }
232
233 let sum_squares: f64 = values1
234 .iter()
235 .zip(values2.iter())
236 .map(|(v1, v2)| {
237 let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
238 diff * diff
239 })
240 .sum();
241
242 Ok(Some(sum_squares.sqrt()))
243}
244
245fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
247 match array.data_type() {
248 DataType::Float64 => Ok(as_float64_array(array)?.clone()),
249 DataType::Float32 => {
250 let array = as_float32_array(array)?;
251 let converted: Float64Array =
252 array.iter().map(|v| v.map(|v| v as f64)).collect();
253 Ok(converted)
254 }
255 DataType::Int64 => {
256 let array = as_int64_array(array)?;
257 let converted: Float64Array =
258 array.iter().map(|v| v.map(|v| v as f64)).collect();
259 Ok(converted)
260 }
261 DataType::Int32 => {
262 let array = as_int32_array(array)?;
263 let converted: Float64Array =
264 array.iter().map(|v| v.map(|v| v as f64)).collect();
265 Ok(converted)
266 }
267 _ => exec_err!("Unsupported array type for conversion to Float64Array"),
268 }
269}