1use std::sync::Arc;
13
14use arrow_array::cast::AsArray;
15use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
16use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
17use arrow_schema::{ArrowError, DataType};
18
19pub mod cosine;
20pub mod dot;
21pub mod hamming;
22pub mod l2;
23pub mod norm_l2;
24
25pub use cosine::*;
26use deepsize::DeepSizeOf;
27pub use dot::*;
28use hamming::hamming_distance_arrow_batch;
29pub use l2::*;
30pub use norm_l2::*;
31
32use crate::Result;
33
34#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
36pub enum DistanceType {
37 L2,
38 Cosine,
39 Dot,
41 Hamming,
43}
44
45pub type MetricType = DistanceType;
47
48pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
49pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
50pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
51
52impl DistanceType {
53 pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
57 match self {
58 Self::L2 => l2_distance_arrow_batch,
59 Self::Cosine => cosine_distance_arrow_batch,
60 Self::Dot => dot_distance_arrow_batch,
61 Self::Hamming => hamming_distance_arrow_batch,
62 }
63 }
64
65 pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
67 match self {
68 Self::L2 => l2,
69 Self::Cosine => cosine_distance,
70 Self::Dot => dot_distance,
71 Self::Hamming => todo!(),
72 }
73 }
74}
75
76impl std::fmt::Display for DistanceType {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(
79 f,
80 "{}",
81 match self {
82 Self::L2 => "l2",
83 Self::Cosine => "cosine",
84 Self::Dot => "dot",
85 Self::Hamming => "hamming",
86 }
87 )
88 }
89}
90
91impl TryFrom<&str> for DistanceType {
92 type Error = ArrowError;
93
94 fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
95 match s.to_lowercase().as_str() {
96 "l2" | "euclidean" => Ok(Self::L2),
97 "cosine" => Ok(Self::Cosine),
98 "dot" => Ok(Self::Dot),
99 "hamming" => Ok(Self::Hamming),
100 _ => Err(ArrowError::InvalidArgumentError(format!(
101 "Metric type '{s}' is not supported"
102 ))),
103 }
104 }
105}
106
107pub fn multivec_distance(
108 query: &dyn Array,
109 vectors: &ListArray,
110 distance_type: DistanceType,
111) -> Result<Vec<f32>> {
112 let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() {
113 dim as usize
114 } else {
115 return Err(ArrowError::InvalidArgumentError(
116 "vectors must be a list of fixed size list".to_string(),
117 ));
118 };
119
120 match query.data_type() {
123 DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
124 _ => {
125 return Err(ArrowError::InvalidArgumentError(
126 "query must be a float array or binary array".to_string(),
127 ));
128 }
129 }
130
131 let dists = vectors
132 .iter()
133 .map(|v| {
134 v.map(|v| {
135 let multivector = v.as_fixed_size_list();
136 match distance_type {
137 DistanceType::Hamming => {
138 let query = query.as_primitive::<UInt8Type>().values();
139 query
140 .chunks_exact(dim)
141 .map(|q| {
142 multivector
143 .values()
144 .as_primitive::<UInt8Type>()
145 .values()
146 .chunks_exact(dim)
147 .map(|v| hamming::hamming(q, v))
148 .min_by(|a, b| a.partial_cmp(b).unwrap())
149 .unwrap()
150 })
151 .sum()
152 }
153 _ => match query.data_type() {
154 DataType::Float16 => multivec_distance_impl::<Float16Type>(
155 query,
156 multivector,
157 dim,
158 distance_type,
159 ),
160 DataType::Float32 => multivec_distance_impl::<Float32Type>(
161 query,
162 multivector,
163 dim,
164 distance_type,
165 ),
166 DataType::Float64 => multivec_distance_impl::<Float64Type>(
167 query,
168 multivector,
169 dim,
170 distance_type,
171 ),
172 _ => unreachable!("missed to check query type"),
173 },
174 }
175 })
176 .unwrap_or(f32::NAN)
177 })
178 .map(|sim| 1.0 - sim)
179 .collect();
180 Ok(dists)
181}
182
183fn multivec_distance_impl<T: ArrowPrimitiveType>(
184 query: &dyn Array,
185 multivector: &FixedSizeListArray,
186 dim: usize,
187 distance_type: DistanceType,
188) -> f32
189where
190 T::Native: L2 + Cosine + Dot,
191{
192 let query = query.as_primitive::<T>().values();
193 query
194 .chunks_exact(dim)
195 .map(|q| {
196 multivector
197 .values()
198 .as_primitive::<T>()
199 .values()
200 .chunks_exact(dim)
201 .map(|v| 1.0 - distance_type.func()(q, v))
202 .max_by(|a, b| a.total_cmp(b))
203 .unwrap()
204 })
205 .sum()
206}