datafusion_functions_nested/
distance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [ScalarUDFImpl] definitions for array_distance function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
23};
24use arrow::datatypes::{
25    DataType,
26    DataType::{FixedSizeList, LargeList, List, Null},
27};
28use datafusion_common::cast::{
29    as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
30    as_int64_array,
31};
32use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion};
33use datafusion_common::{
34    exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result,
35};
36use datafusion_expr::{
37    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
38};
39use datafusion_functions::{downcast_arg, downcast_named_arg};
40use datafusion_macros::user_doc;
41use itertools::Itertools;
42use std::any::Any;
43use std::sync::Arc;
44
45make_udf_expr_and_func!(
46    ArrayDistance,
47    array_distance,
48    array,
49    "returns the Euclidean distance between two numeric arrays.",
50    array_distance_udf
51);
52
53#[user_doc(
54    doc_section(label = "Array Functions"),
55    description = "Returns the Euclidean distance between two input arrays of equal length.",
56    syntax_example = "array_distance(array1, array2)",
57    sql_example = r#"```sql
58> select array_distance([1, 2], [1, 4]);
59+------------------------------------+
60| array_distance(List([1,2], [1,4])) |
61+------------------------------------+
62| 2.0                                |
63+------------------------------------+
64```"#,
65    argument(
66        name = "array1",
67        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
68    ),
69    argument(
70        name = "array2",
71        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
72    )
73)]
74#[derive(Debug)]
75pub struct ArrayDistance {
76    signature: Signature,
77    aliases: Vec<String>,
78}
79
80impl Default for ArrayDistance {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl ArrayDistance {
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::user_defined(Volatility::Immutable),
90            aliases: vec!["list_distance".to_string()],
91        }
92    }
93}
94
95impl ScalarUDFImpl for ArrayDistance {
96    fn as_any(&self) -> &dyn Any {
97        self
98    }
99
100    fn name(&self) -> &str {
101        "array_distance"
102    }
103
104    fn signature(&self) -> &Signature {
105        &self.signature
106    }
107
108    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
109        Ok(DataType::Float64)
110    }
111
112    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
113        let [_, _] = take_function_args(self.name(), arg_types)?;
114        let coercion = Some(&ListCoercion::FixedSizedListToList);
115        let arg_types = arg_types.iter().map(|arg_type| {
116            if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
117                Ok(coerced_type_with_base_type_only(
118                    arg_type,
119                    &DataType::Float64,
120                    coercion,
121                ))
122            } else {
123                plan_err!("{} does not support type {arg_type}", self.name())
124            }
125        });
126
127        arg_types.try_collect()
128    }
129
130    fn invoke_with_args(
131        &self,
132        args: datafusion_expr::ScalarFunctionArgs,
133    ) -> Result<ColumnarValue> {
134        make_scalar_function(array_distance_inner)(&args.args)
135    }
136
137    fn aliases(&self) -> &[String] {
138        &self.aliases
139    }
140
141    fn documentation(&self) -> Option<&Documentation> {
142        self.doc()
143    }
144}
145
146pub fn array_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
147    let [array1, array2] = take_function_args("array_distance", args)?;
148    match (array1.data_type(), array2.data_type()) {
149        (List(_), List(_)) => general_array_distance::<i32>(args),
150        (LargeList(_), LargeList(_)) => general_array_distance::<i64>(args),
151        (arg_type1, arg_type2) => {
152            exec_err!("array_distance does not support types {arg_type1} and {arg_type2}")
153        }
154    }
155}
156
157fn general_array_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
158    let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
159    let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
160
161    let result = list_array1
162        .iter()
163        .zip(list_array2.iter())
164        .map(|(arr1, arr2)| compute_array_distance(arr1, arr2))
165        .collect::<Result<Float64Array>>()?;
166
167    Ok(Arc::new(result) as ArrayRef)
168}
169
170/// Computes the Euclidean distance between two arrays
171fn compute_array_distance(
172    arr1: Option<ArrayRef>,
173    arr2: Option<ArrayRef>,
174) -> Result<Option<f64>> {
175    let value1 = match arr1 {
176        Some(arr) => arr,
177        None => return Ok(None),
178    };
179    let value2 = match arr2 {
180        Some(arr) => arr,
181        None => return Ok(None),
182    };
183
184    let mut value1 = value1;
185    let mut value2 = value2;
186
187    loop {
188        match value1.data_type() {
189            List(_) => {
190                if downcast_arg!(value1, ListArray).null_count() > 0 {
191                    return Ok(None);
192                }
193                value1 = downcast_arg!(value1, ListArray).value(0);
194            }
195            LargeList(_) => {
196                if downcast_arg!(value1, LargeListArray).null_count() > 0 {
197                    return Ok(None);
198                }
199                value1 = downcast_arg!(value1, LargeListArray).value(0);
200            }
201            _ => break,
202        }
203
204        match value2.data_type() {
205            List(_) => {
206                if downcast_arg!(value2, ListArray).null_count() > 0 {
207                    return Ok(None);
208                }
209                value2 = downcast_arg!(value2, ListArray).value(0);
210            }
211            LargeList(_) => {
212                if downcast_arg!(value2, LargeListArray).null_count() > 0 {
213                    return Ok(None);
214                }
215                value2 = downcast_arg!(value2, LargeListArray).value(0);
216            }
217            _ => break,
218        }
219    }
220
221    // Check for NULL values inside the arrays
222    if value1.null_count() != 0 || value2.null_count() != 0 {
223        return Ok(None);
224    }
225
226    let values1 = convert_to_f64_array(&value1)?;
227    let values2 = convert_to_f64_array(&value2)?;
228
229    if values1.len() != values2.len() {
230        return exec_err!("Both arrays must have the same length");
231    }
232
233    let sum_squares: f64 = values1
234        .iter()
235        .zip(values2.iter())
236        .map(|(v1, v2)| {
237            let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0);
238            diff * diff
239        })
240        .sum();
241
242    Ok(Some(sum_squares.sqrt()))
243}
244
245/// Converts an array of any numeric type to a Float64Array.
246fn convert_to_f64_array(array: &ArrayRef) -> Result<Float64Array> {
247    match array.data_type() {
248        DataType::Float64 => Ok(as_float64_array(array)?.clone()),
249        DataType::Float32 => {
250            let array = as_float32_array(array)?;
251            let converted: Float64Array =
252                array.iter().map(|v| v.map(|v| v as f64)).collect();
253            Ok(converted)
254        }
255        DataType::Int64 => {
256            let array = as_int64_array(array)?;
257            let converted: Float64Array =
258                array.iter().map(|v| v.map(|v| v as f64)).collect();
259            Ok(converted)
260        }
261        DataType::Int32 => {
262            let array = as_int32_array(array)?;
263            let converted: Float64Array =
264                array.iter().map(|v| v.map(|v| v as f64)).collect();
265            Ok(converted)
266        }
267        _ => exec_err!("Unsupported array type for conversion to Float64Array"),
268    }
269}