#![warn(missing_debug_implementations, missing_docs)]
use std::str::FromStr;
#[repr(C)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Metric {
Cosine,
InnerProduct,
L2,
CosineNormalized,
}
impl Metric {
pub const fn as_str(self) -> &'static str {
match self {
Metric::Cosine => "cosine",
Metric::InnerProduct => "innerproduct",
Metric::L2 => "l2",
Metric::CosineNormalized => "cosinenormalized",
}
}
}
impl From<Metric> for i32 {
fn from(metric: Metric) -> Self {
metric as i32
}
}
impl TryFrom<i32> for Metric {
type Error = TryFromMetricError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
x if x == Metric::Cosine.into() => Ok(Metric::Cosine),
x if x == Metric::InnerProduct.into() => Ok(Metric::InnerProduct),
x if x == Metric::L2.into() => Ok(Metric::L2),
x if x == Metric::CosineNormalized.into() => Ok(Metric::CosineNormalized),
_ => Err(TryFromMetricError(value)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TryFromMetricError(pub i32);
impl std::fmt::Display for TryFromMetricError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid Metric discriminant: {}", self.0)
}
}
impl std::error::Error for TryFromMetricError {}
impl std::fmt::Display for Metric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug)]
pub enum ParseMetricError {
InvalidFormat(String),
}
impl std::fmt::Display for ParseMetricError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidFormat(str) => write!(f, "Invalid format for Metric: {}", str),
}
}
}
impl std::error::Error for ParseMetricError {}
impl FromStr for Metric {
type Err = ParseMetricError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
x if x == Metric::L2.as_str() => Ok(Metric::L2),
x if x == Metric::Cosine.as_str() => Ok(Metric::Cosine),
x if x == Metric::InnerProduct.as_str() => Ok(Metric::InnerProduct),
x if x == Metric::CosineNormalized.as_str() => Ok(Metric::CosineNormalized),
_ => Err(ParseMetricError::InvalidFormat(String::from(s))),
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::{Metric, ParseMetricError, TryFromMetricError};
#[test]
fn test_metric_from_str() {
assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
assert_eq!(Metric::from_str("l2").unwrap(), Metric::L2);
assert_eq!(
Metric::from_str("innerproduct").unwrap(),
Metric::InnerProduct
);
assert_eq!(
Metric::from_str("cosinenormalized").unwrap(),
Metric::CosineNormalized
);
assert_eq!(
Metric::from_str("invalid").unwrap_err().to_string(),
ParseMetricError::InvalidFormat(String::from("invalid")).to_string()
);
}
#[test]
fn test_metric_to_i32() {
assert_eq!(i32::from(Metric::Cosine), 0);
assert_eq!(i32::from(Metric::InnerProduct), 1);
assert_eq!(i32::from(Metric::L2), 2);
assert_eq!(i32::from(Metric::CosineNormalized), 3);
}
#[test]
fn test_metric_try_from_i32() {
assert_eq!(Metric::try_from(0), Ok(Metric::Cosine));
assert_eq!(Metric::try_from(1), Ok(Metric::InnerProduct));
assert_eq!(Metric::try_from(2), Ok(Metric::L2));
assert_eq!(Metric::try_from(3), Ok(Metric::CosineNormalized));
assert_eq!(Metric::try_from(-1), Err(TryFromMetricError(-1)));
assert_eq!(Metric::try_from(4), Err(TryFromMetricError(4)));
}
}