Skip to main content

use_ml_embedding/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8    pub use crate::{
9        EmbeddingDimension, EmbeddingDistanceMetric, EmbeddingError, EmbeddingIndexKind,
10        EmbeddingModality, EmbeddingModelName, EmbeddingNormalizationKind, EmbeddingSearchKind,
11        EmbeddingVectorFormat, EmbeddingVectorId, EmbeddingVectorShape,
12    };
13}
14
15macro_rules! embedding_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, EmbeddingError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = EmbeddingError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = EmbeddingError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58}
59
60macro_rules! embedding_enum {
61    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63        pub enum $name {
64            $($variant),+
65        }
66
67        impl $name {
68            pub const fn as_str(self) -> &'static str {
69                match self {
70                    $(Self::$variant => $label),+
71                }
72            }
73        }
74
75        impl fmt::Display for $name {
76            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77                formatter.write_str(self.as_str())
78            }
79        }
80
81        impl FromStr for $name {
82            type Err = EmbeddingError;
83
84            fn from_str(value: &str) -> Result<Self, Self::Err> {
85                match normalized_label(value)?.as_str() {
86                    $($label => Ok(Self::$variant),)+
87                    _ => Err(EmbeddingError::UnknownLabel),
88                }
89            }
90        }
91    };
92}
93
94embedding_text_newtype!(EmbeddingModelName);
95embedding_text_newtype!(EmbeddingVectorId);
96
97#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
98pub struct EmbeddingDimension(NonZeroUsize);
99
100impl EmbeddingDimension {
101    pub fn new(value: usize) -> Result<Self, EmbeddingError> {
102        NonZeroUsize::new(value)
103            .map(Self)
104            .ok_or(EmbeddingError::Zero)
105    }
106
107    pub const fn get(self) -> usize {
108        self.0.get()
109    }
110}
111
112#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113pub struct EmbeddingVectorShape {
114    dimension: EmbeddingDimension,
115}
116
117impl EmbeddingVectorShape {
118    pub const fn new(dimension: EmbeddingDimension) -> Self {
119        Self { dimension }
120    }
121
122    pub const fn dimension(self) -> EmbeddingDimension {
123        self.dimension
124    }
125}
126
127embedding_enum!(EmbeddingModality {
128    Text => "text",
129    Image => "image",
130    Audio => "audio",
131    Video => "video",
132    Code => "code",
133    Tabular => "tabular",
134    Graph => "graph",
135    Multimodal => "multimodal",
136    Other => "other",
137});
138
139embedding_enum!(EmbeddingDistanceMetric {
140    Cosine => "cosine",
141    DotProduct => "dot-product",
142    Euclidean => "euclidean",
143    Manhattan => "manhattan",
144    Hamming => "hamming",
145    Jaccard => "jaccard",
146    Custom => "custom",
147});
148
149embedding_enum!(EmbeddingNormalizationKind {
150    None => "none",
151    Unit => "unit",
152    MeanCentered => "mean-centered",
153    Standardized => "standardized",
154    Custom => "custom",
155});
156
157embedding_enum!(EmbeddingIndexKind {
158    Flat => "flat",
159    Hnsw => "hnsw",
160    Ivf => "ivf",
161    Pq => "pq",
162    IvfPq => "ivf-pq",
163    Annoy => "annoy",
164    Scann => "scann",
165    Other => "other",
166});
167
168embedding_enum!(EmbeddingSearchKind {
169    Exact => "exact",
170    Approximate => "approximate",
171    Hybrid => "hybrid",
172    Filtered => "filtered",
173    Reranked => "reranked",
174});
175
176embedding_enum!(EmbeddingVectorFormat {
177    Dense => "dense",
178    Sparse => "sparse",
179    Binary => "binary",
180    Quantized => "quantized",
181    Mixed => "mixed",
182});
183
184#[derive(Clone, Copy, Debug, Eq, PartialEq)]
185pub enum EmbeddingError {
186    Empty,
187    Zero,
188    UnknownLabel,
189}
190
191impl fmt::Display for EmbeddingError {
192    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
193        match self {
194            Self::Empty => formatter.write_str("embedding metadata text cannot be empty"),
195            Self::Zero => formatter.write_str("embedding dimension must be positive"),
196            Self::UnknownLabel => formatter.write_str("unknown embedding metadata label"),
197        }
198    }
199}
200
201impl Error for EmbeddingError {}
202
203fn non_empty_text(value: impl AsRef<str>) -> Result<String, EmbeddingError> {
204    let trimmed = value.as_ref().trim();
205    if trimmed.is_empty() {
206        Err(EmbeddingError::Empty)
207    } else {
208        Ok(trimmed.to_string())
209    }
210}
211
212fn normalized_label(value: &str) -> Result<String, EmbeddingError> {
213    let trimmed = value.trim();
214    if trimmed.is_empty() {
215        Err(EmbeddingError::Empty)
216    } else {
217        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{
224        EmbeddingDimension, EmbeddingDistanceMetric, EmbeddingError, EmbeddingIndexKind,
225        EmbeddingModelName, EmbeddingNormalizationKind, EmbeddingVectorShape,
226    };
227
228    #[test]
229    fn validates_embedding_names_and_dimensions() -> Result<(), EmbeddingError> {
230        let model = EmbeddingModelName::new(" text-embedding ")?;
231        let dimension = EmbeddingDimension::new(384)?;
232        let shape = EmbeddingVectorShape::new(dimension);
233
234        assert_eq!(model.as_str(), "text-embedding");
235        assert_eq!(shape.dimension().get(), 384);
236        assert_eq!(EmbeddingDimension::new(0), Err(EmbeddingError::Zero));
237        Ok(())
238    }
239
240    #[test]
241    fn displays_and_parses_embedding_enums() -> Result<(), EmbeddingError> {
242        assert_eq!(
243            "dot product".parse::<EmbeddingDistanceMetric>()?,
244            EmbeddingDistanceMetric::DotProduct
245        );
246        assert_eq!(
247            "mean centered".parse::<EmbeddingNormalizationKind>()?,
248            EmbeddingNormalizationKind::MeanCentered
249        );
250        assert_eq!(
251            "ivf pq".parse::<EmbeddingIndexKind>()?,
252            EmbeddingIndexKind::IvfPq
253        );
254        Ok(())
255    }
256}