use std::error::Error;
use std::fmt;
use std::hash::{Hash, Hasher};
const PLACEHOLDER_EMBED_DIM: usize = 16;
#[derive(Debug)]
pub enum EmbedderError {
Transport(String),
ModelMissing,
Other(String),
}
impl fmt::Display for EmbedderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Transport(m) => write!(f, "embedder transport: {m}"),
Self::ModelMissing => f.write_str("embedder model not loaded"),
Self::Other(m) => write!(f, "embedder: {m}"),
}
}
}
impl Error for EmbedderError {}
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
texts.iter().map(|t| self.embed(t)).collect()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PlaceholderEmbedder;
impl PlaceholderEmbedder {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl Embedder for PlaceholderEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
use std::collections::hash_map::DefaultHasher;
let mut h = DefaultHasher::new();
text.hash(&mut h);
let x = h.finish();
let mut v = vec![0f32; PLACEHOLDER_EMBED_DIM];
for i in 0..PLACEHOLDER_EMBED_DIM {
v[i] = (((x >> (i * 4)) & 0xF) as f32) / 15.0;
}
let n: f32 = v.iter().map(|e| e * e).sum::<f32>().sqrt();
if n > 0.0 {
for t in v.iter_mut() {
*t /= n;
}
}
Ok(v)
}
}
#[must_use]
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
0.0
} else {
dot / (na * nb)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cosine_identical_is_one() {
let v = vec![1.0, 2.0, 3.0];
assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_orthogonal_is_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_mismatched_lengths_returns_zero() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(cosine(&a, &b), 0.0);
}
#[test]
fn placeholder_l2_unit_vector() {
let e = PlaceholderEmbedder::new();
let v = e.embed("hello world").expect("ok");
assert_eq!(v.len(), PLACEHOLDER_EMBED_DIM);
let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((n - 1.0).abs() < 1e-5 || n.abs() < 1e-5);
}
}