lnmp_embedding/
vector.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[repr(u8)]
6pub enum EmbeddingType {
7    F32 = 0x01,
8    F16 = 0x02,
9    I8 = 0x03,
10    U8 = 0x04,
11    Binary = 0x05,
12}
13
14impl fmt::Display for EmbeddingType {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        match self {
17            EmbeddingType::F32 => write!(f, "F32"),
18            EmbeddingType::F16 => write!(f, "F16"),
19            EmbeddingType::I8 => write!(f, "I8"),
20            EmbeddingType::U8 => write!(f, "U8"),
21            EmbeddingType::Binary => write!(f, "Binary"),
22        }
23    }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[repr(u8)]
28pub enum SimilarityMetric {
29    Cosine = 0x01,
30    Euclidean = 0x02,
31    DotProduct = 0x03,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct Vector {
36    pub dtype: EmbeddingType,
37    pub dim: u16,
38    pub data: Vec<u8>, // Raw bytes
39}
40
41impl Vector {
42    pub fn new(dtype: EmbeddingType, dim: u16, data: Vec<u8>) -> Self {
43        Self { dtype, dim, data }
44    }
45
46    pub fn from_f32(data: Vec<f32>) -> Self {
47        let mut bytes = Vec::with_capacity(data.len() * 4);
48        for val in &data {
49            bytes.extend_from_slice(&val.to_le_bytes());
50        }
51        Self {
52            dtype: EmbeddingType::F32,
53            dim: data.len() as u16,
54            data: bytes,
55        }
56    }
57
58    pub fn as_f32(&self) -> Result<Vec<f32>, String> {
59        if self.dtype != EmbeddingType::F32 {
60            return Err(format!("Cannot convert {:?} to F32", self.dtype));
61        }
62
63        if !self.data.len().is_multiple_of(4) {
64            return Err("Invalid data length for F32".to_string());
65        }
66
67        let mut res = Vec::with_capacity(self.data.len() / 4);
68        for chunk in self.data.chunks_exact(4) {
69            let val = f32::from_le_bytes(chunk.try_into().unwrap());
70            res.push(val);
71        }
72        Ok(res)
73    }
74
75    pub fn normalize(&self) -> Result<Vector, String> {
76        if self.dtype != EmbeddingType::F32 {
77            return Err("Normalization not implemented for this dtype".to_string());
78        }
79
80        // Optimized for F32
81        let norm_sq = Self::norm_sq_f32(&self.data);
82        let norm = norm_sq.sqrt();
83
84        if norm == 0.0 {
85            return Ok(self.clone());
86        }
87
88        let mut res_data = Vec::with_capacity(self.data.len());
89
90        for chunk in self.data.chunks_exact(4) {
91            let val = f32::from_le_bytes(chunk.try_into().unwrap());
92            let normalized_val = val / norm;
93            res_data.extend_from_slice(&normalized_val.to_le_bytes());
94        }
95
96        Ok(Vector {
97            dtype: self.dtype,
98            dim: self.dim,
99            data: res_data,
100        })
101    }
102
103    pub fn similarity(&self, other: &Vector, metric: SimilarityMetric) -> Result<f32, String> {
104        if self.dtype != other.dtype {
105            return Err("DType mismatch".to_string());
106        }
107        if self.dim != other.dim {
108            return Err("Dimension mismatch".to_string());
109        }
110
111        // Currently only implementing for F32
112        if self.dtype == EmbeddingType::F32 {
113            // Optimized: avoid intermediate Vec allocation, work directly with bytes
114            match metric {
115                SimilarityMetric::Cosine => {
116                    let (dot, norm1_sq, norm2_sq) =
117                        Self::dot_and_norms_f32(&self.data, &other.data);
118
119                    let norm1 = norm1_sq.sqrt();
120                    let norm2 = norm2_sq.sqrt();
121
122                    if norm1 == 0.0 || norm2 == 0.0 {
123                        return Ok(0.0);
124                    }
125                    Ok(dot / (norm1 * norm2))
126                }
127                SimilarityMetric::DotProduct => {
128                    let dot = Self::dot_product_f32(&self.data, &other.data);
129                    Ok(dot)
130                }
131                SimilarityMetric::Euclidean => {
132                    let sum_sq = Self::euclidean_distance_sq_f32(&self.data, &other.data);
133                    Ok(sum_sq.sqrt())
134                }
135            }
136        } else {
137            Err("Similarity not implemented for this dtype yet".to_string())
138        }
139    }
140
141    /// Optimized dot product for f32 from raw bytes
142    #[inline]
143    fn dot_product_f32(data1: &[u8], data2: &[u8]) -> f32 {
144        let mut sum = 0.0f32;
145        for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
146            let v1 = f32::from_le_bytes(c1.try_into().unwrap());
147            let v2 = f32::from_le_bytes(c2.try_into().unwrap());
148            sum += v1 * v2;
149        }
150        sum
151    }
152
153    /// Optimized euclidean distance squared for f32 from raw bytes
154    #[inline]
155    fn euclidean_distance_sq_f32(data1: &[u8], data2: &[u8]) -> f32 {
156        let mut sum = 0.0f32;
157        for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
158            let v1 = f32::from_le_bytes(c1.try_into().unwrap());
159            let v2 = f32::from_le_bytes(c2.try_into().unwrap());
160            let diff = v1 - v2;
161            sum += diff * diff;
162        }
163        sum
164    }
165
166    /// Optimized combined dot product and norms calculation for f32 from raw bytes
167    /// Returns (dot_product, norm1_squared, norm2_squared)
168    #[inline]
169    fn dot_and_norms_f32(data1: &[u8], data2: &[u8]) -> (f32, f32, f32) {
170        let mut dot = 0.0f32;
171        let mut norm1_sq = 0.0f32;
172        let mut norm2_sq = 0.0f32;
173
174        for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
175            let v1 = f32::from_le_bytes(c1.try_into().unwrap());
176            let v2 = f32::from_le_bytes(c2.try_into().unwrap());
177            dot += v1 * v2;
178            norm1_sq += v1 * v1;
179            norm2_sq += v2 * v2;
180        }
181
182        (dot, norm1_sq, norm2_sq)
183    }
184
185    /// Optimized norm squared calculation for f32 from raw bytes
186    #[inline]
187    fn norm_sq_f32(data: &[u8]) -> f32 {
188        let mut sum_sq = 0.0f32;
189        for chunk in data.chunks_exact(4) {
190            let val = f32::from_le_bytes(chunk.try_into().unwrap());
191            sum_sq += val * val;
192        }
193        sum_sq
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_f32_conversion() {
203        let data = vec![1.0, 2.0, 3.0];
204        let vec = Vector::from_f32(data.clone());
205        assert_eq!(vec.dtype, EmbeddingType::F32);
206        assert_eq!(vec.dim, 3);
207        assert_eq!(vec.as_f32().unwrap(), data);
208    }
209
210    #[test]
211    fn test_cosine_similarity() {
212        let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
213        let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
214        assert_eq!(v1.similarity(&v2, SimilarityMetric::Cosine).unwrap(), 0.0);
215
216        let v3 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
217        assert_eq!(v1.similarity(&v3, SimilarityMetric::Cosine).unwrap(), 1.0); // Should be exactly 1.0 or very close
218    }
219
220    #[test]
221    fn test_normalize() {
222        let v = Vector::from_f32(vec![3.0, 4.0]);
223        let normalized = v.normalize().unwrap();
224        let data = normalized.as_f32().unwrap();
225        assert!((data[0] - 0.6).abs() < 1e-6);
226        assert!((data[1] - 0.8).abs() < 1e-6);
227
228        // Check that magnitude is 1
229        let norm = (data[0] * data[0] + data[1] * data[1]).sqrt();
230        assert!((norm - 1.0).abs() < 1e-6);
231    }
232}