lnmp_embedding/
vector.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[repr(u8)]
6pub enum EmbeddingType {
7    F32 = 0x01,
8    F16 = 0x02,
9    I8 = 0x03,
10    U8 = 0x04,
11    Binary = 0x05,
12}
13
14impl fmt::Display for EmbeddingType {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        match self {
17            EmbeddingType::F32 => write!(f, "F32"),
18            EmbeddingType::F16 => write!(f, "F16"),
19            EmbeddingType::I8 => write!(f, "I8"),
20            EmbeddingType::U8 => write!(f, "U8"),
21            EmbeddingType::Binary => write!(f, "Binary"),
22        }
23    }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[repr(u8)]
28pub enum SimilarityMetric {
29    Cosine = 0x01,
30    Euclidean = 0x02,
31    DotProduct = 0x03,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct Vector {
36    pub dtype: EmbeddingType,
37    pub dim: u16,
38    pub data: Vec<u8>, // Raw bytes
39}
40
41impl Vector {
42    pub fn new(dtype: EmbeddingType, dim: u16, data: Vec<u8>) -> Self {
43        Self { dtype, dim, data }
44    }
45
46    pub fn from_f32(data: Vec<f32>) -> Self {
47        let mut bytes = Vec::with_capacity(data.len() * 4);
48        for val in &data {
49            bytes.extend_from_slice(&val.to_le_bytes());
50        }
51        Self {
52            dtype: EmbeddingType::F32,
53            dim: data.len() as u16,
54            data: bytes,
55        }
56    }
57
58    pub fn as_f32(&self) -> Result<Vec<f32>, String> {
59        if self.dtype != EmbeddingType::F32 {
60            return Err(format!("Cannot convert {:?} to F32", self.dtype));
61        }
62
63        if !self.data.len().is_multiple_of(4) {
64            return Err("Invalid data length for F32".to_string());
65        }
66
67        let mut res = Vec::with_capacity(self.data.len() / 4);
68        for chunk in self.data.chunks_exact(4) {
69            let val = f32::from_le_bytes(chunk.try_into().unwrap());
70            res.push(val);
71        }
72        Ok(res)
73    }
74
75    pub fn similarity(&self, other: &Vector, metric: SimilarityMetric) -> Result<f32, String> {
76        if self.dtype != other.dtype {
77            return Err("DType mismatch".to_string());
78        }
79        if self.dim != other.dim {
80            return Err("Dimension mismatch".to_string());
81        }
82
83        // Currently only implementing for F32
84        if self.dtype == EmbeddingType::F32 {
85            // Optimized: avoid intermediate Vec allocation, work directly with bytes
86            match metric {
87                SimilarityMetric::Cosine => {
88                    let (dot, norm1_sq, norm2_sq) =
89                        unsafe { Self::dot_and_norms_f32(&self.data, &other.data) };
90
91                    let norm1 = norm1_sq.sqrt();
92                    let norm2 = norm2_sq.sqrt();
93
94                    if norm1 == 0.0 || norm2 == 0.0 {
95                        return Ok(0.0);
96                    }
97                    Ok(dot / (norm1 * norm2))
98                }
99                SimilarityMetric::DotProduct => {
100                    let dot = unsafe { Self::dot_product_f32(&self.data, &other.data) };
101                    Ok(dot)
102                }
103                SimilarityMetric::Euclidean => {
104                    let sum_sq =
105                        unsafe { Self::euclidean_distance_sq_f32(&self.data, &other.data) };
106                    Ok(sum_sq.sqrt())
107                }
108            }
109        } else {
110            Err("Similarity not implemented for this dtype yet".to_string())
111        }
112    }
113
114    /// Optimized dot product for f32 from raw bytes
115    /// SAFETY: Assumes data is properly aligned f32 data with length % 4 == 0
116    #[inline]
117    unsafe fn dot_product_f32(data1: &[u8], data2: &[u8]) -> f32 {
118        let len = data1.len() / 4;
119        let ptr1 = data1.as_ptr() as *const f32;
120        let ptr2 = data2.as_ptr() as *const f32;
121
122        let mut sum = 0.0f32;
123        for i in 0..len {
124            sum += (*ptr1.add(i)) * (*ptr2.add(i));
125        }
126        sum
127    }
128
129    /// Optimized euclidean distance squared for f32 from raw bytes
130    /// SAFETY: Assumes data is properly aligned f32 data with length % 4 == 0
131    #[inline]
132    unsafe fn euclidean_distance_sq_f32(data1: &[u8], data2: &[u8]) -> f32 {
133        let len = data1.len() / 4;
134        let ptr1 = data1.as_ptr() as *const f32;
135        let ptr2 = data2.as_ptr() as *const f32;
136
137        let mut sum = 0.0f32;
138        for i in 0..len {
139            let diff = *ptr1.add(i) - *ptr2.add(i);
140            sum += diff * diff;
141        }
142        sum
143    }
144
145    /// Optimized combined dot product and norms calculation for f32 from raw bytes
146    /// Returns (dot_product, norm1_squared, norm2_squared)
147    /// SAFETY: Assumes data is properly aligned f32 data with length % 4 == 0
148    #[inline]
149    unsafe fn dot_and_norms_f32(data1: &[u8], data2: &[u8]) -> (f32, f32, f32) {
150        let len = data1.len() / 4;
151        let ptr1 = data1.as_ptr() as *const f32;
152        let ptr2 = data2.as_ptr() as *const f32;
153
154        let mut dot = 0.0f32;
155        let mut norm1_sq = 0.0f32;
156        let mut norm2_sq = 0.0f32;
157
158        for i in 0..len {
159            let v1 = *ptr1.add(i);
160            let v2 = *ptr2.add(i);
161            dot += v1 * v2;
162            norm1_sq += v1 * v1;
163            norm2_sq += v2 * v2;
164        }
165
166        (dot, norm1_sq, norm2_sq)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_f32_conversion() {
176        let data = vec![1.0, 2.0, 3.0];
177        let vec = Vector::from_f32(data.clone());
178        assert_eq!(vec.dtype, EmbeddingType::F32);
179        assert_eq!(vec.dim, 3);
180        assert_eq!(vec.as_f32().unwrap(), data);
181    }
182
183    #[test]
184    fn test_cosine_similarity() {
185        let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
186        let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
187        assert_eq!(v1.similarity(&v2, SimilarityMetric::Cosine).unwrap(), 0.0);
188
189        let v3 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
190        assert_eq!(v1.similarity(&v3, SimilarityMetric::Cosine).unwrap(), 1.0); // Should be exactly 1.0 or very close
191    }
192}