semantic_search/
embedding.rs

1//! # Embedding module
2//!
3//! Embedding types, representation, conversion and calculation. Assumes little-endian byte order.
4//!
5//! ## Types
6//!
7//! - [`EmbeddingRaw`]: Raw embedding representation, alias for `[f32; 1024]`.
8//! - [`EmbeddingBytes`]: Embedding represented in bytes (little-endian), alias for `[u8; 1024 * 4]`.
9//! - [`Embedding`]: Wrapped embedding representation.
10//!
11//! ## Representation
12//!
13//! Embedding is represented as a 1024-dimensional vector of 32-bit floating point numbers. [`Embedding`] struct is a wrapper around  [`EmbeddingRaw`] (alias for `[f32; 1024]`), and provides methods for conversion and calculation.
14//!
15//! ## Conversion
16//!
17//! - [`Embedding`] can be converted from [`EmbeddingRaw`] and [`EmbeddingBytes`].
18//! - [`Embedding`] can be immutably dereferenced to [`EmbeddingRaw`] and converted to [`EmbeddingBytes`].
19//! - [`Embedding`] can be converted from `&[f32]`, `&[u8]`, `Vec<f32>` and `Vec<u8>`, but [`DimensionMismatch`](SenseError::DimensionMismatch) error is returned if the length mismatches.
20//!
21//! ## Calculation
22//!
23//! Cosine similarity between two embeddings can be calculated using [`cosine_similarity`](Embedding::cosine_similarity) method.
24
25use super::SenseError;
26use std::{convert::TryFrom, ops::Deref};
27
28/// Raw embedding representation.
29pub type EmbeddingRaw = [f32; 1024];
30
31/// Embedding represented in bytes (little-endian).
32pub type EmbeddingBytes = [u8; 1024 * 4];
33
34/// Wrapped embedding representation.
35///
36/// See [module-level documentation](crate::embedding) for more details.
37#[derive(Debug, Clone, PartialEq)]
38pub struct Embedding {
39    inner: EmbeddingRaw,
40    norm: f32,
41}
42
43// Cosine similarity calculation
44
45impl Embedding {
46    /// Calculate cosine similarity between two embeddings.
47    #[must_use]
48    pub fn cosine_similarity(&self, other: &Self) -> f32 {
49        let dot_product: f32 = self.iter().zip(other.iter()).map(|(a, b)| a * b).sum();
50        dot_product / (self.norm * other.norm)
51    }
52}
53
54impl Default for Embedding {
55    fn default() -> Self {
56        Self {
57            inner: [0.0; 1024],
58            norm: 0.0,
59        }
60    }
61}
62
63// Convertion
64
65impl From<EmbeddingRaw> for Embedding {
66    /// Convert `[f32; 1024]` to `Embedding`.
67    fn from(inner: EmbeddingRaw) -> Self {
68        let norm = inner.iter().map(|a| a * a).sum::<f32>().sqrt();
69        Self { inner, norm }
70    }
71}
72
73impl From<EmbeddingBytes> for Embedding {
74    /// Convert 1024 * 4 bytes to `Embedding` (little-endian).
75    fn from(bytes: EmbeddingBytes) -> Self {
76        let mut embedding = [0.0; 1024];
77        bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
78            let f = f32::from_le_bytes(chunk.try_into().unwrap()); // Safe to unwrap, as we know the length is 4
79            embedding[i] = f;
80        });
81        Self::from(embedding)
82    }
83}
84
85impl From<Embedding> for EmbeddingBytes {
86    /// Convert `Embedding` to 1024 * 4 bytes (little-endian).
87    fn from(embedding: Embedding) -> Self {
88        let mut bytes = [0; 1024 * 4];
89        bytes
90            .chunks_exact_mut(4)
91            .enumerate()
92            .for_each(|(i, chunk)| {
93                let f = embedding[i];
94                chunk.copy_from_slice(&f.to_le_bytes());
95            });
96        bytes
97    }
98}
99
100impl TryFrom<&[f32]> for Embedding {
101    type Error = SenseError;
102
103    /// Convert `&[f32]` to `Embedding`.
104    ///
105    /// # Errors
106    ///
107    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input slice is not 1024.
108    fn try_from(value: &[f32]) -> Result<Self, Self::Error> {
109        let embedding: EmbeddingRaw = value.try_into()?;
110        Ok(Self::from(embedding))
111    }
112}
113
114impl TryFrom<&[u8]> for Embedding {
115    type Error = SenseError;
116
117    /// Convert `&[u8]` to `Embedding`.
118    ///
119    /// # Errors
120    ///
121    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input slice is not 1024 * 4.
122    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
123        let bytes: EmbeddingBytes = value.try_into()?;
124        Ok(Self::from(bytes))
125    }
126}
127
128impl TryFrom<Vec<f32>> for Embedding {
129    type Error = SenseError;
130
131    /// Convert `Vec<f32>` to `Embedding`.
132    ///
133    /// # Errors
134    ///
135    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input vector is not 1024.
136    fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
137        let embedding: EmbeddingRaw = value.try_into()?;
138        Ok(Self::from(embedding))
139    }
140}
141
142impl TryFrom<Vec<u8>> for Embedding {
143    type Error = SenseError;
144
145    /// Convert `Vec<u8>` to `Embedding`.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`DimensionMismatch`](SenseError::DimensionMismatch) if the length of the input vector is not 1024 * 4.
150    fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
151        let bytes: EmbeddingBytes = value.try_into()?;
152        Ok(Self::from(bytes))
153    }
154}
155
156// Implement `Deref` for `Embedding`
157
158impl Deref for Embedding {
159    type Target = EmbeddingRaw;
160
161    fn deref(&self) -> &Self::Target {
162        &self.inner
163    }
164}
165
166// Should not mutate the inner representation, since `norm` is cached based on it
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    const EMBEDDING_FLOAT: f32 = 1.14; // 0x3F91EB85
173    const EMBEDDING_CHUNK: [u8; 4] = [0x85, 0xEB, 0x91, 0x3F];
174
175    #[test]
176    #[allow(clippy::float_cmp, reason = "They should be equal exactly")]
177    fn embedding_from_bytes() {
178        let mut bytes = [0; 1024 * 4];
179        bytes.chunks_exact_mut(4).for_each(|chunk| {
180            chunk.copy_from_slice(&EMBEDDING_CHUNK);
181        });
182
183        let embedding = Embedding::from(bytes);
184        embedding
185            .iter()
186            .for_each(|&f| assert_eq!(f, EMBEDDING_FLOAT));
187    }
188
189    #[test]
190    fn bytes_from_embedding() {
191        let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
192        let bytes = EmbeddingBytes::from(embedding);
193
194        bytes.chunks_exact(4).for_each(|chunk| {
195            assert_eq!(chunk, EMBEDDING_CHUNK);
196        });
197    }
198
199    #[test]
200    fn similar_to_self() {
201        let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
202        let similarity = embedding.cosine_similarity(&embedding);
203        let delta = (similarity - 1.0).abs();
204        // Approximate equality
205        assert!(delta <= f32::EPSILON);
206    }
207}