1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
15#[serde(rename_all = "snake_case")]
16#[non_exhaustive]
17#[repr(u8)]
18pub enum DistanceMetric {
19 #[default]
21 Cosine = 1,
22 Dot = 2,
24 L2 = 3,
26}
27
28impl DistanceMetric {
29 #[inline]
31 pub const fn as_byte(self) -> u8 {
32 self as u8
33 }
34
35 #[inline]
37 pub const fn from_byte(b: u8) -> Option<Self> {
38 match b {
39 1 => Some(Self::Cosine),
40 2 => Some(Self::Dot),
41 3 => Some(Self::L2),
42 _ => None,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
53#[serde(rename_all = "snake_case")]
54#[non_exhaustive]
55#[repr(u8)]
56pub enum VectorDType {
57 #[default]
59 F32 = 1,
60 F16 = 2,
62 I8 = 3,
64}
65
66impl VectorDType {
67 #[inline]
69 pub const fn as_byte(self) -> u8 {
70 self as u8
71 }
72
73 #[inline]
75 pub const fn size_bytes(self) -> usize {
76 match self {
77 Self::F32 => 4,
78 Self::F16 => 2,
79 Self::I8 => 1,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
90#[serde(rename_all = "snake_case")]
91#[non_exhaustive]
92#[repr(u8)]
93pub enum VectorNorm {
94 #[default]
96 None = 0,
97 Unit = 1,
99}
100
101impl VectorNorm {
102 #[inline]
104 pub const fn as_byte(self) -> u8 {
105 self as u8
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
118pub struct EmbeddingKey {
119 pub model: Box<str>,
121 pub revision: Box<str>,
123 pub dims: u32,
125 pub metric: DistanceMetric,
127 pub dtype: VectorDType,
129 pub norm: VectorNorm,
131}
132
133impl EmbeddingKey {
134 pub fn new(
136 model: impl Into<Box<str>>,
137 revision: impl Into<Box<str>>,
138 dims: u32,
139 metric: DistanceMetric,
140 dtype: VectorDType,
141 norm: VectorNorm,
142 ) -> Self {
143 Self {
144 model: model.into(),
145 revision: revision.into(),
146 dims,
147 metric,
148 dtype,
149 norm,
150 }
151 }
152
153 pub fn canonical_bytes(&self) -> Vec<u8> {
163 let mut buf = Vec::new();
164
165 let model_bytes = self.model.as_bytes();
166 buf.extend_from_slice(&(model_bytes.len() as u32).to_be_bytes());
167 buf.extend_from_slice(model_bytes);
168
169 let rev_bytes = self.revision.as_bytes();
170 buf.extend_from_slice(&(rev_bytes.len() as u32).to_be_bytes());
171 buf.extend_from_slice(rev_bytes);
172
173 buf.extend_from_slice(&self.dims.to_be_bytes());
174 buf.push(self.metric.as_byte());
175 buf.push(self.dtype.as_byte());
176 buf.push(self.norm.as_byte());
177
178 buf
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_distance_metric_byte_roundtrip() {
188 for (metric, byte) in [
189 (DistanceMetric::Cosine, 1u8),
190 (DistanceMetric::Dot, 2u8),
191 (DistanceMetric::L2, 3u8),
192 ] {
193 assert_eq!(metric.as_byte(), byte);
194 assert_eq!(DistanceMetric::from_byte(byte), Some(metric));
195 }
196 assert_eq!(DistanceMetric::from_byte(99), None);
197 }
198
199 #[test]
200 fn test_vector_dtype_size_bytes() {
201 assert_eq!(VectorDType::F32.size_bytes(), 4);
202 assert_eq!(VectorDType::F16.size_bytes(), 2);
203 assert_eq!(VectorDType::I8.size_bytes(), 1);
204 }
205
206 #[test]
207 fn test_vector_norm_defaults() {
208 assert_eq!(VectorNorm::default(), VectorNorm::None);
209 }
210
211 #[test]
212 fn test_embedding_key_canonical_bytes_deterministic() {
213 let k1 = EmbeddingKey::new(
214 "bge-small-en-v1.5",
215 "v1.5",
216 384,
217 DistanceMetric::Cosine,
218 VectorDType::F32,
219 VectorNorm::Unit,
220 );
221 let k2 = EmbeddingKey::new(
222 "bge-small-en-v1.5",
223 "v1.5",
224 384,
225 DistanceMetric::Cosine,
226 VectorDType::F32,
227 VectorNorm::Unit,
228 );
229 assert_eq!(k1.canonical_bytes(), k2.canonical_bytes());
230 }
231
232 #[test]
233 fn test_embedding_key_canonical_bytes_differs_by_field() {
234 let k1 = EmbeddingKey::new(
235 "model-a",
236 "v1",
237 384,
238 DistanceMetric::Cosine,
239 VectorDType::F32,
240 VectorNorm::Unit,
241 );
242 let k2 = EmbeddingKey::new(
243 "model-b",
244 "v1",
245 384,
246 DistanceMetric::Cosine,
247 VectorDType::F32,
248 VectorNorm::Unit,
249 );
250 assert_ne!(k1.canonical_bytes(), k2.canonical_bytes());
251
252 let k3 = EmbeddingKey::new(
253 "model-a",
254 "v1",
255 768,
256 DistanceMetric::Cosine,
257 VectorDType::F32,
258 VectorNorm::Unit,
259 );
260 assert_ne!(k1.canonical_bytes(), k3.canonical_bytes());
261 }
262}