Skip to main content

datafusion_functions_nested/
cosine_distance.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for cosine_distance function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22use arrow::datatypes::{
23    DataType,
24    DataType::{FixedSizeList, LargeList, List, Null},
25    Field,
26};
27use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29use datafusion_common::{
30    Result, exec_err, internal_err, plan_err, utils::take_function_args,
31};
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    CosineDistance,
41    cosine_distance,
42    array1 array2,
43    "returns the cosine distance between two numeric arrays.",
44    cosine_distance_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.",
50    syntax_example = "cosine_distance(array1, array2)",
51    sql_example = r#"```sql
52> select cosine_distance([1.0, 0.0], [0.0, 1.0]);
53+-----------------------------------------------+
54| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) |
55+-----------------------------------------------+
56| 1.0                                           |
57+-----------------------------------------------+
58```"#,
59    argument(
60        name = "array1",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    ),
63    argument(
64        name = "array2",
65        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct CosineDistance {
70    signature: Signature,
71}
72
73impl Default for CosineDistance {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl CosineDistance {
80    pub fn new() -> Self {
81        Self {
82            signature: Signature::user_defined(Volatility::Immutable),
83        }
84    }
85}
86
87impl ScalarUDFImpl for CosineDistance {
88    fn name(&self) -> &str {
89        "cosine_distance"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
97        Ok(DataType::Float64)
98    }
99
100    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101        let [_, _] = take_function_args(self.name(), arg_types)?;
102        let coercion = Some(&ListCoercion::FixedSizedListToList);
103
104        for arg_type in arg_types {
105            if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
106                return plan_err!("{} does not support type {arg_type}", self.name());
107            }
108        }
109
110        // If any input is `LargeList`, both sides must be widened to `LargeList`
111        // so the runtime dispatch in `cosine_distance_inner` sees a homogeneous
112        // pair. Follows the pattern in `ArrayConcat::coerce_types`.
113        let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
114
115        let coerced = arg_types
116            .iter()
117            .map(|arg_type| {
118                if matches!(arg_type, Null) {
119                    let field = Arc::new(Field::new_list_field(DataType::Float64, true));
120                    return if any_large_list {
121                        LargeList(field)
122                    } else {
123                        List(field)
124                    };
125                }
126                let coerced = coerced_type_with_base_type_only(
127                    arg_type,
128                    &DataType::Float64,
129                    coercion,
130                );
131                match coerced {
132                    List(field) if any_large_list => LargeList(field),
133                    other => other,
134                }
135            })
136            .collect();
137
138        Ok(coerced)
139    }
140
141    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142        make_scalar_function(cosine_distance_inner)(&args.args)
143    }
144
145    fn documentation(&self) -> Option<&Documentation> {
146        self.doc()
147    }
148}
149
150fn cosine_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151    let [array1, array2] = take_function_args("cosine_distance", args)?;
152    match (array1.data_type(), array2.data_type()) {
153        (List(_), List(_)) => general_cosine_distance::<i32>(args),
154        (LargeList(_), LargeList(_)) => general_cosine_distance::<i64>(args),
155        (arg_type1, arg_type2) => internal_err!(
156            "cosine_distance received unexpected types after coercion: {arg_type1} and {arg_type2}"
157        ),
158    }
159}
160
161fn general_cosine_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
162    let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
163    let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
164
165    let values1 = as_float64_array(list_array1.values())?;
166    let values2 = as_float64_array(list_array2.values())?;
167    let offsets1 = list_array1.value_offsets();
168    let offsets2 = list_array2.value_offsets();
169
170    let mut builder = Float64Array::builder(list_array1.len());
171    for row in 0..list_array1.len() {
172        if list_array1.is_null(row) || list_array2.is_null(row) {
173            builder.append_null();
174            continue;
175        }
176
177        let start1 = offsets1[row].as_usize();
178        let end1 = offsets1[row + 1].as_usize();
179        let start2 = offsets2[row].as_usize();
180        let end2 = offsets2[row + 1].as_usize();
181        let len1 = end1 - start1;
182        let len2 = end2 - start2;
183
184        if len1 != len2 {
185            return exec_err!(
186                "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}"
187            );
188        }
189
190        let slice1 = values1.slice(start1, len1);
191        let slice2 = values2.slice(start2, len2);
192        if slice1.null_count() != 0 || slice2.null_count() != 0 {
193            builder.append_null();
194            continue;
195        }
196
197        let vals1 = slice1.values();
198        let vals2 = slice2.values();
199
200        let mut dot = 0.0;
201        let mut sq1 = 0.0;
202        let mut sq2 = 0.0;
203        for i in 0..len1 {
204            let a = vals1[i];
205            let b = vals2[i];
206            dot += a * b;
207            sq1 += a * a;
208            sq2 += b * b;
209        }
210
211        if sq1 == 0.0 || sq2 == 0.0 {
212            builder.append_null();
213        } else {
214            builder.append_value(1.0 - dot / (sq1.sqrt() * sq2.sqrt()));
215        }
216    }
217
218    Ok(Arc::new(builder.finish()) as ArrayRef)
219}