use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
#[repr(u8)]
pub enum DistanceMetric {
#[default]
Cosine = 1,
Dot = 2,
L2 = 3,
}
impl DistanceMetric {
#[inline]
pub const fn as_byte(self) -> u8 {
self as u8
}
#[inline]
pub const fn from_byte(b: u8) -> Option<Self> {
match b {
1 => Some(Self::Cosine),
2 => Some(Self::Dot),
3 => Some(Self::L2),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
#[repr(u8)]
pub enum VectorDType {
#[default]
F32 = 1,
F16 = 2,
I8 = 3,
}
impl VectorDType {
#[inline]
pub const fn as_byte(self) -> u8 {
self as u8
}
#[inline]
pub const fn size_bytes(self) -> usize {
match self {
Self::F32 => 4,
Self::F16 => 2,
Self::I8 => 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
#[repr(u8)]
pub enum VectorNorm {
#[default]
None = 0,
Unit = 1,
}
impl VectorNorm {
#[inline]
pub const fn as_byte(self) -> u8 {
self as u8
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EmbeddingKey {
pub model: Box<str>,
pub revision: Box<str>,
pub dims: u32,
pub metric: DistanceMetric,
pub dtype: VectorDType,
pub norm: VectorNorm,
}
impl EmbeddingKey {
pub fn new(
model: impl Into<Box<str>>,
revision: impl Into<Box<str>>,
dims: u32,
metric: DistanceMetric,
dtype: VectorDType,
norm: VectorNorm,
) -> Self {
Self {
model: model.into(),
revision: revision.into(),
dims,
metric,
dtype,
norm,
}
}
pub fn canonical_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
let model_bytes = self.model.as_bytes();
buf.extend_from_slice(&(model_bytes.len() as u32).to_be_bytes());
buf.extend_from_slice(model_bytes);
let rev_bytes = self.revision.as_bytes();
buf.extend_from_slice(&(rev_bytes.len() as u32).to_be_bytes());
buf.extend_from_slice(rev_bytes);
buf.extend_from_slice(&self.dims.to_be_bytes());
buf.push(self.metric.as_byte());
buf.push(self.dtype.as_byte());
buf.push(self.norm.as_byte());
buf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_byte_roundtrip() {
for (metric, byte) in [
(DistanceMetric::Cosine, 1u8),
(DistanceMetric::Dot, 2u8),
(DistanceMetric::L2, 3u8),
] {
assert_eq!(metric.as_byte(), byte);
assert_eq!(DistanceMetric::from_byte(byte), Some(metric));
}
assert_eq!(DistanceMetric::from_byte(99), None);
}
#[test]
fn test_vector_dtype_size_bytes() {
assert_eq!(VectorDType::F32.size_bytes(), 4);
assert_eq!(VectorDType::F16.size_bytes(), 2);
assert_eq!(VectorDType::I8.size_bytes(), 1);
}
#[test]
fn test_vector_norm_defaults() {
assert_eq!(VectorNorm::default(), VectorNorm::None);
}
#[test]
fn test_embedding_key_canonical_bytes_deterministic() {
let k1 = EmbeddingKey::new(
"bge-small-en-v1.5",
"v1.5",
384,
DistanceMetric::Cosine,
VectorDType::F32,
VectorNorm::Unit,
);
let k2 = EmbeddingKey::new(
"bge-small-en-v1.5",
"v1.5",
384,
DistanceMetric::Cosine,
VectorDType::F32,
VectorNorm::Unit,
);
assert_eq!(k1.canonical_bytes(), k2.canonical_bytes());
}
#[test]
fn test_embedding_key_canonical_bytes_differs_by_field() {
let k1 = EmbeddingKey::new(
"model-a",
"v1",
384,
DistanceMetric::Cosine,
VectorDType::F32,
VectorNorm::Unit,
);
let k2 = EmbeddingKey::new(
"model-b",
"v1",
384,
DistanceMetric::Cosine,
VectorDType::F32,
VectorNorm::Unit,
);
assert_ne!(k1.canonical_bytes(), k2.canonical_bytes());
let k3 = EmbeddingKey::new(
"model-a",
"v1",
768,
DistanceMetric::Cosine,
VectorDType::F32,
VectorNorm::Unit,
);
assert_ne!(k1.canonical_bytes(), k3.canonical_bytes());
}
}