lnmp_embedding/
view.rs

1//! Zero-copy embedding views for high-performance vector operations.
2//!
3//! This module provides `EmbeddingView` which allows direct access to embedding
4//! data without allocation, enabling SIMD-optimized similarity computations.
5
6use crate::vector::{EmbeddingType, SimilarityMetric};
7
8/// Zero-copy view into an embedding stored in a binary buffer.
9///
10/// This struct provides direct access to embedding data without allocation.
11/// All similarity computations can be performed on borrowed data.
12///
13/// # Layout
14///
15/// The binary format is:
16/// ```text
17/// [u16 dim | u8 dtype | u8 similarity | vector data...]
18/// ```
19#[derive(Debug, Clone, Copy)]
20pub struct EmbeddingView<'a> {
21    /// Embedding dimension
22    pub dim: u16,
23    /// Embedding data type (F32, F16, etc.)
24    pub dtype: EmbeddingType,
25    /// Raw embedding data (borrowed from input buffer)
26    data: &'a [u8],
27}
28
29impl<'a> EmbeddingView<'a> {
30    /// Creates a new embedding view from raw bytes.
31    ///
32    /// # Format
33    ///
34    /// Expects bytes in the format: `[u16 dim | u8 dtype | u8 reserved | data...]`
35    ///
36    /// # Errors
37    ///
38    /// Returns error if:
39    /// - Buffer too small (< 4 bytes header)
40    /// - Invalid dtype byte
41    /// - Data length doesn't match expected size
42    pub fn from_bytes(bytes: &'a [u8]) -> Result<Self, String> {
43        if bytes.len() < 4 {
44            return Err("Buffer too small for embedding header".to_string());
45        }
46
47        // Parse header
48        let dim = u16::from_le_bytes([bytes[0], bytes[1]]);
49        let dtype_byte = bytes[2];
50        // bytes[3] is similarity/reserved
51
52        let dtype = match dtype_byte {
53            0x01 => EmbeddingType::F32,
54            0x02 => EmbeddingType::F16,
55            0x03 => EmbeddingType::I8,
56            0x04 => EmbeddingType::U8,
57            0x05 => EmbeddingType::Binary,
58            _ => return Err(format!("Invalid dtype: 0x{:02x}", dtype_byte)),
59        };
60
61        let data = &bytes[4..];
62        let expected_len = dim as usize * dtype.size_bytes();
63
64        if data.len() != expected_len {
65            return Err(format!(
66                "Data length mismatch: expected {} bytes, found {}",
67                expected_len,
68                data.len()
69            ));
70        }
71
72        Ok(Self { dim, dtype, data })
73    }
74
75    /// Returns the raw data bytes.
76    pub fn data(&self) -> &'a [u8] {
77        self.data
78    }
79
80    /// Returns the embedding as a Vec<f32> (safe copy).
81    ///
82    /// This method performs a safe copy of the embedding data, converting
83    /// from bytes to f32 values. While it allocates memory, it's guaranteed
84    /// to work regardless of memory alignment.
85    ///
86    /// # Performance
87    ///
88    /// For 256-dim embedding: ~100-200ns allocation + copy overhead
89    /// Still much faster than full record decode for large records.
90    ///
91    /// # Errors
92    ///
93    /// Returns error if dtype is not F32.
94    pub fn as_f32_vec(&self) -> Result<Vec<f32>, String> {
95        if self.dtype != EmbeddingType::F32 {
96            return Err(format!(
97                "Cannot convert {:?} to f32 vec (expected F32)",
98                self.dtype
99            ));
100        }
101
102        if !self.data.len().is_multiple_of(4) {
103            return Err("Invalid data length for F32".to_string());
104        }
105
106        let mut result = Vec::with_capacity(self.dim as usize);
107        for chunk in self.data.chunks_exact(4) {
108            result.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
109        }
110        Ok(result)
111    }
112
113    /// Returns the embedding as an f32 slice (zero-copy cast - EXPERIMENTAL).
114    ///
115    /// ⚠️ **WARNING:** This method is experimental and may fail with alignment errors.
116    /// Memory alignment is not guaranteed for slices from arbitrary buffers.
117    /// Use `as_f32_vec()` for a safe, reliable alternative.
118    ///
119    /// # Safety
120    ///
121    /// Uses `bytemuck` for casting, which requires 4-byte alignment.
122    /// Will panic if the underlying buffer is not properly aligned.
123    ///
124    /// # Errors
125    ///
126    /// Returns error if dtype is not F32 or bytemuck feature is not enabled.
127    #[cfg(feature = "zerocopy")]
128    pub fn as_f32_slice(&self) -> Result<&'a [f32], String> {
129        if self.dtype != EmbeddingType::F32 {
130            return Err(format!(
131                "Cannot cast {:?} to f32 slice (expected F32)",
132                self.dtype
133            ));
134        }
135
136        // Try zero-copy cast (may panic if unaligned!)
137        Ok(bytemuck::cast_slice(self.data))
138    }
139
140    /// Computes cosine similarity with another embedding view (zero-copy).
141    ///
142    /// # Performance
143    ///
144    /// Uses SIMD instructions when available on x86_64.
145    /// Typical performance:
146    /// - 256-dim: ~50-100 ns
147    /// - 1024-dim: ~200-400 ns
148    ///
149    /// # Errors
150    ///
151    /// Returns error if dimensions or dtypes don't match.
152    pub fn cosine_similarity(&self, other: &EmbeddingView) -> Result<f32, String> {
153        self.check_compatibility(other)?;
154
155        match self.dtype {
156            EmbeddingType::F32 => {
157                let a = self.as_f32_vec()?;
158                let b = other.as_f32_vec()?;
159                Ok(cosine_similarity_f32(&a, &b))
160            }
161            _ => Err(format!(
162                "Cosine similarity not implemented for {:?}",
163                self.dtype
164            )),
165        }
166    }
167
168    pub fn dot_product(&self, other: &EmbeddingView) -> Result<f32, String> {
169        self.check_compatibility(other)?;
170        let a = self.as_f32_vec()?;
171        let b = other.as_f32_vec()?;
172        Ok(dot_product_f32(&a, &b))
173    }
174
175    /// Computes Euclidean distance (zero-copy).
176    pub fn euclidean_distance(&self, other: &EmbeddingView) -> Result<f32, String> {
177        self.check_compatibility(other)?;
178        let a = self.as_f32_vec()?;
179        let b = other.as_f32_vec()?;
180        Ok(euclidean_distance_f32(&a, &b))
181    }
182
183    /// Generic similarity with metric selection.
184    pub fn similarity(
185        &self,
186        other: &EmbeddingView,
187        metric: SimilarityMetric,
188    ) -> Result<f32, String> {
189        match metric {
190            SimilarityMetric::Cosine => self.cosine_similarity(other),
191            SimilarityMetric::DotProduct => self.dot_product(other),
192            SimilarityMetric::Euclidean => self.euclidean_distance(other),
193        }
194    }
195
196    fn check_compatibility(&self, other: &EmbeddingView) -> Result<(), String> {
197        if self.dim != other.dim {
198            return Err(format!("Dimension mismatch: {} vs {}", self.dim, other.dim));
199        }
200        if self.dtype != other.dtype {
201            return Err(format!(
202                "DType mismatch: {:?} vs {:?}",
203                self.dtype, other.dtype
204            ));
205        }
206        Ok(())
207    }
208}
209
210impl EmbeddingType {
211    /// Returns the size in bytes for each element of this type.
212    pub fn size_bytes(&self) -> usize {
213        match self {
214            EmbeddingType::F32 => 4,
215            EmbeddingType::F16 => 2,
216            EmbeddingType::I8 => 1,
217            EmbeddingType::U8 => 1,
218            EmbeddingType::Binary => 1, // Bitpacked, but byte-aligned
219        }
220    }
221}
222
223// ============================================================================
224// SIMD-Optimized Similarity Functions
225// ============================================================================
226
227#[cfg(all(target_arch = "x86_64", feature = "zerocopy"))]
228fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
229    use std::arch::x86_64::*;
230
231    unsafe {
232        let mut dot = _mm256_setzero_ps();
233        let mut norm_a = _mm256_setzero_ps();
234        let mut norm_b = _mm256_setzero_ps();
235
236        // Process 8 floats at a time with AVX2
237        let chunks = a.len() / 8;
238        for i in 0..chunks {
239            let offset = i * 8;
240            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
241            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
242
243            dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
244            norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
245            norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
246        }
247
248        // Horizontal sum for dot, norm_a, norm_b
249        let dot_sum = horizontal_sum_avx(dot);
250        let norm_a_sum = horizontal_sum_avx(norm_a);
251        let norm_b_sum = horizontal_sum_avx(norm_b);
252
253        // Handle remainder
254        let mut dot_rem = 0.0f32;
255        let mut norm_a_rem = 0.0f32;
256        let mut norm_b_rem = 0.0f32;
257        for i in (chunks * 8)..a.len() {
258            dot_rem += a[i] * b[i];
259            norm_a_rem += a[i] * a[i];
260            norm_b_rem += b[i] * b[i];
261        }
262
263        let total_dot = dot_sum + dot_rem;
264        let total_norm_a = (norm_a_sum + norm_a_rem).sqrt();
265        let total_norm_b = (norm_b_sum + norm_b_rem).sqrt();
266
267        if total_norm_a == 0.0 || total_norm_b == 0.0 {
268            return 0.0;
269        }
270
271        total_dot / (total_norm_a * total_norm_b)
272    }
273}
274
275#[cfg(all(target_arch = "x86_64", feature = "zerocopy"))]
276unsafe fn horizontal_sum_avx(v: std::arch::x86_64::__m256) -> f32 {
277    use std::arch::x86_64::*;
278    let hi = _mm256_extractf128_ps(v, 1);
279    let lo = _mm256_castps256_ps128(v);
280    let sum = _mm_add_ps(hi, lo);
281    let shuf = _mm_movehdup_ps(sum);
282    let sums = _mm_add_ps(sum, shuf);
283    let shuf = _mm_movehl_ps(shuf, sums);
284    let result = _mm_add_ss(sums, shuf);
285    _mm_cvtss_f32(result)
286}
287
288// Fallback implementations for non-x86_64 or without bytemuck
289#[cfg(any(not(target_arch = "x86_64"), not(feature = "zerocopy")))]
290fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
291    let (dot, norm_a_sq, norm_b_sq) = a
292        .iter()
293        .zip(b)
294        .fold((0.0f32, 0.0f32, 0.0f32), |(d, na, nb), (&x, &y)| {
295            (d + x * y, na + x * x, nb + y * y)
296        });
297
298    let norm_a = norm_a_sq.sqrt();
299    let norm_b = norm_b_sq.sqrt();
300
301    if norm_a == 0.0 || norm_b == 0.0 {
302        0.0
303    } else {
304        dot / (norm_a * norm_b)
305    }
306}
307
308#[cfg(feature = "zerocopy")]
309fn dot_product_f32(a: &[f32], b: &[f32]) -> f32 {
310    a.iter().zip(b).map(|(&x, &y)| x * y).sum()
311}
312
313#[cfg(not(feature = "zerocopy"))]
314fn dot_product_f32(_a: &[f32], _b: &[f32]) -> f32 {
315    panic!("bytemuck feature required");
316}
317
318#[cfg(feature = "zerocopy")]
319fn euclidean_distance_f32(a: &[f32], b: &[f32]) -> f32 {
320    a.iter()
321        .zip(b)
322        .map(|(&x, &y)| {
323            let diff = x - y;
324            diff * diff
325        })
326        .sum::<f32>()
327        .sqrt()
328}
329
330#[cfg(not(feature = "zerocopy"))]
331fn euclidean_distance_f32(_a: &[f32], _b: &[f32]) -> f32 {
332    panic!("bytemuck feature required");
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::encoder::Encoder;
339    use crate::vector::Vector;
340
341    #[test]
342    fn test_embedding_view_from_bytes() {
343        let vec = Vector::from_f32(vec![1.0, 2.0, 3.0]);
344        let encoded = Encoder::encode(&vec).unwrap();
345
346        let view = EmbeddingView::from_bytes(&encoded).unwrap();
347        assert_eq!(view.dim, 3);
348        assert_eq!(view.dtype, EmbeddingType::F32);
349    }
350
351    #[test]
352    #[cfg(feature = "zerocopy")]
353    fn test_cosine_similarity_zerocopy() {
354        let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
355        let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
356
357        let bytes1 = Encoder::encode(&v1).unwrap();
358        let bytes2 = Encoder::encode(&v2).unwrap();
359
360        let view1 = EmbeddingView::from_bytes(&bytes1).unwrap();
361        let view2 = EmbeddingView::from_bytes(&bytes2).unwrap();
362
363        let similarity = view1.cosine_similarity(&view2).unwrap();
364        assert!((similarity - 0.0).abs() < 1e-6);
365    }
366
367    #[test]
368    #[cfg(feature = "zerocopy")]
369    fn test_dot_product_zerocopy() {
370        let v1 = Vector::from_f32(vec![1.0, 2.0, 3.0]);
371        let v2 = Vector::from_f32(vec![4.0, 5.0, 6.0]);
372
373        let bytes1 = Encoder::encode(&v1).unwrap();
374        let bytes2 = Encoder::encode(&v2).unwrap();
375
376        let view1 = EmbeddingView::from_bytes(&bytes1).unwrap();
377        let view2 = EmbeddingView::from_bytes(&bytes2).unwrap();
378
379        let dot = view1.dot_product(&view2).unwrap();
380        assert!((dot - 32.0).abs() < 1e-6); // 1*4 + 2*5 + 3*6 = 32
381    }
382}