1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[repr(u8)]
6pub enum EmbeddingType {
7 F32 = 0x01,
8 F16 = 0x02,
9 I8 = 0x03,
10 U8 = 0x04,
11 Binary = 0x05,
12}
13
14impl fmt::Display for EmbeddingType {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 match self {
17 EmbeddingType::F32 => write!(f, "F32"),
18 EmbeddingType::F16 => write!(f, "F16"),
19 EmbeddingType::I8 => write!(f, "I8"),
20 EmbeddingType::U8 => write!(f, "U8"),
21 EmbeddingType::Binary => write!(f, "Binary"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[repr(u8)]
28pub enum SimilarityMetric {
29 Cosine = 0x01,
30 Euclidean = 0x02,
31 DotProduct = 0x03,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct Vector {
36 pub dtype: EmbeddingType,
37 pub dim: u16,
38 pub data: Vec<u8>, }
40
41impl Vector {
42 pub fn new(dtype: EmbeddingType, dim: u16, data: Vec<u8>) -> Self {
43 Self { dtype, dim, data }
44 }
45
46 pub fn from_f32(data: Vec<f32>) -> Self {
47 let mut bytes = Vec::with_capacity(data.len() * 4);
48 for val in &data {
49 bytes.extend_from_slice(&val.to_le_bytes());
50 }
51 Self {
52 dtype: EmbeddingType::F32,
53 dim: data.len() as u16,
54 data: bytes,
55 }
56 }
57
58 pub fn as_f32(&self) -> Result<Vec<f32>, String> {
59 if self.dtype != EmbeddingType::F32 {
60 return Err(format!("Cannot convert {:?} to F32", self.dtype));
61 }
62
63 if !self.data.len().is_multiple_of(4) {
64 return Err("Invalid data length for F32".to_string());
65 }
66
67 let mut res = Vec::with_capacity(self.data.len() / 4);
68 for chunk in self.data.chunks_exact(4) {
69 let val = f32::from_le_bytes(chunk.try_into().unwrap());
70 res.push(val);
71 }
72 Ok(res)
73 }
74
75 pub fn normalize(&self) -> Result<Vector, String> {
76 if self.dtype != EmbeddingType::F32 {
77 return Err("Normalization not implemented for this dtype".to_string());
78 }
79
80 let norm_sq = Self::norm_sq_f32(&self.data);
82 let norm = norm_sq.sqrt();
83
84 if norm == 0.0 {
85 return Ok(self.clone());
86 }
87
88 let mut res_data = Vec::with_capacity(self.data.len());
89
90 for chunk in self.data.chunks_exact(4) {
91 let val = f32::from_le_bytes(chunk.try_into().unwrap());
92 let normalized_val = val / norm;
93 res_data.extend_from_slice(&normalized_val.to_le_bytes());
94 }
95
96 Ok(Vector {
97 dtype: self.dtype,
98 dim: self.dim,
99 data: res_data,
100 })
101 }
102
103 pub fn similarity(&self, other: &Vector, metric: SimilarityMetric) -> Result<f32, String> {
104 if self.dtype != other.dtype {
105 return Err("DType mismatch".to_string());
106 }
107 if self.dim != other.dim {
108 return Err("Dimension mismatch".to_string());
109 }
110
111 if self.dtype == EmbeddingType::F32 {
113 match metric {
115 SimilarityMetric::Cosine => {
116 let (dot, norm1_sq, norm2_sq) =
117 Self::dot_and_norms_f32(&self.data, &other.data);
118
119 let norm1 = norm1_sq.sqrt();
120 let norm2 = norm2_sq.sqrt();
121
122 if norm1 == 0.0 || norm2 == 0.0 {
123 return Ok(0.0);
124 }
125 Ok(dot / (norm1 * norm2))
126 }
127 SimilarityMetric::DotProduct => {
128 let dot = Self::dot_product_f32(&self.data, &other.data);
129 Ok(dot)
130 }
131 SimilarityMetric::Euclidean => {
132 let sum_sq = Self::euclidean_distance_sq_f32(&self.data, &other.data);
133 Ok(sum_sq.sqrt())
134 }
135 }
136 } else {
137 Err("Similarity not implemented for this dtype yet".to_string())
138 }
139 }
140
141 #[inline]
143 fn dot_product_f32(data1: &[u8], data2: &[u8]) -> f32 {
144 let mut sum = 0.0f32;
145 for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
146 let v1 = f32::from_le_bytes(c1.try_into().unwrap());
147 let v2 = f32::from_le_bytes(c2.try_into().unwrap());
148 sum += v1 * v2;
149 }
150 sum
151 }
152
153 #[inline]
155 fn euclidean_distance_sq_f32(data1: &[u8], data2: &[u8]) -> f32 {
156 let mut sum = 0.0f32;
157 for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
158 let v1 = f32::from_le_bytes(c1.try_into().unwrap());
159 let v2 = f32::from_le_bytes(c2.try_into().unwrap());
160 let diff = v1 - v2;
161 sum += diff * diff;
162 }
163 sum
164 }
165
166 #[inline]
169 fn dot_and_norms_f32(data1: &[u8], data2: &[u8]) -> (f32, f32, f32) {
170 let mut dot = 0.0f32;
171 let mut norm1_sq = 0.0f32;
172 let mut norm2_sq = 0.0f32;
173
174 for (c1, c2) in data1.chunks_exact(4).zip(data2.chunks_exact(4)) {
175 let v1 = f32::from_le_bytes(c1.try_into().unwrap());
176 let v2 = f32::from_le_bytes(c2.try_into().unwrap());
177 dot += v1 * v2;
178 norm1_sq += v1 * v1;
179 norm2_sq += v2 * v2;
180 }
181
182 (dot, norm1_sq, norm2_sq)
183 }
184
185 #[inline]
187 fn norm_sq_f32(data: &[u8]) -> f32 {
188 let mut sum_sq = 0.0f32;
189 for chunk in data.chunks_exact(4) {
190 let val = f32::from_le_bytes(chunk.try_into().unwrap());
191 sum_sq += val * val;
192 }
193 sum_sq
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_f32_conversion() {
203 let data = vec![1.0, 2.0, 3.0];
204 let vec = Vector::from_f32(data.clone());
205 assert_eq!(vec.dtype, EmbeddingType::F32);
206 assert_eq!(vec.dim, 3);
207 assert_eq!(vec.as_f32().unwrap(), data);
208 }
209
210 #[test]
211 fn test_cosine_similarity() {
212 let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
213 let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
214 assert_eq!(v1.similarity(&v2, SimilarityMetric::Cosine).unwrap(), 0.0);
215
216 let v3 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
217 assert_eq!(v1.similarity(&v3, SimilarityMetric::Cosine).unwrap(), 1.0); }
219
220 #[test]
221 fn test_normalize() {
222 let v = Vector::from_f32(vec![3.0, 4.0]);
223 let normalized = v.normalize().unwrap();
224 let data = normalized.as_f32().unwrap();
225 assert!((data[0] - 0.6).abs() < 1e-6);
226 assert!((data[1] - 0.8).abs() < 1e-6);
227
228 let norm = (data[0] * data[0] + data[1] * data[1]).sqrt();
230 assert!((norm - 1.0).abs() < 1e-6);
231 }
232}