1use std::sync::Arc;
13
14use arrow_array::cast::AsArray;
15use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
16use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
17use arrow_schema::{ArrowError, DataType};
18
19pub mod cosine;
20pub mod cosine_u8;
21pub mod dot;
22pub mod dot_u8;
23pub mod hamming;
24pub mod l2;
25pub mod l2_u8;
26pub mod norm_l2;
27
28pub use cosine::*;
29use deepsize::DeepSizeOf;
30pub use dot::*;
31use hamming::hamming_distance_arrow_batch;
32pub use l2::*;
33pub use norm_l2::*;
34
35use crate::Result;
36
37#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
39pub enum DistanceType {
40 L2,
41 Cosine,
42 Dot,
44 Hamming,
46}
47
48pub type MetricType = DistanceType;
50
51pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
52pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
53pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
54
55impl DistanceType {
56 pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
60 match self {
61 Self::L2 => l2_distance_arrow_batch,
62 Self::Cosine => cosine_distance_arrow_batch,
63 Self::Dot => dot_distance_arrow_batch,
64 Self::Hamming => hamming_distance_arrow_batch,
65 }
66 }
67
68 pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
70 match self {
71 Self::L2 => l2,
72 Self::Cosine => cosine_distance,
73 Self::Dot => dot_distance,
74 Self::Hamming => todo!(),
75 }
76 }
77}
78
79impl std::fmt::Display for DistanceType {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(
82 f,
83 "{}",
84 match self {
85 Self::L2 => "l2",
86 Self::Cosine => "cosine",
87 Self::Dot => "dot",
88 Self::Hamming => "hamming",
89 }
90 )
91 }
92}
93
94impl TryFrom<&str> for DistanceType {
95 type Error = ArrowError;
96
97 fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
98 match s.to_lowercase().as_str() {
99 "l2" | "euclidean" => Ok(Self::L2),
100 "cosine" => Ok(Self::Cosine),
101 "dot" => Ok(Self::Dot),
102 "hamming" => Ok(Self::Hamming),
103 _ => Err(ArrowError::InvalidArgumentError(format!(
104 "Metric type '{s}' is not supported"
105 ))),
106 }
107 }
108}
109
110pub fn multivec_distance(
111 query: &dyn Array,
112 vectors: &ListArray,
113 distance_type: DistanceType,
114) -> Result<Vec<f32>> {
115 let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() {
116 dim as usize
117 } else {
118 return Err(ArrowError::InvalidArgumentError(
119 "vectors must be a list of fixed size list".to_string(),
120 ));
121 };
122
123 match query.data_type() {
126 DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
127 _ => {
128 return Err(ArrowError::InvalidArgumentError(
129 "query must be a float array or binary array".to_string(),
130 ));
131 }
132 }
133
134 let mut dists = Vec::with_capacity(vectors.len());
135 for v in vectors.iter() {
136 match v {
137 None => dists.push(f32::NAN),
138 Some(v) => {
139 let multivector = v.as_fixed_size_list();
140 if multivector.len() == 0 {
141 dists.push(f32::NAN);
142 continue;
143 }
144
145 let sim = match distance_type {
146 DistanceType::Hamming => {
147 let query = query.as_primitive::<UInt8Type>().values();
148 query
149 .chunks_exact(dim)
150 .map(|q| {
151 multivector
152 .values()
153 .as_primitive::<UInt8Type>()
154 .values()
155 .chunks_exact(dim)
156 .map(|v| hamming::hamming(q, v))
157 .min_by(|a, b| a.partial_cmp(b).unwrap())
158 .unwrap()
159 })
160 .sum()
161 }
162 _ => match query.data_type() {
163 DataType::Float16 => multivec_distance_impl::<Float16Type>(
164 query,
165 multivector,
166 dim,
167 distance_type,
168 ),
169 DataType::Float32 => multivec_distance_impl::<Float32Type>(
170 query,
171 multivector,
172 dim,
173 distance_type,
174 ),
175 DataType::Float64 => multivec_distance_impl::<Float64Type>(
176 query,
177 multivector,
178 dim,
179 distance_type,
180 ),
181 _ => unreachable!("missed to check query type"),
182 },
183 };
184
185 dists.push(1.0 - sim);
186 }
187 }
188 }
189 Ok(dists)
190}
191
192fn multivec_distance_impl<T: ArrowPrimitiveType>(
193 query: &dyn Array,
194 multivector: &FixedSizeListArray,
195 dim: usize,
196 distance_type: DistanceType,
197) -> f32
198where
199 T::Native: L2 + Cosine + Dot,
200{
201 let query = query.as_primitive::<T>().values();
202 query
203 .chunks_exact(dim)
204 .map(|q| {
205 multivector
206 .values()
207 .as_primitive::<T>()
208 .values()
209 .chunks_exact(dim)
210 .map(|v| 1.0 - distance_type.func()(q, v))
211 .max_by(|a, b| a.total_cmp(b))
212 .unwrap()
213 })
214 .sum()
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 use std::sync::Arc;
222
223 use arrow_array::types::Float32Type;
224 use arrow_array::{Float32Array, ListArray};
225 use arrow_buffer::OffsetBuffer;
226 use arrow_schema::Field;
227
228 #[test]
229 fn test_multivec_distance_empty_row_is_nan() {
230 let query: Arc<dyn Array> = Arc::new(Float32Array::from_iter_values([1.0_f32, 2.0]));
231
232 let dim = 2;
233 let values = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
234 vec![Some(vec![Some(1.0_f32), Some(2.0)])],
235 dim,
236 );
237
238 let offsets = OffsetBuffer::from_lengths([0_usize, 1]);
240 let field = Arc::new(Field::new("item", values.data_type().clone(), true));
241 let vectors = ListArray::try_new(field, offsets, Arc::new(values), None).unwrap();
242
243 let dists = multivec_distance(query.as_ref(), &vectors, DistanceType::Dot).unwrap();
244 assert_eq!(dists.len(), 2);
245 assert!(dists[0].is_nan());
246 assert_eq!(dists[1], -4.0);
247 }
248}