Skip to main content

lattice_embed/
types.rs

1//! ML-domain vector types local to lattice-embed.
2//!
3//! These types are defined here because the new lattice-types foundation crate
4//! only contains identity/policy/capability primitives and does not include
5//! vector configuration types. These are ML-domain concerns.
6
7use serde::{Deserialize, Serialize};
8
9// ============================================================================
10// DistanceMetric
11// ============================================================================
12
13/// Distance metric used for vector similarity search.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
15#[serde(rename_all = "snake_case")]
16#[non_exhaustive]
17#[repr(u8)]
18pub enum DistanceMetric {
19    /// Cosine similarity (1 - cosine distance).
20    #[default]
21    Cosine = 1,
22    /// Dot product (inner product).
23    Dot = 2,
24    /// Euclidean (L2) distance.
25    L2 = 3,
26}
27
28impl DistanceMetric {
29    /// Return the wire byte for this variant.
30    #[inline]
31    pub const fn as_byte(self) -> u8 {
32        self as u8
33    }
34
35    /// Reconstruct from a wire byte. Returns `None` for unknown values.
36    #[inline]
37    pub const fn from_byte(b: u8) -> Option<Self> {
38        match b {
39            1 => Some(Self::Cosine),
40            2 => Some(Self::Dot),
41            3 => Some(Self::L2),
42            _ => None,
43        }
44    }
45}
46
47// ============================================================================
48// VectorDType
49// ============================================================================
50
51/// Element data type for stored vectors.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
53#[serde(rename_all = "snake_case")]
54#[non_exhaustive]
55#[repr(u8)]
56pub enum VectorDType {
57    /// 32-bit float.
58    #[default]
59    F32 = 1,
60    /// 16-bit float (half precision).
61    F16 = 2,
62    /// 8-bit signed integer (quantized).
63    I8 = 3,
64}
65
66impl VectorDType {
67    /// Return the wire byte for this variant.
68    #[inline]
69    pub const fn as_byte(self) -> u8 {
70        self as u8
71    }
72
73    /// Size in bytes per element.
74    #[inline]
75    pub const fn size_bytes(self) -> usize {
76        match self {
77            Self::F32 => 4,
78            Self::F16 => 2,
79            Self::I8 => 1,
80        }
81    }
82}
83
84// ============================================================================
85// VectorNorm
86// ============================================================================
87
88/// Normalization state of stored vectors.
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
90#[serde(rename_all = "snake_case")]
91#[non_exhaustive]
92#[repr(u8)]
93pub enum VectorNorm {
94    /// No normalization applied.
95    #[default]
96    None = 0,
97    /// Normalized to unit length (L2 norm = 1).
98    Unit = 1,
99}
100
101impl VectorNorm {
102    /// Return the wire byte for this variant.
103    #[inline]
104    pub const fn as_byte(self) -> u8 {
105        self as u8
106    }
107}
108
109// ============================================================================
110// EmbeddingKey
111// ============================================================================
112
113/// Identifies an embedding space (model + revision + dims + metric + dtype + norm).
114///
115/// Used for selecting vector store collections, caching, and embedding migration routing.
116/// `canonical_bytes()` produces a stable hash for deduplication.
117#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
118pub struct EmbeddingKey {
119    /// Provider/model name (e.g., "bge-small-en-v1.5").
120    pub model: Box<str>,
121    /// Provider-specific revision (semver, date tag, or commit hash).
122    pub revision: Box<str>,
123    /// Vector dimensionality.
124    pub dims: u32,
125    /// Distance metric for similarity.
126    pub metric: DistanceMetric,
127    /// Element data type.
128    pub dtype: VectorDType,
129    /// Normalization state.
130    pub norm: VectorNorm,
131}
132
133impl EmbeddingKey {
134    /// Create a new `EmbeddingKey`.
135    pub fn new(
136        model: impl Into<Box<str>>,
137        revision: impl Into<Box<str>>,
138        dims: u32,
139        metric: DistanceMetric,
140        dtype: VectorDType,
141        norm: VectorNorm,
142    ) -> Self {
143        Self {
144            model: model.into(),
145            revision: revision.into(),
146            dims,
147            metric,
148            dtype,
149            norm,
150        }
151    }
152
153    /// Returns canonical bytes for deterministic hashing.
154    ///
155    /// Format:
156    /// - model (4-byte big-endian length prefix + UTF-8 bytes)
157    /// - revision (4-byte big-endian length prefix + UTF-8 bytes)
158    /// - dims (4 bytes, big-endian)
159    /// - metric (1 byte)
160    /// - dtype (1 byte)
161    /// - norm (1 byte)
162    pub fn canonical_bytes(&self) -> Vec<u8> {
163        let mut buf = Vec::new();
164
165        let model_bytes = self.model.as_bytes();
166        buf.extend_from_slice(&(model_bytes.len() as u32).to_be_bytes());
167        buf.extend_from_slice(model_bytes);
168
169        let rev_bytes = self.revision.as_bytes();
170        buf.extend_from_slice(&(rev_bytes.len() as u32).to_be_bytes());
171        buf.extend_from_slice(rev_bytes);
172
173        buf.extend_from_slice(&self.dims.to_be_bytes());
174        buf.push(self.metric.as_byte());
175        buf.push(self.dtype.as_byte());
176        buf.push(self.norm.as_byte());
177
178        buf
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_distance_metric_byte_roundtrip() {
188        for (metric, byte) in [
189            (DistanceMetric::Cosine, 1u8),
190            (DistanceMetric::Dot, 2u8),
191            (DistanceMetric::L2, 3u8),
192        ] {
193            assert_eq!(metric.as_byte(), byte);
194            assert_eq!(DistanceMetric::from_byte(byte), Some(metric));
195        }
196        assert_eq!(DistanceMetric::from_byte(99), None);
197    }
198
199    #[test]
200    fn test_vector_dtype_size_bytes() {
201        assert_eq!(VectorDType::F32.size_bytes(), 4);
202        assert_eq!(VectorDType::F16.size_bytes(), 2);
203        assert_eq!(VectorDType::I8.size_bytes(), 1);
204    }
205
206    #[test]
207    fn test_vector_norm_defaults() {
208        assert_eq!(VectorNorm::default(), VectorNorm::None);
209    }
210
211    #[test]
212    fn test_embedding_key_canonical_bytes_deterministic() {
213        let k1 = EmbeddingKey::new(
214            "bge-small-en-v1.5",
215            "v1.5",
216            384,
217            DistanceMetric::Cosine,
218            VectorDType::F32,
219            VectorNorm::Unit,
220        );
221        let k2 = EmbeddingKey::new(
222            "bge-small-en-v1.5",
223            "v1.5",
224            384,
225            DistanceMetric::Cosine,
226            VectorDType::F32,
227            VectorNorm::Unit,
228        );
229        assert_eq!(k1.canonical_bytes(), k2.canonical_bytes());
230    }
231
232    #[test]
233    fn test_embedding_key_canonical_bytes_differs_by_field() {
234        let k1 = EmbeddingKey::new(
235            "model-a",
236            "v1",
237            384,
238            DistanceMetric::Cosine,
239            VectorDType::F32,
240            VectorNorm::Unit,
241        );
242        let k2 = EmbeddingKey::new(
243            "model-b",
244            "v1",
245            384,
246            DistanceMetric::Cosine,
247            VectorDType::F32,
248            VectorNorm::Unit,
249        );
250        assert_ne!(k1.canonical_bytes(), k2.canonical_bytes());
251
252        let k3 = EmbeddingKey::new(
253            "model-a",
254            "v1",
255            768,
256            DistanceMetric::Cosine,
257            VectorDType::F32,
258            VectorNorm::Unit,
259        );
260        assert_ne!(k1.canonical_bytes(), k3.canonical_bytes());
261    }
262}