Skip to main content

lance_linalg/
distance.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Distance metrics
5//!
6//! This module provides distance metrics for vectors.
7//!
8//! - `bf16, f16, f32, f64` types are supported.
9//! - SIMD is used when available, on `x86_64`, `aarch64` and `loongarch64`
10//!   architectures.
11
12use std::sync::Arc;
13
14use arrow_array::cast::AsArray;
15use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
16use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
17use arrow_schema::{ArrowError, DataType};
18
19pub mod cosine;
20pub mod cosine_u8;
21pub mod dot;
22pub mod dot_u8;
23pub mod hamming;
24pub mod l2;
25pub mod l2_u8;
26pub mod norm_l2;
27
28pub use cosine::*;
29use deepsize::DeepSizeOf;
30pub use dot::*;
31use hamming::hamming_distance_arrow_batch;
32pub use l2::*;
33pub use norm_l2::*;
34
35use crate::Result;
36
37/// Distance metrics type.
38#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
39pub enum DistanceType {
40    L2,
41    Cosine,
42    /// Dot Product
43    Dot,
44    /// Hamming Distance
45    Hamming,
46}
47
48/// For backwards compatibility.
49pub type MetricType = DistanceType;
50
51pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
52pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
53pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
54
55impl DistanceType {
56    /// Compute the distance from one vector to a batch of vectors.
57    ///
58    /// This propagates nulls to the output.
59    pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
60        match self {
61            Self::L2 => l2_distance_arrow_batch,
62            Self::Cosine => cosine_distance_arrow_batch,
63            Self::Dot => dot_distance_arrow_batch,
64            Self::Hamming => hamming_distance_arrow_batch,
65        }
66    }
67
68    /// Returns the distance function between two vectors.
69    pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
70        match self {
71            Self::L2 => l2,
72            Self::Cosine => cosine_distance,
73            Self::Dot => dot_distance,
74            Self::Hamming => todo!(),
75        }
76    }
77}
78
79impl std::fmt::Display for DistanceType {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(
82            f,
83            "{}",
84            match self {
85                Self::L2 => "l2",
86                Self::Cosine => "cosine",
87                Self::Dot => "dot",
88                Self::Hamming => "hamming",
89            }
90        )
91    }
92}
93
94impl TryFrom<&str> for DistanceType {
95    type Error = ArrowError;
96
97    fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
98        match s.to_lowercase().as_str() {
99            "l2" | "euclidean" => Ok(Self::L2),
100            "cosine" => Ok(Self::Cosine),
101            "dot" => Ok(Self::Dot),
102            "hamming" => Ok(Self::Hamming),
103            _ => Err(ArrowError::InvalidArgumentError(format!(
104                "Metric type '{s}' is not supported"
105            ))),
106        }
107    }
108}
109
110pub fn multivec_distance(
111    query: &dyn Array,
112    vectors: &ListArray,
113    distance_type: DistanceType,
114) -> Result<Vec<f32>> {
115    let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() {
116        dim as usize
117    } else {
118        return Err(ArrowError::InvalidArgumentError(
119            "vectors must be a list of fixed size list".to_string(),
120        ));
121    };
122
123    // check the query vectors type first
124    // because we don't want to check the vectors type for each vector
125    match query.data_type() {
126        DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
127        _ => {
128            return Err(ArrowError::InvalidArgumentError(
129                "query must be a float array or binary array".to_string(),
130            ));
131        }
132    }
133
134    let mut dists = Vec::with_capacity(vectors.len());
135    for v in vectors.iter() {
136        match v {
137            None => dists.push(f32::NAN),
138            Some(v) => {
139                let multivector = v.as_fixed_size_list();
140                if multivector.len() == 0 {
141                    dists.push(f32::NAN);
142                    continue;
143                }
144
145                let sim = match distance_type {
146                    DistanceType::Hamming => {
147                        let query = query.as_primitive::<UInt8Type>().values();
148                        query
149                            .chunks_exact(dim)
150                            .map(|q| {
151                                multivector
152                                    .values()
153                                    .as_primitive::<UInt8Type>()
154                                    .values()
155                                    .chunks_exact(dim)
156                                    .map(|v| hamming::hamming(q, v))
157                                    .min_by(|a, b| a.partial_cmp(b).unwrap())
158                                    .unwrap()
159                            })
160                            .sum()
161                    }
162                    _ => match query.data_type() {
163                        DataType::Float16 => multivec_distance_impl::<Float16Type>(
164                            query,
165                            multivector,
166                            dim,
167                            distance_type,
168                        ),
169                        DataType::Float32 => multivec_distance_impl::<Float32Type>(
170                            query,
171                            multivector,
172                            dim,
173                            distance_type,
174                        ),
175                        DataType::Float64 => multivec_distance_impl::<Float64Type>(
176                            query,
177                            multivector,
178                            dim,
179                            distance_type,
180                        ),
181                        _ => unreachable!("missed to check query type"),
182                    },
183                };
184
185                dists.push(1.0 - sim);
186            }
187        }
188    }
189    Ok(dists)
190}
191
192fn multivec_distance_impl<T: ArrowPrimitiveType>(
193    query: &dyn Array,
194    multivector: &FixedSizeListArray,
195    dim: usize,
196    distance_type: DistanceType,
197) -> f32
198where
199    T::Native: L2 + Cosine + Dot,
200{
201    let query = query.as_primitive::<T>().values();
202    query
203        .chunks_exact(dim)
204        .map(|q| {
205            multivector
206                .values()
207                .as_primitive::<T>()
208                .values()
209                .chunks_exact(dim)
210                .map(|v| 1.0 - distance_type.func()(q, v))
211                .max_by(|a, b| a.total_cmp(b))
212                .unwrap()
213        })
214        .sum()
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    use std::sync::Arc;
222
223    use arrow_array::types::Float32Type;
224    use arrow_array::{Float32Array, ListArray};
225    use arrow_buffer::OffsetBuffer;
226    use arrow_schema::Field;
227
228    #[test]
229    fn test_multivec_distance_empty_row_is_nan() {
230        let query: Arc<dyn Array> = Arc::new(Float32Array::from_iter_values([1.0_f32, 2.0]));
231
232        let dim = 2;
233        let values = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
234            vec![Some(vec![Some(1.0_f32), Some(2.0)])],
235            dim,
236        );
237
238        // Two rows: first is empty list, second has one sub-vector.
239        let offsets = OffsetBuffer::from_lengths([0_usize, 1]);
240        let field = Arc::new(Field::new("item", values.data_type().clone(), true));
241        let vectors = ListArray::try_new(field, offsets, Arc::new(values), None).unwrap();
242
243        let dists = multivec_distance(query.as_ref(), &vectors, DistanceType::Dot).unwrap();
244        assert_eq!(dists.len(), 2);
245        assert!(dists[0].is_nan());
246        assert_eq!(dists[1], -4.0);
247    }
248}