Skip to main content

datafusion_functions_nested/
distance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [ScalarUDFImpl] definitions for array_distance function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25    DataType,
26    DataType::{FixedSizeList, LargeList, List, Null},
27};
28use datafusion_common::cast::{
29    as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30    as_int64_array,
31};
32use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
33use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args};
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36    Volatility,
37};
38use datafusion_functions::downcast_arg;
39use datafusion_macros::user_doc;
40use itertools::Itertools;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44    ArrayDistance,
45    array_distance,
46    array,
47    "returns the Euclidean distance between two numeric arrays.",
48    array_distance_udf
49);
50
51#[user_doc(
52    doc_section(label = "Array Functions"),
53    description = "Returns the Euclidean distance between two input arrays of equal length.",
54    syntax_example = "array_distance(array1, array2)",
55    sql_example = r#"```sql
56> select array_distance([1, 2], [1, 4]);
57+------------------------------------+
58| array_distance(List([1,2], [1,4])) |
59+------------------------------------+
60| 2.0                                |
61+------------------------------------+
62```"#,
63    argument(
64        name = "array1",
65        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66    ),
67    argument(
68        name = "array2",
69        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
70    )
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct ArrayDistance {
74    signature: Signature,
75    aliases: Vec<String>,
76}
77
78impl Default for ArrayDistance {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl ArrayDistance {
85    pub fn new() -> Self {
86        Self {
87            signature: Signature::user_defined(Volatility::Immutable),
88            aliases: vec!["list_distance".to_string()],
89        }
90    }
91}
92
93impl ScalarUDFImpl for ArrayDistance {
94    fn name(&self) -> &str {
95        "array_distance"
96    }
97
98    fn signature(&self) -> &Signature {
99        &self.signature
100    }
101
102    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103        Ok(DataType::Float64)
104    }
105
106    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
107        let [_, _] = take_function_args(self.name(), arg_types)?;
108        let coercion = Some(&ListCoercion::FixedSizedListToList);
109        let arg_types = arg_types.iter().map(|arg_type| {
110            if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
111                Ok(coerced_type_with_base_type_only(
112                    arg_type,
113                    &DataType::Float64,
114                    coercion,
115                ))
116            } else {
117                plan_err!("{} does not support type {arg_type}", self.name())
118            }
119        });
120
121        arg_types.try_collect()
122    }
123
124    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
125        make_scalar_function(array_distance_inner)(&args.args)
126    }
127
128    fn aliases(&self) -> &[String] {
129        &self.aliases
130    }
131
132    fn documentation(&self) -> Option<&Documentation> {
133        self.doc()
134    }
135}
136
137fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
138    let [array1, array2] = take_function_args("array_distance", args)?;
139    match (array1.data_type(), array2.data_type()) {
140        (List(_), List(_)) => general_array_distance::<i32>(args),
141        (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
142        (arg_type1, arg_type2) => {
143            exec_err!("array_distance does not support types {arg_type1} and {arg_type2}")
144        }
145    }
146}
147
148fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
149    let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
150    let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
151
152    let result = list_array1
153        .iter()
154        .zip(list_array2.iter())
155        .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
156        .collect::<Result<Float64Array>>()?;
157
158    Ok(Arc::new(result) as ArrayRef)
159}
160
161/// Computes the Euclidean distance between two arrays
162fn compute_array_distance(
163    arr1: Option<ArrayRef>,
164    arr2: Option<ArrayRef>,
165) -> Result<Option<f64>> {
166    let value1 = match arr1 {
167        Some(arr) => arr,
168        None => return Ok(None),
169    };
170    let value2 = match arr2 {
171        Some(arr) => arr,
172        None => return Ok(None),
173    };
174
175    let mut value1 = value1;
176    let mut value2 = value2;
177
178    loop {
179        match value1.data_type() {
180            List(_) => {
181                if downcast_arg!(value1, ListArray).null_count() > 0 {
182                    return Ok(None);
183                }
184                value1 = downcast_arg!(value1, ListArray).value(0);
185            }
186            LargeList(_) => {
187                if downcast_arg!(value1, LargeListArray).null_count() > 0 {
188                    return Ok(None);
189                }
190                value1 = downcast_arg!(value1, LargeListArray).value(0);
191            }
192            _ => break,
193        }
194
195        match value2.data_type() {
196            List(_) => {
197                if downcast_arg!(value2, ListArray).null_count() > 0 {
198                    return Ok(None);
199                }
200                value2 = downcast_arg!(value2, ListArray).value(0);
201            }
202            LargeList(_) => {
203                if downcast_arg!(value2, LargeListArray).null_count() > 0 {
204                    return Ok(None);
205                }
206                value2 = downcast_arg!(value2, LargeListArray).value(0);
207            }
208            _ => break,
209        }
210    }
211
212    // Check for NULL values inside the arrays
213    if value1.null_count() != 0 || value2.null_count() != 0 {
214        return Ok(None);
215    }
216
217    let values1 = convert_to_f64_array(&value1)?;
218    let values2 = convert_to_f64_array(&value2)?;
219
220    if values1.len() != values2.len() {
221        return exec_err!("Both arrays must have the same length");
222    }
223
224    let sum_squares: f64 = values1
225        .iter()
226        .zip(values2.iter())
227        .map(|(v1, v2)| {
228            let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
229            diff * diff
230        })
231        .sum();
232
233    Ok(Some(sum_squares.sqrt()))
234}
235
236/// Converts an array of any numeric type to a Float64Array.
237fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
238    match array.data_type() {
239        DataType::Float64 => Ok(as_float64_array(array)?.clone()),
240        DataType::Float32 => {
241            let array = as_float32_array(array)?;
242            let converted: Float64Array =
243                array.iter().map(|v| v.map(|v| v as f64)).collect();
244            Ok(converted)
245        }
246        DataType::Int64 => {
247            let array = as_int64_array(array)?;
248            let converted: Float64Array =
249                array.iter().map(|v| v.map(|v| v as f64)).collect();
250            Ok(converted)
251        }
252        DataType::Int32 => {
253            let array = as_int32_array(array)?;
254            let converted: Float64Array =
255                array.iter().map(|v| v.map(|v| v as f64)).collect();
256            Ok(converted)
257        }
258        _ => exec_err!("Unsupported array type for conversion to Float64Array"),
259    }
260}