Skip to main content

diskann_vector/distance/
metric.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5#![warn(missing_debug_implementations, missing_docs)]
6use std::str::FromStr;
7
8#[repr(C)]
9/// Distance metric
10#[derive(Debug, PartialEq, Eq, Clone, Copy)]
11pub enum Metric {
12    /// Cosine similarity
13    Cosine,
14    /// Inner product
15    InnerProduct,
16    /// Squared Euclidean (L2-Squared)
17    L2,
18    /// Normalized Cosine Similarity
19    CosineNormalized,
20}
21
22impl Metric {
23    /// Returns the string representation of the metric.
24    pub const fn as_str(self) -> &'static str {
25        match self {
26            Metric::Cosine => "cosine",
27            Metric::InnerProduct => "innerproduct",
28            Metric::L2 => "l2",
29            Metric::CosineNormalized => "cosinenormalized",
30        }
31    }
32}
33
34impl From<Metric> for i32 {
35    fn from(metric: Metric) -> Self {
36        metric as i32
37    }
38}
39
40impl TryFrom<i32> for Metric {
41    type Error = TryFromMetricError;
42
43    fn try_from(value: i32) -> Result<Self, Self::Error> {
44        match value {
45            x if x == Metric::Cosine.into() => Ok(Metric::Cosine),
46            x if x == Metric::InnerProduct.into() => Ok(Metric::InnerProduct),
47            x if x == Metric::L2.into() => Ok(Metric::L2),
48            x if x == Metric::CosineNormalized.into() => Ok(Metric::CosineNormalized),
49            _ => Err(TryFromMetricError(value)),
50        }
51    }
52}
53
54/// Error returned when an `i32` value does not correspond to a valid [`Metric`] variant.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct TryFromMetricError(pub i32);
57
58impl std::fmt::Display for TryFromMetricError {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "invalid Metric discriminant: {}", self.0)
61    }
62}
63
64impl std::error::Error for TryFromMetricError {}
65
66impl std::fmt::Display for Metric {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.write_str(self.as_str())
69    }
70}
71
72#[derive(Debug)]
73pub enum ParseMetricError {
74    InvalidFormat(String),
75}
76
77impl std::fmt::Display for ParseMetricError {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        match self {
80            Self::InvalidFormat(str) => write!(f, "Invalid format for Metric: {}", str),
81        }
82    }
83}
84
85impl std::error::Error for ParseMetricError {}
86
87impl FromStr for Metric {
88    type Err = ParseMetricError;
89
90    fn from_str(s: &str) -> Result<Self, Self::Err> {
91        match s.to_lowercase().as_str() {
92            x if x == Metric::L2.as_str() => Ok(Metric::L2),
93            x if x == Metric::Cosine.as_str() => Ok(Metric::Cosine),
94            x if x == Metric::InnerProduct.as_str() => Ok(Metric::InnerProduct),
95            x if x == Metric::CosineNormalized.as_str() => Ok(Metric::CosineNormalized),
96            _ => Err(ParseMetricError::InvalidFormat(String::from(s))),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use std::str::FromStr;
104
105    use super::{Metric, ParseMetricError, TryFromMetricError};
106
107    #[test]
108    fn test_metric_from_str() {
109        assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
110        assert_eq!(Metric::from_str("l2").unwrap(), Metric::L2);
111        assert_eq!(
112            Metric::from_str("innerproduct").unwrap(),
113            Metric::InnerProduct
114        );
115        assert_eq!(
116            Metric::from_str("cosinenormalized").unwrap(),
117            Metric::CosineNormalized
118        );
119        assert_eq!(
120            Metric::from_str("invalid").unwrap_err().to_string(),
121            ParseMetricError::InvalidFormat(String::from("invalid")).to_string()
122        );
123    }
124
125    #[test]
126    fn test_metric_to_i32() {
127        assert_eq!(i32::from(Metric::Cosine), 0);
128        assert_eq!(i32::from(Metric::InnerProduct), 1);
129        assert_eq!(i32::from(Metric::L2), 2);
130        assert_eq!(i32::from(Metric::CosineNormalized), 3);
131    }
132
133    #[test]
134    fn test_metric_try_from_i32() {
135        assert_eq!(Metric::try_from(0), Ok(Metric::Cosine));
136        assert_eq!(Metric::try_from(1), Ok(Metric::InnerProduct));
137        assert_eq!(Metric::try_from(2), Ok(Metric::L2));
138        assert_eq!(Metric::try_from(3), Ok(Metric::CosineNormalized));
139        assert_eq!(Metric::try_from(-1), Err(TryFromMetricError(-1)));
140        assert_eq!(Metric::try_from(4), Err(TryFromMetricError(4)));
141    }
142}