semantic_search/
embedding.rs1use super::SenseError;
26use std::{convert::TryFrom, ops::Deref};
27
28pub type EmbeddingRaw = [f32; 1024];
30
31pub type EmbeddingBytes = [u8; 1024 * 4];
33
34#[derive(Debug, Clone, PartialEq)]
38pub struct Embedding {
39 inner: EmbeddingRaw,
40 norm: f32,
41}
42
43impl Embedding {
46 #[must_use]
48 pub fn cosine_similarity(&self, other: &Self) -> f32 {
49 let dot_product: f32 = self.iter().zip(other.iter()).map(|(a, b)| a * b).sum();
50 dot_product / (self.norm * other.norm)
51 }
52}
53
54impl Default for Embedding {
55 fn default() -> Self {
56 Self {
57 inner: [0.0; 1024],
58 norm: 0.0,
59 }
60 }
61}
62
63impl From<EmbeddingRaw> for Embedding {
66 fn from(inner: EmbeddingRaw) -> Self {
68 let norm = inner.iter().map(|a| a * a).sum::<f32>().sqrt();
69 Self { inner, norm }
70 }
71}
72
73impl From<EmbeddingBytes> for Embedding {
74 fn from(bytes: EmbeddingBytes) -> Self {
76 let mut embedding = [0.0; 1024];
77 bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
78 let f = f32::from_le_bytes(chunk.try_into().unwrap()); embedding[i] = f;
80 });
81 Self::from(embedding)
82 }
83}
84
85impl From<Embedding> for EmbeddingBytes {
86 fn from(embedding: Embedding) -> Self {
88 let mut bytes = [0; 1024 * 4];
89 bytes
90 .chunks_exact_mut(4)
91 .enumerate()
92 .for_each(|(i, chunk)| {
93 let f = embedding[i];
94 chunk.copy_from_slice(&f.to_le_bytes());
95 });
96 bytes
97 }
98}
99
100impl TryFrom<&[f32]> for Embedding {
101 type Error = SenseError;
102
103 fn try_from(value: &[f32]) -> Result<Self, Self::Error> {
109 let embedding: EmbeddingRaw = value.try_into()?;
110 Ok(Self::from(embedding))
111 }
112}
113
114impl TryFrom<&[u8]> for Embedding {
115 type Error = SenseError;
116
117 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
123 let bytes: EmbeddingBytes = value.try_into()?;
124 Ok(Self::from(bytes))
125 }
126}
127
128impl TryFrom<Vec<f32>> for Embedding {
129 type Error = SenseError;
130
131 fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
137 let embedding: EmbeddingRaw = value.try_into()?;
138 Ok(Self::from(embedding))
139 }
140}
141
142impl TryFrom<Vec<u8>> for Embedding {
143 type Error = SenseError;
144
145 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
151 let bytes: EmbeddingBytes = value.try_into()?;
152 Ok(Self::from(bytes))
153 }
154}
155
156impl Deref for Embedding {
159 type Target = EmbeddingRaw;
160
161 fn deref(&self) -> &Self::Target {
162 &self.inner
163 }
164}
165
166#[cfg(test)]
169mod tests {
170 use super::*;
171
172 const EMBEDDING_FLOAT: f32 = 1.14; const EMBEDDING_CHUNK: [u8; 4] = [0x85, 0xEB, 0x91, 0x3F];
174
175 #[test]
176 #[allow(clippy::float_cmp, reason = "They should be equal exactly")]
177 fn embedding_from_bytes() {
178 let mut bytes = [0; 1024 * 4];
179 bytes.chunks_exact_mut(4).for_each(|chunk| {
180 chunk.copy_from_slice(&EMBEDDING_CHUNK);
181 });
182
183 let embedding = Embedding::from(bytes);
184 embedding
185 .iter()
186 .for_each(|&f| assert_eq!(f, EMBEDDING_FLOAT));
187 }
188
189 #[test]
190 fn bytes_from_embedding() {
191 let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
192 let bytes = EmbeddingBytes::from(embedding);
193
194 bytes.chunks_exact(4).for_each(|chunk| {
195 assert_eq!(chunk, EMBEDDING_CHUNK);
196 });
197 }
198
199 #[test]
200 fn similar_to_self() {
201 let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
202 let similarity = embedding.cosine_similarity(&embedding);
203 let delta = (similarity - 1.0).abs();
204 assert!(delta <= f32::EPSILON);
206 }
207}