1use serde::{Deserialize, Serialize};
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Embedding {
16 pub id: Uuid,
18
19 pub chunk_id: Uuid,
21
22 pub vector: Vec<i16>,
24
25 pub model_hash: [u8; 32],
27
28 pub dim: u16,
30
31 pub l2_norm: f32,
34
35 pub embedding_version: u32,
37}
38
39impl Embedding {
40 pub fn new(
48 chunk_id: Uuid,
49 vector_f32: &[f32],
50 model_hash: [u8; 32],
51 embedding_version: u32,
52 ) -> Self {
53 let normalized = normalize_l2(vector_f32);
55
56 let quantized: Vec<i16> = normalized.iter().map(|&v| quantize_f32_to_i16(v)).collect();
58
59 let dim = quantized.len() as u16;
60
61 let l2_norm = compute_l2_norm(&quantized);
63
64 let id_bytes = crate::id::generate_composite_id(&[
66 chunk_id.as_bytes(),
67 &model_hash,
68 &embedding_version.to_le_bytes(),
69 ]);
70 let id = Uuid::from_bytes(id_bytes);
71
72 Self {
73 id,
74 chunk_id,
75 vector: quantized,
76 model_hash,
77 dim,
78 l2_norm,
79 embedding_version,
80 }
81 }
82
83 pub fn from_quantized(
87 chunk_id: Uuid,
88 vector: Vec<i16>,
89 model_hash: [u8; 32],
90 embedding_version: u32,
91 ) -> Self {
92 let dim = vector.len() as u16;
93 let l2_norm = compute_l2_norm(&vector);
94 let id_bytes = crate::id::generate_composite_id(&[
95 chunk_id.as_bytes(),
96 &model_hash,
97 &embedding_version.to_le_bytes(),
98 ]);
99 let id = Uuid::from_bytes(id_bytes);
100
101 Self {
102 id,
103 chunk_id,
104 vector,
105 model_hash,
106 dim,
107 l2_norm,
108 embedding_version,
109 }
110 }
111
112 pub fn from_quantized_with_norm(
116 chunk_id: Uuid,
117 vector: Vec<i16>,
118 model_hash: [u8; 32],
119 l2_norm: f32,
120 embedding_version: u32,
121 ) -> Self {
122 let dim = vector.len() as u16;
123 let id_bytes = crate::id::generate_composite_id(&[
124 chunk_id.as_bytes(),
125 &model_hash,
126 &embedding_version.to_le_bytes(),
127 ]);
128 let id = Uuid::from_bytes(id_bytes);
129
130 Self {
131 id,
132 chunk_id,
133 vector,
134 model_hash,
135 dim,
136 l2_norm,
137 embedding_version,
138 }
139 }
140
141 pub fn to_f32(&self) -> Vec<f32> {
143 self.vector
144 .iter()
145 .map(|&v| f32::from(v) / 32767.0)
146 .collect()
147 }
148
149 pub fn integer_dot_product(&self, other: &[i16]) -> i64 {
154 if self.vector.len() != other.len() {
155 return 0;
156 }
157
158 self.vector
159 .iter()
160 .zip(other.iter())
161 .map(|(&a, &b)| i64::from(a) * i64::from(b))
162 .sum()
163 }
164
165 pub fn norm_squared(&self) -> i64 {
169 self.vector
170 .iter()
171 .map(|&v| i64::from(v) * i64::from(v))
172 .sum()
173 }
174
175 pub fn norm_f32(&self) -> f32 {
177 (self.norm_squared() as f64).sqrt() as f32
178 }
179
180 pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
185 if self.vector.len() != other.vector.len() {
186 return 0.0;
187 }
188
189 let dot = self.integer_dot_product(&other.vector);
190 let norm_a = self.norm_squared();
191 let norm_b = other.norm_squared();
192
193 if norm_a == 0 || norm_b == 0 {
194 return 0.0;
195 }
196
197 let denom = ((norm_a as f64) * (norm_b as f64)).sqrt();
200 (dot as f64 / denom) as f32
201 }
202}
203
204fn normalize_l2(vector: &[f32]) -> Vec<f32> {
206 let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
207 if norm == 0.0 {
208 return vector.to_vec();
209 }
210 vector.iter().map(|v| v / norm).collect()
211}
212
213fn compute_l2_norm(vector: &[i16]) -> f32 {
215 let sum_sq: i64 = vector.iter().map(|&v| i64::from(v) * i64::from(v)).sum();
216 (sum_sq as f64).sqrt() as f32
217}
218
219fn quantize_f32_to_i16(val: f32) -> i16 {
224 if val.abs() < 1e-7 {
226 return 0;
227 }
228 let scaled = val * 32767.0;
229 let rounded = scaled.round_ties_even();
230 rounded.clamp(-32767.0, 32767.0) as i16
231}
232
233impl PartialEq for Embedding {
234 fn eq(&self, other: &Self) -> bool {
235 self.id == other.id
236 && self.chunk_id == other.chunk_id
237 && self.model_hash == other.model_hash
238 && self.dim == other.dim
239 && self.embedding_version == other.embedding_version
240 }
241}
242
243impl Eq for Embedding {}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_embedding_id_is_blake3_not_uuid_v5() {
251 let chunk_id = Uuid::from_bytes([42u8; 16]);
252 let model_hash = [1u8; 32];
253 let vector = vec![1.0, 0.0];
254
255 let emb = Embedding::new(chunk_id, &vector, model_hash, 0);
256
257 let expected = crate::id::generate_composite_id(&[
259 chunk_id.as_bytes(),
260 &model_hash,
261 &0u32.to_le_bytes(),
262 ]);
263 assert_eq!(emb.id.as_bytes(), &expected);
264 }
265
266 #[test]
267 fn test_embedding_creation_quantized() {
268 let chunk_id = Uuid::from_bytes([0u8; 16]);
269 let vector = vec![1.0, 0.0];
271 let model_hash = [0u8; 32];
272
273 let emb = Embedding::new(chunk_id, &vector, model_hash, 0);
274
275 assert_eq!(emb.vector[0], 32767);
276 assert_eq!(emb.vector[1], 0);
277 assert!((emb.norm_f32() - 32767.0).abs() < 1.0);
278 }
279
280 #[test]
281 fn test_quantize_round_ties_even() {
282 let result = quantize_f32_to_i16(0.5);
284 assert_eq!(result, 16384);
286 }
287
288 #[test]
289 fn test_quantize_dead_zone() {
290 assert_eq!(quantize_f32_to_i16(0.0), 0);
291 assert_eq!(quantize_f32_to_i16(1e-8), 0); assert_eq!(quantize_f32_to_i16(-1e-8), 0); }
294
295 #[test]
296 fn test_integer_dot_product() {
297 let chunk_id = Uuid::from_bytes([0u8; 16]);
298 let model_hash = [0u8; 32];
299
300 let emb = Embedding::from_quantized(chunk_id, vec![100, 200, 300], model_hash, 0);
301 let other = vec![1i16, 2, 3];
302
303 assert_eq!(emb.integer_dot_product(&other), 1400);
305 }
306
307 #[test]
308 fn test_cosine_similarity() {
309 let chunk_id = Uuid::from_bytes([0u8; 16]);
310 let model_hash = [0u8; 32];
311
312 let emb1 = Embedding::new(chunk_id, &[1.0, 0.0], model_hash, 0);
313 let emb2 = Embedding::new(chunk_id, &[1.0, 0.0], model_hash, 0);
314 let emb3 = Embedding::new(chunk_id, &[0.0, 1.0], model_hash, 0); let emb4 = Embedding::new(chunk_id, &[-1.0, 0.0], model_hash, 0); assert!((emb1.cosine_similarity(&emb2) - 1.0).abs() < 0.01);
318 assert!(emb1.cosine_similarity(&emb3).abs() < 0.01);
319 assert!((emb1.cosine_similarity(&emb4) + 1.0).abs() < 0.01);
320 }
321
322 #[test]
323 fn test_embedding_id_determinism() {
324 let chunk_id = Uuid::from_bytes([42u8; 16]);
325 let model_hash = [7u8; 32];
326 let vector = vec![0.5, -0.3, 0.8];
327
328 let emb1 = Embedding::new(chunk_id, &vector, model_hash, 0);
329 let emb2 = Embedding::new(chunk_id, &vector, model_hash, 0);
330 assert_eq!(emb1.id, emb2.id);
331 }
332
333 #[test]
334 fn test_from_quantized() {
335 let chunk_id = Uuid::from_bytes([0u8; 16]);
336 let model_hash = [0u8; 32];
337 let vec = vec![32767i16, 0, -32767];
338
339 let emb = Embedding::from_quantized(chunk_id, vec.clone(), model_hash, 0);
340 assert_eq!(emb.vector, vec);
341 assert_eq!(emb.dim, 3);
342 }
343
344 #[test]
345 fn test_embedding_l2_norm_computed() {
346 let chunk_id = Uuid::from_bytes([0u8; 16]);
347 let model_hash = [0u8; 32];
348
349 let vector = vec![0.5, 0.5, 0.5, 0.5];
351 let emb = Embedding::new(chunk_id, &vector, model_hash, 0);
352 assert!(
353 emb.l2_norm > 0.0,
354 "l2_norm should be positive for non-zero vectors"
355 );
356
357 let unit_vec = vec![1.0, 0.0];
359 let emb2 = Embedding::new(chunk_id, &unit_vec, model_hash, 0);
360 assert!((emb2.l2_norm - 32767.0).abs() < 1.0);
361 }
362
363 #[test]
364 fn test_l2_norm_from_quantized() {
365 let chunk_id = Uuid::from_bytes([0u8; 16]);
366 let model_hash = [0u8; 32];
367 let vec = vec![100i16, 200, 300];
368
369 let emb = Embedding::from_quantized(chunk_id, vec.clone(), model_hash, 0);
370
371 let expected = (140000.0_f64).sqrt() as f32;
373 assert!((emb.l2_norm - expected).abs() < 0.01);
374 }
375
376 #[test]
377 fn test_l2_norm_with_precomputed() {
378 let chunk_id = Uuid::from_bytes([0u8; 16]);
379 let model_hash = [0u8; 32];
380 let vec = vec![100i16, 200, 300];
381 let precomputed_norm = 374.17;
382
383 let emb =
384 Embedding::from_quantized_with_norm(chunk_id, vec, model_hash, precomputed_norm, 0);
385 assert!((emb.l2_norm - precomputed_norm).abs() < 1e-6);
386 }
387}