datafusion_functions_nested/
distance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [ScalarUDFImpl] definitions for array_distance function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25    DataType,
26    DataType::{FixedSizeList, Float64, LargeList, List},
27};
28use datafusion_common::cast::{
29    as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30    as_int64_array,
31};
32use datafusion_common::utils::coerced_fixed_size_list_to_list;
33use datafusion_common::{
34    exec_err, internal_datafusion_err, utils::take_function_args, Result,
35};
36use datafusion_expr::{
37    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_functions::{downcast_arg, downcast_named_arg};
40use datafusion_macros::user_doc;
41use std::any::Any;
42use std::sync::Arc;
43
44make_udf_expr_and_func!(
45    ArrayDistance,
46    array_distance,
47    array,
48    "returns the Euclidean distance between two numeric arrays.",
49    array_distance_udf
50);
51
52#[user_doc(
53    doc_section(label = "Array Functions"),
54    description = "Returns the Euclidean distance between two input arrays of equal length.",
55    syntax_example = "array_distance(array1, array2)",
56    sql_example = r#"```sql
57> select array_distance([1, 2], [1, 4]);
58+------------------------------------+
59| array_distance(List([1,2], [1,4])) |
60+------------------------------------+
61| 2.0                                |
62+------------------------------------+
63```"#,
64    argument(
65        name = "array1",
66        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
67    ),
68    argument(
69        name = "array2",
70        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
71    )
72)]
73#[derive(Debug)]
74pub struct ArrayDistance {
75    signature: Signature,
76    aliases: Vec<String>,
77}
78
79impl Default for ArrayDistance {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl ArrayDistance {
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::user_defined(Volatility::Immutable),
89            aliases: vec!["list_distance".to_string()],
90        }
91    }
92}
93
94impl ScalarUDFImpl for ArrayDistance {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "array_distance"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
108        match arg_types[0] {
109            List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64),
110            _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
111        }
112    }
113
114    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
115        let [_, _] = take_function_args(self.name(), arg_types)?;
116        let mut result = Vec::new();
117        for arg_type in arg_types {
118            match arg_type {
119                List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)),
120                _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."),
121            }
122        }
123
124        Ok(result)
125    }
126
127    fn invoke_with_args(
128        &self,
129        args: datafusion_expr::ScalarFunctionArgs,
130    ) -> Result<ColumnarValue> {
131        make_scalar_function(array_distance_inner)(&args.args)
132    }
133
134    fn aliases(&self) -> &[String] {
135        &self.aliases
136    }
137
138    fn documentation(&self) -> Option<&Documentation> {
139        self.doc()
140    }
141}
142
143pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
144    let [array1, array2] = take_function_args("array_distance", args)?;
145
146    match (&array1.data_type(), &array2.data_type()) {
147        (List(_), List(_)) => general_array_distance::<i32>(args),
148        (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
149        (array_type1, array_type2) => {
150            exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'")
151        }
152    }
153}
154
155fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
156    let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
157    let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
158
159    let result = list_array1
160        .iter()
161        .zip(list_array2.iter())
162        .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
163        .collect::<Result<Float64Array>>()?;
164
165    Ok(Arc::new(result) as ArrayRef)
166}
167
168/// Computes the Euclidean distance between two arrays
169fn compute_array_distance(
170    arr1: Option<ArrayRef>,
171    arr2: Option<ArrayRef>,
172) -> Result<Option<f64>> {
173    let value1 = match arr1 {
174        Some(arr) => arr,
175        None => return Ok(None),
176    };
177    let value2 = match arr2 {
178        Some(arr) => arr,
179        None => return Ok(None),
180    };
181
182    let mut value1 = value1;
183    let mut value2 = value2;
184
185    loop {
186        match value1.data_type() {
187            List(_) => {
188                if downcast_arg!(value1, ListArray).null_count() > 0 {
189                    return Ok(None);
190                }
191                value1 = downcast_arg!(value1, ListArray).value(0);
192            }
193            LargeList(_) => {
194                if downcast_arg!(value1, LargeListArray).null_count() > 0 {
195                    return Ok(None);
196                }
197                value1 = downcast_arg!(value1, LargeListArray).value(0);
198            }
199            _ => break,
200        }
201
202        match value2.data_type() {
203            List(_) => {
204                if downcast_arg!(value2, ListArray).null_count() > 0 {
205                    return Ok(None);
206                }
207                value2 = downcast_arg!(value2, ListArray).value(0);
208            }
209            LargeList(_) => {
210                if downcast_arg!(value2, LargeListArray).null_count() > 0 {
211                    return Ok(None);
212                }
213                value2 = downcast_arg!(value2, LargeListArray).value(0);
214            }
215            _ => break,
216        }
217    }
218
219    // Check for NULL values inside the arrays
220    if value1.null_count() != 0 || value2.null_count() != 0 {
221        return Ok(None);
222    }
223
224    let values1 = convert_to_f64_array(&value1)?;
225    let values2 = convert_to_f64_array(&value2)?;
226
227    if values1.len() != values2.len() {
228        return exec_err!("Both arrays must have the same length");
229    }
230
231    let sum_squares: f64 = values1
232        .iter()
233        .zip(values2.iter())
234        .map(|(v1, v2)| {
235            let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
236            diff * diff
237        })
238        .sum();
239
240    Ok(Some(sum_squares.sqrt()))
241}
242
243/// Converts an array of any numeric type to a Float64Array.
244fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
245    match array.data_type() {
246        Float64 => Ok(as_float64_array(array)?.clone()),
247        DataType::Float32 => {
248            let array = as_float32_array(array)?;
249            let converted: Float64Array =
250                array.iter().map(|v| v.map(|v| v as f64)).collect();
251            Ok(converted)
252        }
253        DataType::Int64 => {
254            let array = as_int64_array(array)?;
255            let converted: Float64Array =
256                array.iter().map(|v| v.map(|v| v as f64)).collect();
257            Ok(converted)
258        }
259        DataType::Int32 => {
260            let array = as_int32_array(array)?;
261            let converted: Float64Array =
262                array.iter().map(|v| v.map(|v| v as f64)).collect();
263            Ok(converted)
264        }
265        _ => exec_err!("Unsupported array type for conversion to Float64Array"),
266    }
267}