use super::SenseError;
use std::{convert::TryFrom, ops::Deref};
pub type EmbeddingRaw = [f32; 1024];
pub type EmbeddingBytes = [u8; 1024 * 4];
#[derive(Debug, Clone, PartialEq)]
pub struct Embedding {
inner: EmbeddingRaw,
norm: f32,
}
impl Embedding {
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f32 {
let dot_product: f32 = self.iter().zip(other.iter()).map(|(a, b)| a * b).sum();
dot_product / (self.norm * other.norm)
}
}
impl Default for Embedding {
fn default() -> Self {
Self {
inner: [0.0; 1024],
norm: 0.0,
}
}
}
impl From<EmbeddingRaw> for Embedding {
fn from(inner: EmbeddingRaw) -> Self {
let norm = inner.iter().map(|a| a * a).sum::<f32>().sqrt();
Self { inner, norm }
}
}
impl From<EmbeddingBytes> for Embedding {
fn from(bytes: EmbeddingBytes) -> Self {
let mut embedding = [0.0; 1024];
bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
let f = f32::from_le_bytes(chunk.try_into().unwrap()); embedding[i] = f;
});
Self::from(embedding)
}
}
impl From<Embedding> for EmbeddingBytes {
fn from(embedding: Embedding) -> Self {
let mut bytes = [0; 1024 * 4];
bytes
.chunks_exact_mut(4)
.enumerate()
.for_each(|(i, chunk)| {
let f = embedding[i];
chunk.copy_from_slice(&f.to_le_bytes());
});
bytes
}
}
impl TryFrom<&[f32]> for Embedding {
type Error = SenseError;
fn try_from(value: &[f32]) -> Result<Self, Self::Error> {
let embedding: EmbeddingRaw = value.try_into()?;
Ok(Self::from(embedding))
}
}
impl TryFrom<&[u8]> for Embedding {
type Error = SenseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
let bytes: EmbeddingBytes = value.try_into()?;
Ok(Self::from(bytes))
}
}
impl TryFrom<Vec<f32>> for Embedding {
type Error = SenseError;
fn try_from(value: Vec<f32>) -> Result<Self, Self::Error> {
let embedding: EmbeddingRaw = value.try_into()?;
Ok(Self::from(embedding))
}
}
impl TryFrom<Vec<u8>> for Embedding {
type Error = SenseError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
let bytes: EmbeddingBytes = value.try_into()?;
Ok(Self::from(bytes))
}
}
impl Deref for Embedding {
type Target = EmbeddingRaw;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[cfg(test)]
mod tests {
use super::*;
const EMBEDDING_FLOAT: f32 = 1.14; const EMBEDDING_CHUNK: [u8; 4] = [0x85, 0xEB, 0x91, 0x3F];
#[test]
#[allow(clippy::float_cmp, reason = "They should be equal exactly")]
fn embedding_from_bytes() {
let mut bytes = [0; 1024 * 4];
bytes.chunks_exact_mut(4).for_each(|chunk| {
chunk.copy_from_slice(&EMBEDDING_CHUNK);
});
let embedding = Embedding::from(bytes);
embedding
.iter()
.for_each(|&f| assert_eq!(f, EMBEDDING_FLOAT));
}
#[test]
fn bytes_from_embedding() {
let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
let bytes = EmbeddingBytes::from(embedding);
bytes.chunks_exact(4).for_each(|chunk| {
assert_eq!(chunk, EMBEDDING_CHUNK);
});
}
#[test]
fn similar_to_self() {
let embedding = Embedding::from([EMBEDDING_FLOAT; 1024]);
let similarity = embedding.cosine_similarity(&embedding);
let delta = (similarity - 1.0).abs();
assert!(delta <= f32::EPSILON);
}
}