diskann_vector/distance/
metric.rs1#![warn(missing_debug_implementations, missing_docs)]
6use std::str::FromStr;
7
8#[repr(C)]
9#[derive(Debug, PartialEq, Eq, Clone, Copy)]
11pub enum Metric {
12 Cosine,
14 InnerProduct,
16 L2,
18 CosineNormalized,
20}
21
22impl Metric {
23 pub const fn as_str(self) -> &'static str {
25 match self {
26 Metric::Cosine => "cosine",
27 Metric::InnerProduct => "innerproduct",
28 Metric::L2 => "l2",
29 Metric::CosineNormalized => "cosinenormalized",
30 }
31 }
32}
33
34impl std::fmt::Display for Metric {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.write_str(self.as_str())
37 }
38}
39
40#[derive(Debug)]
41pub enum ParseMetricError {
42 InvalidFormat(String),
43}
44
45impl std::fmt::Display for ParseMetricError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Self::InvalidFormat(str) => write!(f, "Invalid format for Metric: {}", str),
49 }
50 }
51}
52
53impl std::error::Error for ParseMetricError {}
54
55impl FromStr for Metric {
56 type Err = ParseMetricError;
57
58 fn from_str(s: &str) -> Result<Self, Self::Err> {
59 match s.to_lowercase().as_str() {
60 x if x == Metric::L2.as_str() => Ok(Metric::L2),
61 x if x == Metric::Cosine.as_str() => Ok(Metric::Cosine),
62 x if x == Metric::InnerProduct.as_str() => Ok(Metric::InnerProduct),
63 x if x == Metric::CosineNormalized.as_str() => Ok(Metric::CosineNormalized),
64 _ => Err(ParseMetricError::InvalidFormat(String::from(s))),
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use std::str::FromStr;
72
73 use super::{Metric, ParseMetricError};
74
75 #[test]
76 fn test_metric_from_str() {
77 assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
78 assert_eq!(Metric::from_str("l2").unwrap(), Metric::L2);
79 assert_eq!(
80 Metric::from_str("innerproduct").unwrap(),
81 Metric::InnerProduct
82 );
83 assert_eq!(
84 Metric::from_str("cosinenormalized").unwrap(),
85 Metric::CosineNormalized
86 );
87 assert_eq!(
88 Metric::from_str("invalid").unwrap_err().to_string(),
89 ParseMetricError::InvalidFormat(String::from("invalid")).to_string()
90 );
91 }
92}