diskann_vector/distance/
metric.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5#![warn(missing_debug_implementations, missing_docs)]
6use std::str::FromStr;
7
8#[repr(C)]
9/// Distance metric
10#[derive(Debug, PartialEq, Eq, Clone, Copy)]
11pub enum Metric {
12    /// Cosine similarity
13    Cosine,
14    /// Inner product
15    InnerProduct,
16    /// Squared Euclidean (L2-Squared)
17    L2,
18    /// Normalized Cosine Similarity
19    CosineNormalized,
20}
21
22impl Metric {
23    /// Returns the string representation of the metric.
24    pub const fn as_str(self) -> &'static str {
25        match self {
26            Metric::Cosine => "cosine",
27            Metric::InnerProduct => "innerproduct",
28            Metric::L2 => "l2",
29            Metric::CosineNormalized => "cosinenormalized",
30        }
31    }
32}
33
34impl std::fmt::Display for Metric {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.write_str(self.as_str())
37    }
38}
39
40#[derive(Debug)]
41pub enum ParseMetricError {
42    InvalidFormat(String),
43}
44
45impl std::fmt::Display for ParseMetricError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::InvalidFormat(str) => write!(f, "Invalid format for Metric: {}", str),
49        }
50    }
51}
52
53impl std::error::Error for ParseMetricError {}
54
55impl FromStr for Metric {
56    type Err = ParseMetricError;
57
58    fn from_str(s: &str) -> Result<Self, Self::Err> {
59        match s.to_lowercase().as_str() {
60            x if x == Metric::L2.as_str() => Ok(Metric::L2),
61            x if x == Metric::Cosine.as_str() => Ok(Metric::Cosine),
62            x if x == Metric::InnerProduct.as_str() => Ok(Metric::InnerProduct),
63            x if x == Metric::CosineNormalized.as_str() => Ok(Metric::CosineNormalized),
64            _ => Err(ParseMetricError::InvalidFormat(String::from(s))),
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use std::str::FromStr;
72
73    use super::{Metric, ParseMetricError};
74
75    #[test]
76    fn test_metric_from_str() {
77        assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
78        assert_eq!(Metric::from_str("l2").unwrap(), Metric::L2);
79        assert_eq!(
80            Metric::from_str("innerproduct").unwrap(),
81            Metric::InnerProduct
82        );
83        assert_eq!(
84            Metric::from_str("cosinenormalized").unwrap(),
85            Metric::CosineNormalized
86        );
87        assert_eq!(
88            Metric::from_str("invalid").unwrap_err().to_string(),
89            ParseMetricError::InvalidFormat(String::from("invalid")).to_string()
90        );
91    }
92}