datafusion_functions_nested/
cosine_distance.rs1use crate::utils::make_scalar_function;
21use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
22use arrow::datatypes::{
23 DataType,
24 DataType::{FixedSizeList, LargeList, List, Null},
25 Field,
26};
27use datafusion_common::cast::{as_float64_array, as_generic_list_array};
28use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
29use datafusion_common::{
30 Result, exec_err, internal_err, plan_err, utils::take_function_args,
31};
32use datafusion_expr::{
33 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34 Volatility,
35};
36use datafusion_macros::user_doc;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 CosineDistance,
41 cosine_distance,
42 array1 array2,
43 "returns the cosine distance between two numeric arrays.",
44 cosine_distance_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.",
50 syntax_example = "cosine_distance(array1, array2)",
51 sql_example = r#"```sql
52> select cosine_distance([1.0, 0.0], [0.0, 1.0]);
53+-----------------------------------------------+
54| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) |
55+-----------------------------------------------+
56| 1.0 |
57+-----------------------------------------------+
58```"#,
59 argument(
60 name = "array1",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 ),
63 argument(
64 name = "array2",
65 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct CosineDistance {
70 signature: Signature,
71}
72
73impl Default for CosineDistance {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl CosineDistance {
80 pub fn new() -> Self {
81 Self {
82 signature: Signature::user_defined(Volatility::Immutable),
83 }
84 }
85}
86
87impl ScalarUDFImpl for CosineDistance {
88 fn name(&self) -> &str {
89 "cosine_distance"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
97 Ok(DataType::Float64)
98 }
99
100 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
101 let [_, _] = take_function_args(self.name(), arg_types)?;
102 let coercion = Some(&ListCoercion::FixedSizedListToList);
103
104 for arg_type in arg_types {
105 if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
106 return plan_err!("{} does not support type {arg_type}", self.name());
107 }
108 }
109
110 let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
114
115 let coerced = arg_types
116 .iter()
117 .map(|arg_type| {
118 if matches!(arg_type, Null) {
119 let field = Arc::new(Field::new_list_field(DataType::Float64, true));
120 return if any_large_list {
121 LargeList(field)
122 } else {
123 List(field)
124 };
125 }
126 let coerced = coerced_type_with_base_type_only(
127 arg_type,
128 &DataType::Float64,
129 coercion,
130 );
131 match coerced {
132 List(field) if any_large_list => LargeList(field),
133 other => other,
134 }
135 })
136 .collect();
137
138 Ok(coerced)
139 }
140
141 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
142 make_scalar_function(cosine_distance_inner)(&args.args)
143 }
144
145 fn documentation(&self) -> Option<&Documentation> {
146 self.doc()
147 }
148}
149
150fn cosine_distance_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
151 let [array1, array2] = take_function_args("cosine_distance", args)?;
152 match (array1.data_type(), array2.data_type()) {
153 (List(_), List(_)) => general_cosine_distance::<i32>(args),
154 (LargeList(_), LargeList(_)) => general_cosine_distance::<i64>(args),
155 (arg_type1, arg_type2) => internal_err!(
156 "cosine_distance received unexpected types after coercion: {arg_type1} and {arg_type2}"
157 ),
158 }
159}
160
161fn general_cosine_distance<O: OffsetSizeTrait>(arrays: &[ArrayRef]) -> Result<ArrayRef> {
162 let list_array1 = as_generic_list_array::<O>(&arrays[0])?;
163 let list_array2 = as_generic_list_array::<O>(&arrays[1])?;
164
165 let values1 = as_float64_array(list_array1.values())?;
166 let values2 = as_float64_array(list_array2.values())?;
167 let offsets1 = list_array1.value_offsets();
168 let offsets2 = list_array2.value_offsets();
169
170 let mut builder = Float64Array::builder(list_array1.len());
171 for row in 0..list_array1.len() {
172 if list_array1.is_null(row) || list_array2.is_null(row) {
173 builder.append_null();
174 continue;
175 }
176
177 let start1 = offsets1[row].as_usize();
178 let end1 = offsets1[row + 1].as_usize();
179 let start2 = offsets2[row].as_usize();
180 let end2 = offsets2[row + 1].as_usize();
181 let len1 = end1 - start1;
182 let len2 = end2 - start2;
183
184 if len1 != len2 {
185 return exec_err!(
186 "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}"
187 );
188 }
189
190 let slice1 = values1.slice(start1, len1);
191 let slice2 = values2.slice(start2, len2);
192 if slice1.null_count() != 0 || slice2.null_count() != 0 {
193 builder.append_null();
194 continue;
195 }
196
197 let vals1 = slice1.values();
198 let vals2 = slice2.values();
199
200 let mut dot = 0.0;
201 let mut sq1 = 0.0;
202 let mut sq2 = 0.0;
203 for i in 0..len1 {
204 let a = vals1[i];
205 let b = vals2[i];
206 dot += a * b;
207 sq1 += a * a;
208 sq2 += b * b;
209 }
210
211 if sq1 == 0.0 || sq2 == 0.0 {
212 builder.append_null();
213 } else {
214 builder.append_value(1.0 - dot / (sq1.sqrt() * sq2.sqrt()));
215 }
216 }
217
218 Ok(Arc::new(builder.finish()) as ArrayRef)
219}