use std::sync::Arc;
use crate::error::{Error, Result};
#[derive(Clone, Debug)]
pub struct Embedding(Arc<[f32]>);
impl Embedding {
pub const EMBED_DIM: usize = 768;
pub const NORM_EPSILON: f32 = 5e-4;
pub fn dim(&self) -> usize {
self.0.len()
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
pub fn into_inner(self) -> Arc<[f32]> {
self.0
}
pub fn try_cosine(&self, other: &Embedding) -> Result<f32> {
if self.dim() != other.dim() {
return Err(Error::EmbeddingDim {
expected: self.dim(),
got: other.dim(),
});
}
let a: &[f32; Self::EMBED_DIM] =
self
.as_slice()
.try_into()
.map_err(|_| Error::EmbeddingDim {
expected: Self::EMBED_DIM,
got: self.dim(),
})?;
let b: &[f32; Self::EMBED_DIM] =
other
.as_slice()
.try_into()
.map_err(|_| Error::EmbeddingDim {
expected: Self::EMBED_DIM,
got: other.dim(),
})?;
Ok(crate::simd::dot_768(a, b))
}
#[cfg(feature = "inference")]
pub(crate) fn from_model_output(data: &[f32]) -> Result<Self> {
let arr: &[f32; Self::EMBED_DIM] = data.try_into().map_err(|_| Error::EmbeddingDim {
expected: Self::EMBED_DIM,
got: data.len(),
})?;
let norm_sq = crate::simd::dot_768(arr, arr);
let norm = norm_sq.sqrt();
if !norm.is_finite() || norm == 0.0 {
return Err(Error::NotNormalized {
norm,
epsilon: Self::NORM_EPSILON,
});
}
let factor = 1.0 / norm;
let arc: Arc<[f32]> = data.iter().map(|&x| x * factor).collect();
Ok(Self(arc))
}
}
impl TryFrom<Vec<f32>> for Embedding {
type Error = Error;
fn try_from(mut v: Vec<f32>) -> Result<Self> {
let norm_sq = {
let arr: &[f32; Self::EMBED_DIM] =
v.as_slice().try_into().map_err(|_| Error::EmbeddingDim {
expected: Self::EMBED_DIM,
got: v.len(),
})?;
crate::simd::dot_768(arr, arr)
};
let norm = norm_sq.sqrt();
if !norm.is_finite() || (norm - 1.0).abs() > Self::NORM_EPSILON {
return Err(Error::NotNormalized {
norm,
epsilon: Self::NORM_EPSILON,
});
}
let factor = 1.0 / norm;
for x in &mut v {
*x *= factor;
}
Ok(Self(v.into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn unit_vec(dim: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[0] = 1.0;
v
}
#[test]
fn try_from_accepts_unit_norm_768() {
let v = unit_vec(768);
let e = Embedding::try_from(v).expect("unit-norm 768-dim should succeed");
assert_eq!(e.dim(), 768);
let cos = e.try_cosine(&e).expect("happy path");
assert!((cos - 1.0).abs() < 1e-5);
}
#[test]
fn try_from_rejects_wrong_dim() {
let v = vec![0.0; 100];
let err = Embedding::try_from(v).unwrap_err();
match err {
Error::EmbeddingDim { expected, got } => {
assert_eq!(expected, 768);
assert_eq!(got, 100);
}
_ => panic!("expected EmbeddingDim, got {err}"),
}
}
#[test]
fn try_from_rejects_non_unit_norm() {
let v = vec![0.5f32; 768];
let err = Embedding::try_from(v).unwrap_err();
match err {
Error::NotNormalized { .. } => {}
_ => panic!("expected NotNormalized, got {err}"),
}
}
#[cfg(feature = "inference")]
#[test]
fn from_model_output_normalizes_arbitrary_norm() {
let v = vec![1.0f32; 768];
let e = Embedding::from_model_output(&v).expect("arbitrary-norm output must be normalized");
let cos = e.try_cosine(&e).expect("happy path");
assert!(
(cos - 1.0).abs() < 1e-5,
"post-norm cosine should be 1.0; got {cos}"
);
assert!((e.as_slice()[0] - (1.0 / (768.0_f32).sqrt())).abs() < 1e-6);
}
#[cfg(feature = "inference")]
#[test]
fn from_model_output_rejects_wrong_dim() {
let v = vec![0.5f32; 100];
let err = Embedding::from_model_output(&v).unwrap_err();
match err {
Error::EmbeddingDim { expected, got } => {
assert_eq!(expected, 768);
assert_eq!(got, 100);
}
_ => panic!("expected EmbeddingDim, got {err}"),
}
}
#[cfg(feature = "inference")]
#[test]
fn from_model_output_rejects_zero_norm() {
let v = vec![0.0f32; 768];
let err = Embedding::from_model_output(&v).unwrap_err();
match err {
Error::NotNormalized { norm, .. } => assert_eq!(norm, 0.0),
_ => panic!("expected NotNormalized for zero output, got {err}"),
}
}
#[cfg(feature = "inference")]
#[test]
fn from_model_output_rejects_nan_component() {
let mut v = vec![0.5f32; 768];
v[100] = f32::NAN;
let err = Embedding::from_model_output(&v).unwrap_err();
match err {
Error::NotNormalized { norm, .. } => assert!(norm.is_nan()),
_ => panic!("expected NotNormalized for NaN, got {err}"),
}
}
#[test]
fn try_from_renormalizes_within_tolerance() {
let mut v = unit_vec(768);
v[1] = Embedding::NORM_EPSILON / 2.0;
let e = Embedding::try_from(v).expect("near-unit norm should be accepted");
let dot = e.try_cosine(&e).expect("happy path");
assert!(
(dot - 1.0).abs() < 1e-5,
"renormalized cosine should be 1.0; got {dot}"
);
}
#[test]
fn try_cosine_returns_dim_error_on_mismatch() {
let a = Embedding(vec![1.0f32, 0.0].into());
let b = Embedding(vec![1.0f32, 0.0, 0.0].into());
let err = a
.try_cosine(&b)
.expect_err("dim mismatch must surface as Err");
match err {
Error::EmbeddingDim { expected, got } => {
assert_eq!(expected, 2, "lhs dim");
assert_eq!(got, 3, "rhs dim");
}
other => panic!("expected Error::EmbeddingDim, got {other}"),
}
}
#[test]
fn try_cosine_returns_dim_error_when_both_wrong_size() {
let a = Embedding(vec![1.0f32, 0.0, 0.0, 0.0].into());
let b = Embedding(vec![0.0f32, 1.0, 0.0, 0.0].into());
let err = a
.try_cosine(&b)
.expect_err("non-EMBED_DIM operands must error");
match err {
Error::EmbeddingDim { expected, got } => {
assert_eq!(expected, Embedding::EMBED_DIM);
assert_eq!(got, 4);
}
other => panic!("expected Error::EmbeddingDim, got {other}"),
}
}
#[test]
fn try_cosine_self_unit_pair() {
let v = unit_vec(768);
let e = Embedding::try_from(v).expect("unit-norm 768-d should succeed");
let cos = e.try_cosine(&e).expect("happy path must be Ok");
assert!((cos - 1.0).abs() < 1e-5);
}
#[test]
fn into_inner_exposes_arc_unchanged() {
let v = unit_vec(768);
let e = Embedding::try_from(v).expect("unit-norm 768-d should succeed");
let arc = e.into_inner();
assert_eq!(arc.len(), 768);
assert!((arc[0] - 1.0).abs() < 1e-6);
}
#[test]
fn embedding_is_send_sync() {
fn _req<T: Send + Sync>() {}
_req::<Embedding>();
}
}