use anyhow::{Context, Result};
use async_trait::async_trait;
pub const EMBEDDING_DIM: usize = crate::memory::config::DEFAULT_EMBEDDING_DIM;
#[async_trait]
pub trait Embedder: Send + Sync {
fn name(&self) -> &'static str;
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> Vec<Result<Vec<f32>>> {
let mut out = Vec::with_capacity(texts.len());
for text in texts {
out.push(self.embed(text).await);
}
out
}
}
pub fn check_embed_dim(v: Vec<f32>, label: &str) -> Result<Vec<f32>> {
if v.len() != EMBEDDING_DIM {
anyhow::bail!(
"{label} embedder returned {} dims, expected {}",
v.len(),
EMBEDDING_DIM
);
}
Ok(v)
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
dot / (na.sqrt() * nb.sqrt())
}
pub fn pack_embedding(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for f in v {
out.extend_from_slice(&f.to_le_bytes());
}
out
}
pub fn unpack_embedding(b: &[u8]) -> Result<Vec<f32>> {
if !b.len().is_multiple_of(4) {
anyhow::bail!(
"embedding blob length {} not a multiple of 4 — corrupt row",
b.len()
);
}
let floats: Vec<f32> = b
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
if floats.len() != EMBEDDING_DIM {
anyhow::bail!(
"embedding blob length {} floats, expected {}",
floats.len(),
EMBEDDING_DIM
);
}
Ok(floats)
}
pub fn pack_checked(v: &[f32]) -> Result<Vec<u8>> {
if v.len() != EMBEDDING_DIM {
anyhow::bail!(
"embedding vector has {} dims, expected {}",
v.len(),
EMBEDDING_DIM
);
}
Ok(pack_embedding(v))
}
pub fn decode_optional_blob(
blob: Option<Vec<u8>>,
context_label: &str,
) -> Result<Option<Vec<f32>>> {
match blob {
None => Ok(None),
Some(bytes) => {
let v = unpack_embedding(&bytes)
.with_context(|| format!("decode embedding for {context_label}"))?;
Ok(Some(v))
}
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct InertEmbedder;
impl InertEmbedder {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Embedder for InertEmbedder {
fn name(&self) -> &'static str {
"inert"
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.0; EMBEDDING_DIM])
}
}
#[cfg(test)]
#[path = "embed_tests.rs"]
mod tests;