datafusion_functions_nested/
distance.rs1use crate::utils::make_scalar_function;
21use arrow::array::{
22 Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25 DataType,
26 DataType::{FixedSizeList, LargeList, List, Null},
27};
28use datafusion_common::cast::{
29 as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30 as_int64_array,
31};
32use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
33use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args};
34use datafusion_expr::{
35 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36 Volatility,
37};
38use datafusion_functions::downcast_arg;
39use datafusion_macros::user_doc;
40use itertools::Itertools;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44 ArrayDistance,
45 array_distance,
46 array,
47 "returns the Euclidean distance between two numeric arrays.",
48 array_distance_udf
49);
50
51#[user_doc(
52 doc_section(label = "Array Functions"),
53 description = "Returns the Euclidean distance between two input arrays of equal length.",
54 syntax_example = "array_distance(array1, array2)",
55 sql_example = r#"```sql
56> select array_distance([1, 2], [1, 4]);
57+------------------------------------+
58| array_distance(List([1,2], [1,4])) |
59+------------------------------------+
60| 2.0 |
61+------------------------------------+
62```"#,
63 argument(
64 name = "array1",
65 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66 ),
67 argument(
68 name = "array2",
69 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
70 )
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct ArrayDistance {
74 signature: Signature,
75 aliases: Vec<String>,
76}
77
78impl Default for ArrayDistance {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84impl ArrayDistance {
85 pub fn new() -> Self {
86 Self {
87 signature: Signature::user_defined(Volatility::Immutable),
88 aliases: vec!["list_distance".to_string()],
89 }
90 }
91}
92
93impl ScalarUDFImpl for ArrayDistance {
94 fn name(&self) -> &str {
95 "array_distance"
96 }
97
98 fn signature(&self) -> &Signature {
99 &self.signature
100 }
101
102 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103 Ok(DataType::Float64)
104 }
105
106 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
107 let [_, _] = take_function_args(self.name(), arg_types)?;
108 let coercion = Some(&ListCoercion::FixedSizedListToList);
109 let arg_types = arg_types.iter().map(|arg_type| {
110 if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
111 Ok(coerced_type_with_base_type_only(
112 arg_type,
113 &DataType::Float64,
114 coercion,
115 ))
116 } else {
117 plan_err!("{} does not support type {arg_type}", self.name())
118 }
119 });
120
121 arg_types.try_collect()
122 }
123
124 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125 make_scalar_function(array_distance_inner)(&args.args)
126 }
127
128 fn aliases(&self) -> &[String] {
129 &self.aliases
130 }
131
132 fn documentation(&self) -> Option<&Documentation> {
133 self.doc()
134 }
135}
136
137fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
138 let [array1, array2] = take_function_args("array_distance", args)?;
139 match (array1.data_type(), array2.data_type()) {
140 (List(_), List(_)) => general_array_distance::<i32>(args),
141 (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
142 (arg_type1, arg_type2) => {
143 exec_err!("array_distance does not support types {arg_type1} and {arg_type2}")
144 }
145 }
146}
147
148fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
149 let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
150 let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
151
152 let result = list_array1
153 .iter()
154 .zip(list_array2.iter())
155 .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
156 .collect::<Result<Float64Array>>()?;
157
158 Ok(Arc::new(result) as ArrayRef)
159}
160
161fn compute_array_distance(
163 arr1: Option<ArrayRef>,
164 arr2: Option<ArrayRef>,
165) -> Result<Option<f64>> {
166 let value1 = match arr1 {
167 Some(arr) => arr,
168 None => return Ok(None),
169 };
170 let value2 = match arr2 {
171 Some(arr) => arr,
172 None => return Ok(None),
173 };
174
175 let mut value1 = value1;
176 let mut value2 = value2;
177
178 loop {
179 match value1.data_type() {
180 List(_) => {
181 if downcast_arg!(value1, ListArray).null_count() > 0 {
182 return Ok(None);
183 }
184 value1 = downcast_arg!(value1, ListArray).value(0);
185 }
186 LargeList(_) => {
187 if downcast_arg!(value1, LargeListArray).null_count() > 0 {
188 return Ok(None);
189 }
190 value1 = downcast_arg!(value1, LargeListArray).value(0);
191 }
192 _ => break,
193 }
194
195 match value2.data_type() {
196 List(_) => {
197 if downcast_arg!(value2, ListArray).null_count() > 0 {
198 return Ok(None);
199 }
200 value2 = downcast_arg!(value2, ListArray).value(0);
201 }
202 LargeList(_) => {
203 if downcast_arg!(value2, LargeListArray).null_count() > 0 {
204 return Ok(None);
205 }
206 value2 = downcast_arg!(value2, LargeListArray).value(0);
207 }
208 _ => break,
209 }
210 }
211
212 if value1.null_count() != 0 || value2.null_count() != 0 {
214 return Ok(None);
215 }
216
217 let values1 = convert_to_f64_array(&value1)?;
218 let values2 = convert_to_f64_array(&value2)?;
219
220 if values1.len() != values2.len() {
221 return exec_err!("Both arrays must have the same length");
222 }
223
224 let sum_squares: f64 = values1
225 .iter()
226 .zip(values2.iter())
227 .map(|(v1, v2)| {
228 let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
229 diff * diff
230 })
231 .sum();
232
233 Ok(Some(sum_squares.sqrt()))
234}
235
236fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
238 match array.data_type() {
239 DataType::Float64 => Ok(as_float64_array(array)?.clone()),
240 DataType::Float32 => {
241 let array = as_float32_array(array)?;
242 let converted: Float64Array =
243 array.iter().map(|v| v.map(|v| v as f64)).collect();
244 Ok(converted)
245 }
246 DataType::Int64 => {
247 let array = as_int64_array(array)?;
248 let converted: Float64Array =
249 array.iter().map(|v| v.map(|v| v as f64)).collect();
250 Ok(converted)
251 }
252 DataType::Int32 => {
253 let array = as_int32_array(array)?;
254 let converted: Float64Array =
255 array.iter().map(|v| v.map(|v| v as f64)).collect();
256 Ok(converted)
257 }
258 _ => exec_err!("Unsupported array type for conversion to Float64Array"),
259 }
260}