1use crate::vector::{EmbeddingType, SimilarityMetric};
7
8#[derive(Debug, Clone, Copy)]
20pub struct EmbeddingView<'a> {
21 pub dim: u16,
23 pub dtype: EmbeddingType,
25 data: &'a [u8],
27}
28
29impl<'a> EmbeddingView<'a> {
30 pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, String> {
43 if bytes.len() < 4 {
44 return Err("Buffer too small for embedding header".to_string());
45 }
46
47 let dim = u16::from_le_bytes([bytes[0], bytes[1]]);
49 let dtype_byte = bytes[2];
50 let dtype = match dtype_byte {
53 0x01 => EmbeddingType::F32,
54 0x02 => EmbeddingType::F16,
55 0x03 => EmbeddingType::I8,
56 0x04 => EmbeddingType::U8,
57 0x05 => EmbeddingType::Binary,
58 _ => return Err(format!("Invalid dtype: 0x{:02x}", dtype_byte)),
59 };
60
61 let data = &bytes[4..];
62 let expected_len = dim as usize * dtype.size_bytes();
63
64 if data.len() != expected_len {
65 return Err(format!(
66 "Data length mismatch: expected {} bytes, found {}",
67 expected_len,
68 data.len()
69 ));
70 }
71
72 Ok(Self { dim, dtype, data })
73 }
74
75 pub fn data(&self) -> &'a [u8] {
77 self.data
78 }
79
80 pub fn as_f32_vec(&self) -> Result<Vec<f32>, String> {
95 if self.dtype != EmbeddingType::F32 {
96 return Err(format!(
97 "Cannot convert {:?} to f32 vec (expected F32)",
98 self.dtype
99 ));
100 }
101
102 if !self.data.len().is_multiple_of(4) {
103 return Err("Invalid data length for F32".to_string());
104 }
105
106 let mut result = Vec::with_capacity(self.dim as usize);
107 for chunk in self.data.chunks_exact(4) {
108 result.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
109 }
110 Ok(result)
111 }
112
113 #[cfg(feature = "zerocopy")]
128 pub fn as_f32_slice(&self) -> Result<&'a [f32], String> {
129 if self.dtype != EmbeddingType::F32 {
130 return Err(format!(
131 "Cannot cast {:?} to f32 slice (expected F32)",
132 self.dtype
133 ));
134 }
135
136 Ok(bytemuck::cast_slice(self.data))
138 }
139
140 pub fn cosine_similarity(&self, other: &EmbeddingView) -> Result<f32, String> {
153 self.check_compatibility(other)?;
154
155 match self.dtype {
156 EmbeddingType::F32 => {
157 let a = self.as_f32_vec()?;
158 let b = other.as_f32_vec()?;
159 Ok(cosine_similarity_f32(&a, &b))
160 }
161 _ => Err(format!(
162 "Cosine similarity not implemented for {:?}",
163 self.dtype
164 )),
165 }
166 }
167
168 pub fn dot_product(&self, other: &EmbeddingView) -> Result<f32, String> {
169 self.check_compatibility(other)?;
170 let a = self.as_f32_vec()?;
171 let b = other.as_f32_vec()?;
172 Ok(dot_product_f32(&a, &b))
173 }
174
175 pub fn euclidean_distance(&self, other: &EmbeddingView) -> Result<f32, String> {
177 self.check_compatibility(other)?;
178 let a = self.as_f32_vec()?;
179 let b = other.as_f32_vec()?;
180 Ok(euclidean_distance_f32(&a, &b))
181 }
182
183 pub fn similarity(
185 &self,
186 other: &EmbeddingView,
187 metric: SimilarityMetric,
188 ) -> Result<f32, String> {
189 match metric {
190 SimilarityMetric::Cosine => self.cosine_similarity(other),
191 SimilarityMetric::DotProduct => self.dot_product(other),
192 SimilarityMetric::Euclidean => self.euclidean_distance(other),
193 }
194 }
195
196 fn check_compatibility(&self, other: &EmbeddingView) -> Result<(), String> {
197 if self.dim != other.dim {
198 return Err(format!("Dimension mismatch: {} vs {}", self.dim, other.dim));
199 }
200 if self.dtype != other.dtype {
201 return Err(format!(
202 "DType mismatch: {:?} vs {:?}",
203 self.dtype, other.dtype
204 ));
205 }
206 Ok(())
207 }
208}
209
210impl EmbeddingType {
211 pub fn size_bytes(&self) -> usize {
213 match self {
214 EmbeddingType::F32 => 4,
215 EmbeddingType::F16 => 2,
216 EmbeddingType::I8 => 1,
217 EmbeddingType::U8 => 1,
218 EmbeddingType::Binary => 1, }
220 }
221}
222
223#[cfg(all(target_arch = "x86_64", feature = "zerocopy"))]
228fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
229 use std::arch::x86_64::*;
230
231 unsafe {
232 let mut dot = _mm256_setzero_ps();
233 let mut norm_a = _mm256_setzero_ps();
234 let mut norm_b = _mm256_setzero_ps();
235
236 let chunks = a.len() / 8;
238 for i in 0..chunks {
239 let offset = i * 8;
240 let va = _mm256_loadu_ps(a.as_ptr().add(offset));
241 let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
242
243 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
244 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
245 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
246 }
247
248 let dot_sum = horizontal_sum_avx(dot);
250 let norm_a_sum = horizontal_sum_avx(norm_a);
251 let norm_b_sum = horizontal_sum_avx(norm_b);
252
253 let mut dot_rem = 0.0f32;
255 let mut norm_a_rem = 0.0f32;
256 let mut norm_b_rem = 0.0f32;
257 for i in (chunks * 8)..a.len() {
258 dot_rem += a[i] * b[i];
259 norm_a_rem += a[i] * a[i];
260 norm_b_rem += b[i] * b[i];
261 }
262
263 let total_dot = dot_sum + dot_rem;
264 let total_norm_a = (norm_a_sum + norm_a_rem).sqrt();
265 let total_norm_b = (norm_b_sum + norm_b_rem).sqrt();
266
267 if total_norm_a == 0.0 || total_norm_b == 0.0 {
268 return 0.0;
269 }
270
271 total_dot / (total_norm_a * total_norm_b)
272 }
273}
274
275#[cfg(all(target_arch = "x86_64", feature = "zerocopy"))]
276unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
277 use std::arch::x86_64::*;
278 let hi = _mm256_extractf128_ps(v, 1);
279 let lo = _mm256_castps256_ps128(v);
280 let sum = _mm_add_ps(hi, lo);
281 let shuf = _mm_movehdup_ps(sum);
282 let sums = _mm_add_ps(sum, shuf);
283 let shuf = _mm_movehl_ps(shuf, sums);
284 let result = _mm_add_ss(sums, shuf);
285 _mm_cvtss_f32(result)
286}
287
288#[cfg(any(not(target_arch = "x86_64"), not(feature = "zerocopy")))]
290fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
291 let (dot, norm_a_sq, norm_b_sq) = a
292 .iter()
293 .zip(b)
294 .fold((0.0f32, 0.0f32, 0.0f32), |(d, na, nb), (&x, &y)| {
295 (d + x * y, na + x * x, nb + y * y)
296 });
297
298 let norm_a = norm_a_sq.sqrt();
299 let norm_b = norm_b_sq.sqrt();
300
301 if norm_a == 0.0 || norm_b == 0.0 {
302 0.0
303 } else {
304 dot / (norm_a * norm_b)
305 }
306}
307
308#[cfg(feature = "zerocopy")]
309fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
310 a.iter().zip(b).map(|(&x, &y)| x * y).sum()
311}
312
313#[cfg(not(feature = "zerocopy"))]
314fn dot_product_f32(_a: &[f32], _b: &[f32]) -> f32 {
315 panic!("bytemuck feature required");
316}
317
318#[cfg(feature = "zerocopy")]
319fn euclidean_distance_f32(a: &[f32], b: &[f32]) -> f32 {
320 a.iter()
321 .zip(b)
322 .map(|(&x, &y)| {
323 let diff = x - y;
324 diff * diff
325 })
326 .sum::<f32>()
327 .sqrt()
328}
329
330#[cfg(not(feature = "zerocopy"))]
331fn euclidean_distance_f32(_a: &[f32], _b: &[f32]) -> f32 {
332 panic!("bytemuck feature required");
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::encoder::Encoder;
339 use crate::vector::Vector;
340
341 #[test]
342 fn test_embedding_view_from_bytes() {
343 let vec = Vector::from_f32(vec![1.0, 2.0, 3.0]);
344 let encoded = Encoder::encode(&vec).unwrap();
345
346 let view = EmbeddingView::from_bytes(&encoded).unwrap();
347 assert_eq!(view.dim, 3);
348 assert_eq!(view.dtype, EmbeddingType::F32);
349 }
350
351 #[test]
352 #[cfg(feature = "zerocopy")]
353 fn test_cosine_similarity_zerocopy() {
354 let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
355 let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
356
357 let bytes1 = Encoder::encode(&v1).unwrap();
358 let bytes2 = Encoder::encode(&v2).unwrap();
359
360 let view1 = EmbeddingView::from_bytes(&bytes1).unwrap();
361 let view2 = EmbeddingView::from_bytes(&bytes2).unwrap();
362
363 let similarity = view1.cosine_similarity(&view2).unwrap();
364 assert!((similarity - 0.0).abs() < 1e-6);
365 }
366
367 #[test]
368 #[cfg(feature = "zerocopy")]
369 fn test_dot_product_zerocopy() {
370 let v1 = Vector::from_f32(vec![1.0, 2.0, 3.0]);
371 let v2 = Vector::from_f32(vec![4.0, 5.0, 6.0]);
372
373 let bytes1 = Encoder::encode(&v1).unwrap();
374 let bytes2 = Encoder::encode(&v2).unwrap();
375
376 let view1 = EmbeddingView::from_bytes(&bytes1).unwrap();
377 let view2 = EmbeddingView::from_bytes(&bytes2).unwrap();
378
379 let dot = view1.dot_product(&view2).unwrap();
380 assert!((dot - 32.0).abs() < 1e-6); }
382}