1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[repr(u8)]
6pub enum EmbeddingType {
7 F32 = 0x01,
8 F16 = 0x02,
9 I8 = 0x03,
10 U8 = 0x04,
11 Binary = 0x05,
12}
13
14impl fmt::Display for EmbeddingType {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 match self {
17 EmbeddingType::F32 => write!(f, "F32"),
18 EmbeddingType::F16 => write!(f, "F16"),
19 EmbeddingType::I8 => write!(f, "I8"),
20 EmbeddingType::U8 => write!(f, "U8"),
21 EmbeddingType::Binary => write!(f, "Binary"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27#[repr(u8)]
28pub enum SimilarityMetric {
29 Cosine = 0x01,
30 Euclidean = 0x02,
31 DotProduct = 0x03,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct Vector {
36 pub dtype: EmbeddingType,
37 pub dim: u16,
38 pub data: Vec<u8>, }
40
41impl Vector {
42 pub fn new(dtype: EmbeddingType, dim: u16, data: Vec<u8>) -> Self {
43 Self { dtype, dim, data }
44 }
45
46 pub fn from_f32(data: Vec<f32>) -> Self {
47 let mut bytes = Vec::with_capacity(data.len() * 4);
48 for val in &data {
49 bytes.extend_from_slice(&val.to_le_bytes());
50 }
51 Self {
52 dtype: EmbeddingType::F32,
53 dim: data.len() as u16,
54 data: bytes,
55 }
56 }
57
58 pub fn as_f32(&self) -> Result<Vec<f32>, String> {
59 if self.dtype != EmbeddingType::F32 {
60 return Err(format!("Cannot convert {:?} to F32", self.dtype));
61 }
62
63 if !self.data.len().is_multiple_of(4) {
64 return Err("Invalid data length for F32".to_string());
65 }
66
67 let mut res = Vec::with_capacity(self.data.len() / 4);
68 for chunk in self.data.chunks_exact(4) {
69 let val = f32::from_le_bytes(chunk.try_into().unwrap());
70 res.push(val);
71 }
72 Ok(res)
73 }
74
75 pub fn similarity(&self, other: &Vector, metric: SimilarityMetric) -> Result<f32, String> {
76 if self.dtype != other.dtype {
77 return Err("DType mismatch".to_string());
78 }
79 if self.dim != other.dim {
80 return Err("Dimension mismatch".to_string());
81 }
82
83 if self.dtype == EmbeddingType::F32 {
85 match metric {
87 SimilarityMetric::Cosine => {
88 let (dot, norm1_sq, norm2_sq) =
89 unsafe { Self::dot_and_norms_f32(&self.data, &other.data) };
90
91 let norm1 = norm1_sq.sqrt();
92 let norm2 = norm2_sq.sqrt();
93
94 if norm1 == 0.0 || norm2 == 0.0 {
95 return Ok(0.0);
96 }
97 Ok(dot / (norm1 * norm2))
98 }
99 SimilarityMetric::DotProduct => {
100 let dot = unsafe { Self::dot_product_f32(&self.data, &other.data) };
101 Ok(dot)
102 }
103 SimilarityMetric::Euclidean => {
104 let sum_sq =
105 unsafe { Self::euclidean_distance_sq_f32(&self.data, &other.data) };
106 Ok(sum_sq.sqrt())
107 }
108 }
109 } else {
110 Err("Similarity not implemented for this dtype yet".to_string())
111 }
112 }
113
114 #[inline]
117 unsafe fn dot_product_f32(data1: &[u8], data2: &[u8]) -> f32 {
118 let len = data1.len() / 4;
119 let ptr1 = data1.as_ptr() as *const f32;
120 let ptr2 = data2.as_ptr() as *const f32;
121
122 let mut sum = 0.0f32;
123 for i in 0..len {
124 sum += (*ptr1.add(i)) * (*ptr2.add(i));
125 }
126 sum
127 }
128
129 #[inline]
132 unsafe fn euclidean_distance_sq_f32(data1: &[u8], data2: &[u8]) -> f32 {
133 let len = data1.len() / 4;
134 let ptr1 = data1.as_ptr() as *const f32;
135 let ptr2 = data2.as_ptr() as *const f32;
136
137 let mut sum = 0.0f32;
138 for i in 0..len {
139 let diff = *ptr1.add(i) - *ptr2.add(i);
140 sum += diff * diff;
141 }
142 sum
143 }
144
145 #[inline]
149 unsafe fn dot_and_norms_f32(data1: &[u8], data2: &[u8]) -> (f32, f32, f32) {
150 let len = data1.len() / 4;
151 let ptr1 = data1.as_ptr() as *const f32;
152 let ptr2 = data2.as_ptr() as *const f32;
153
154 let mut dot = 0.0f32;
155 let mut norm1_sq = 0.0f32;
156 let mut norm2_sq = 0.0f32;
157
158 for i in 0..len {
159 let v1 = *ptr1.add(i);
160 let v2 = *ptr2.add(i);
161 dot += v1 * v2;
162 norm1_sq += v1 * v1;
163 norm2_sq += v2 * v2;
164 }
165
166 (dot, norm1_sq, norm2_sq)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_f32_conversion() {
176 let data = vec![1.0, 2.0, 3.0];
177 let vec = Vector::from_f32(data.clone());
178 assert_eq!(vec.dtype, EmbeddingType::F32);
179 assert_eq!(vec.dim, 3);
180 assert_eq!(vec.as_f32().unwrap(), data);
181 }
182
183 #[test]
184 fn test_cosine_similarity() {
185 let v1 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
186 let v2 = Vector::from_f32(vec![0.0, 1.0, 0.0]);
187 assert_eq!(v1.similarity(&v2, SimilarityMetric::Cosine).unwrap(), 0.0);
188
189 let v3 = Vector::from_f32(vec![1.0, 0.0, 0.0]);
190 assert_eq!(v1.similarity(&v3, SimilarityMetric::Cosine).unwrap(), 1.0); }
192}