diskann_vector/distance/
metric.rs1#![warn(missing_debug_implementations, missing_docs)]
6use std::str::FromStr;
7
8#[repr(C)]
9#[derive(Debug, PartialEq, Eq, Clone, Copy)]
11pub enum Metric {
12 Cosine,
14 InnerProduct,
16 L2,
18 CosineNormalized,
20}
21
22impl Metric {
23 pub const fn as_str(self) -> &'static str {
25 match self {
26 Metric::Cosine => "cosine",
27 Metric::InnerProduct => "innerproduct",
28 Metric::L2 => "l2",
29 Metric::CosineNormalized => "cosinenormalized",
30 }
31 }
32}
33
34impl From<Metric> for i32 {
35 fn from(metric: Metric) -> Self {
36 metric as i32
37 }
38}
39
40impl TryFrom<i32> for Metric {
41 type Error = TryFromMetricError;
42
43 fn try_from(value: i32) -> Result<Self, Self::Error> {
44 match value {
45 x if x == Metric::Cosine.into() => Ok(Metric::Cosine),
46 x if x == Metric::InnerProduct.into() => Ok(Metric::InnerProduct),
47 x if x == Metric::L2.into() => Ok(Metric::L2),
48 x if x == Metric::CosineNormalized.into() => Ok(Metric::CosineNormalized),
49 _ => Err(TryFromMetricError(value)),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct TryFromMetricError(pub i32);
57
58impl std::fmt::Display for TryFromMetricError {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "invalid Metric discriminant: {}", self.0)
61 }
62}
63
64impl std::error::Error for TryFromMetricError {}
65
66impl std::fmt::Display for Metric {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.write_str(self.as_str())
69 }
70}
71
72#[derive(Debug)]
73pub enum ParseMetricError {
74 InvalidFormat(String),
75}
76
77impl std::fmt::Display for ParseMetricError {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 match self {
80 Self::InvalidFormat(str) => write!(f, "Invalid format for Metric: {}", str),
81 }
82 }
83}
84
85impl std::error::Error for ParseMetricError {}
86
87impl FromStr for Metric {
88 type Err = ParseMetricError;
89
90 fn from_str(s: &str) -> Result<Self, Self::Err> {
91 match s.to_lowercase().as_str() {
92 x if x == Metric::L2.as_str() => Ok(Metric::L2),
93 x if x == Metric::Cosine.as_str() => Ok(Metric::Cosine),
94 x if x == Metric::InnerProduct.as_str() => Ok(Metric::InnerProduct),
95 x if x == Metric::CosineNormalized.as_str() => Ok(Metric::CosineNormalized),
96 _ => Err(ParseMetricError::InvalidFormat(String::from(s))),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use std::str::FromStr;
104
105 use super::{Metric, ParseMetricError, TryFromMetricError};
106
107 #[test]
108 fn test_metric_from_str() {
109 assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
110 assert_eq!(Metric::from_str("l2").unwrap(), Metric::L2);
111 assert_eq!(
112 Metric::from_str("innerproduct").unwrap(),
113 Metric::InnerProduct
114 );
115 assert_eq!(
116 Metric::from_str("cosinenormalized").unwrap(),
117 Metric::CosineNormalized
118 );
119 assert_eq!(
120 Metric::from_str("invalid").unwrap_err().to_string(),
121 ParseMetricError::InvalidFormat(String::from("invalid")).to_string()
122 );
123 }
124
125 #[test]
126 fn test_metric_to_i32() {
127 assert_eq!(i32::from(Metric::Cosine), 0);
128 assert_eq!(i32::from(Metric::InnerProduct), 1);
129 assert_eq!(i32::from(Metric::L2), 2);
130 assert_eq!(i32::from(Metric::CosineNormalized), 3);
131 }
132
133 #[test]
134 fn test_metric_try_from_i32() {
135 assert_eq!(Metric::try_from(0), Ok(Metric::Cosine));
136 assert_eq!(Metric::try_from(1), Ok(Metric::InnerProduct));
137 assert_eq!(Metric::try_from(2), Ok(Metric::L2));
138 assert_eq!(Metric::try_from(3), Ok(Metric::CosineNormalized));
139 assert_eq!(Metric::try_from(-1), Err(TryFromMetricError(-1)));
140 assert_eq!(Metric::try_from(4), Err(TryFromMetricError(4)));
141 }
142}