lance_linalg/
distance.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Distance metrics
5//!
6//! This module provides distance metrics for vectors.
7//!
8//! - `bf16, f16, f32, f64` types are supported.
9//! - SIMD is used when available, on `x86_64`, `aarch64` and `loongarch64`
10//!   architectures.
11
12use std::sync::Arc;
13
14use arrow_array::cast::AsArray;
15use arrow_array::types::{Float16Type, Float32Type, Float64Type, UInt8Type};
16use arrow_array::{Array, ArrowPrimitiveType, FixedSizeListArray, Float32Array, ListArray};
17use arrow_schema::{ArrowError, DataType};
18
19pub mod cosine;
20pub mod dot;
21pub mod hamming;
22pub mod l2;
23pub mod norm_l2;
24
25pub use cosine::*;
26use deepsize::DeepSizeOf;
27pub use dot::*;
28use hamming::hamming_distance_arrow_batch;
29pub use l2::*;
30pub use norm_l2::*;
31
32use crate::Result;
33
34/// Distance metrics type.
35#[derive(Debug, Copy, Clone, PartialEq, DeepSizeOf)]
36pub enum DistanceType {
37    L2,
38    Cosine,
39    /// Dot Product
40    Dot,
41    /// Hamming Distance
42    Hamming,
43}
44
45/// For backwards compatibility.
46pub type MetricType = DistanceType;
47
48pub type DistanceFunc<T> = fn(&[T], &[T]) -> f32;
49pub type BatchDistanceFunc = fn(&[f32], &[f32], usize) -> Arc<Float32Array>;
50pub type ArrowBatchDistanceFunc = fn(&dyn Array, &FixedSizeListArray) -> Result<Arc<Float32Array>>;
51
52impl DistanceType {
53    /// Compute the distance from one vector to a batch of vectors.
54    ///
55    /// This propagates nulls to the output.
56    pub fn arrow_batch_func(&self) -> ArrowBatchDistanceFunc {
57        match self {
58            Self::L2 => l2_distance_arrow_batch,
59            Self::Cosine => cosine_distance_arrow_batch,
60            Self::Dot => dot_distance_arrow_batch,
61            Self::Hamming => hamming_distance_arrow_batch,
62        }
63    }
64
65    /// Returns the distance function between two vectors.
66    pub fn func<T: L2 + Cosine + Dot>(&self) -> DistanceFunc<T> {
67        match self {
68            Self::L2 => l2,
69            Self::Cosine => cosine_distance,
70            Self::Dot => dot_distance,
71            Self::Hamming => todo!(),
72        }
73    }
74}
75
76impl std::fmt::Display for DistanceType {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(
79            f,
80            "{}",
81            match self {
82                Self::L2 => "l2",
83                Self::Cosine => "cosine",
84                Self::Dot => "dot",
85                Self::Hamming => "hamming",
86            }
87        )
88    }
89}
90
91impl TryFrom<&str> for DistanceType {
92    type Error = ArrowError;
93
94    fn try_from(s: &str) -> std::result::Result<Self, Self::Error> {
95        match s.to_lowercase().as_str() {
96            "l2" | "euclidean" => Ok(Self::L2),
97            "cosine" => Ok(Self::Cosine),
98            "dot" => Ok(Self::Dot),
99            "hamming" => Ok(Self::Hamming),
100            _ => Err(ArrowError::InvalidArgumentError(format!(
101                "Metric type '{s}' is not supported"
102            ))),
103        }
104    }
105}
106
107pub fn multivec_distance(
108    query: &dyn Array,
109    vectors: &ListArray,
110    distance_type: DistanceType,
111) -> Result<Vec<f32>> {
112    let dim = if let DataType::FixedSizeList(_, dim) = vectors.value_type() {
113        dim as usize
114    } else {
115        return Err(ArrowError::InvalidArgumentError(
116            "vectors must be a list of fixed size list".to_string(),
117        ));
118    };
119
120    // check the query vectors type first
121    // because we don't want to check the vectors type for each vector
122    match query.data_type() {
123        DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::UInt8 => {}
124        _ => {
125            return Err(ArrowError::InvalidArgumentError(
126                "query must be a float array or binary array".to_string(),
127            ));
128        }
129    }
130
131    let dists = vectors
132        .iter()
133        .map(|v| {
134            v.map(|v| {
135                let multivector = v.as_fixed_size_list();
136                match distance_type {
137                    DistanceType::Hamming => {
138                        let query = query.as_primitive::<UInt8Type>().values();
139                        query
140                            .chunks_exact(dim)
141                            .map(|q| {
142                                multivector
143                                    .values()
144                                    .as_primitive::<UInt8Type>()
145                                    .values()
146                                    .chunks_exact(dim)
147                                    .map(|v| hamming::hamming(q, v))
148                                    .min_by(|a, b| a.partial_cmp(b).unwrap())
149                                    .unwrap()
150                            })
151                            .sum()
152                    }
153                    _ => match query.data_type() {
154                        DataType::Float16 => multivec_distance_impl::<Float16Type>(
155                            query,
156                            multivector,
157                            dim,
158                            distance_type,
159                        ),
160                        DataType::Float32 => multivec_distance_impl::<Float32Type>(
161                            query,
162                            multivector,
163                            dim,
164                            distance_type,
165                        ),
166                        DataType::Float64 => multivec_distance_impl::<Float64Type>(
167                            query,
168                            multivector,
169                            dim,
170                            distance_type,
171                        ),
172                        _ => unreachable!("missed to check query type"),
173                    },
174                }
175            })
176            .unwrap_or(f32::NAN)
177        })
178        .map(|sim| 1.0 - sim)
179        .collect();
180    Ok(dists)
181}
182
183fn multivec_distance_impl<T: ArrowPrimitiveType>(
184    query: &dyn Array,
185    multivector: &FixedSizeListArray,
186    dim: usize,
187    distance_type: DistanceType,
188) -> f32
189where
190    T::Native: L2 + Cosine + Dot,
191{
192    let query = query.as_primitive::<T>().values();
193    query
194        .chunks_exact(dim)
195        .map(|q| {
196            multivector
197                .values()
198                .as_primitive::<T>()
199                .values()
200                .chunks_exact(dim)
201                .map(|v| 1.0 - distance_type.func()(q, v))
202                .max_by(|a, b| a.total_cmp(b))
203                .unwrap()
204        })
205        .sum()
206}