semantic_search/
embedding.rs1use super::SenseError;
26use std::{convert::TryFrom, ops::Deref};
27
28pub type EmbeddingRaw = [f32; 1024];
30
31pub type EmbeddingBytes = [u8; 1024 * 4];
33
34#[derive(Debug, Clone, PartialEq)]
38pub struct Embedding {
39 inner: EmbeddingRaw,
40 norm: f32,
41}
42
43impl Embedding {
46 #[must_use]
48 pub fn cosine_similarity(&self, other: &Self) -> f32 {
49 let dot_product = self
50 .iter()
51 .zip(other.iter())
52 .map(|(a, b)| a * b)
53 .sum::<f32>();
54 dot_product / (self.norm * other.norm)
55 }
56}
57
58impl Default for Embedding {
59 fn default() -> Self {
60 Self {
61 inner: [0.0; 1024],
62 norm: 0.0,
63 }
64 }
65}
66
67impl From<EmbeddingRaw> for Embedding {
70 fn from(inner: EmbeddingRaw) -> Self {
72 let norm = inner.iter().map(|a| a * a).sum::<f32>().sqrt();
73 Self { inner, norm }
74 }
75}
76
77impl From<EmbeddingBytes> for Embedding {
78 fn from(bytes: EmbeddingBytes) -> Self {
80 let mut embedding = [0.0; 1024];
81 bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
82 let f = f32::from_le_bytes(chunk.try_into().unwrap()); embedding[i] = f;
84 });
85 Self::from(embedding)
86 }
87}
88
89impl From<Embedding> for EmbeddingBytes {
90 fn from(embedding: Embedding) -> Self {
92 let mut bytes = [0; 1024 * 4];
93 bytes
94 .chunks_exact_mut(4)
95 .enumerate()
96 .for_each(|(i, chunk)| {
97 let f = embedding[i];
98 chunk.copy_from_slice(&f.to_le_bytes());
99 });
100 bytes
101 }
102}
103
104impl TryFrom<&[f32]> for Embedding {
105 type Error = SenseError;
106
107 fn try_from(value: &[f32]) -> Result<Self, Self::Error> {
113 let embedding: EmbeddingRaw = value.try_into()?;
114 Ok(Self::from(embedding))
115 }
116}
117
118impl TryFrom<&[u8]> for Embedding {
119 type Error = SenseError;
120
121 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
127 let bytes: EmbeddingBytes = value.try_into()?;
128 Ok(Self::from(bytes))
129 }
130}
131
132impl TryFrom<Vec<f32>> for Embedding {
133 type Error = SenseError;
134
135 fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
141 let embedding: EmbeddingRaw = value.try_into()?;
142 Ok(Self::from(embedding))
143 }
144}
145
146impl TryFrom<Vec<u8>> for Embedding {
147 type Error = SenseError;
148
149 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
155 let bytes: EmbeddingBytes = value.try_into()?;
156 Ok(Self::from(bytes))
157 }
158}
159
160impl Deref for Embedding {
163 type Target = EmbeddingRaw;
164
165 fn deref(&self) -> &Self::Target {
166 &self.inner
167 }
168}
169
170#[cfg(test)]
173mod tests {
174 use super::*;
175
176 const EMBEDDING_FLOAT: f32 = 1.14; const EMBEDDING_CHUNK: [u8; 4] = [0x85, 0xEB, 0x91, 0x3F];
178
179 #[test]
180 #[allow(clippy::float_cmp, reason = "They should be equal exactly")]
181 fn embedding_from_bytes() {
182 let mut bytes = [0; 1024 * 4];
183 bytes.chunks_exact_mut(4).for_each(|chunk| {
184 chunk.copy_from_slice(&EMBEDDING_CHUNK);
185 });
186
187 let embedding = Embedding::from(bytes);
188 embedding
189 .iter()
190 .for_each(|&f| assert_eq!(f, EMBEDDING_FLOAT));
191 }
192
193 #[test]
194 fn bytes_from_embedding() {
195 let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
196 let bytes = EmbeddingBytes::from(embedding);
197
198 bytes.chunks_exact(4).for_each(|chunk| {
199 assert_eq!(chunk, EMBEDDING_CHUNK);
200 });
201 }
202
203 #[test]
204 fn similar_to_self() {
205 let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
206 let similarity = embedding.cosine_similarity(&embedding);
207 let delta = (similarity - 1.0).abs();
208 assert!(delta <= f32::EPSILON);
210 }
211}