use async_trait::async_trait;
use std::fmt;
#[derive(Debug)]
#[non_exhaustive]
pub enum EmbedError {
Transport(String),
Provider(String),
BadInput(String),
}
impl fmt::Display for EmbedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EmbedError::Transport(s) => write!(f, "embed transport: {s}"),
EmbedError::Provider(s) => write!(f, "embed provider: {s}"),
EmbedError::BadInput(s) => write!(f, "embed bad input: {s}"),
}
}
}
impl std::error::Error for EmbedError {}
#[async_trait]
pub trait Embedder: Send + Sync + 'static {
async fn embed(&self, inputs: &[&str]) -> Result<Vec<Vec<f32>>, EmbedError>;
fn dim(&self) -> usize;
fn handle(&self) -> &str;
}
#[async_trait]
pub trait EmbedderExt: Embedder {
async fn embed_one(&self, input: &str) -> Result<Vec<f32>, EmbedError> {
let mut out = self.embed(&[input]).await?;
out.pop()
.ok_or_else(|| EmbedError::Provider("empty result for single input".into()))
}
}
impl<T: Embedder + ?Sized> EmbedderExt for T {}
pub fn l2_normalize(v: &mut [f32]) {
let mut s = 0.0f32;
for &x in v.iter() {
s += x * x;
}
if s <= 0.0 {
return;
}
let inv = 1.0 / s.sqrt();
for x in v.iter_mut() {
*x *= inv;
}
}
pub fn dot(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let mut s = 0.0f32;
for i in 0..n {
s += a[i] * b[i];
}
s
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_unit_length() {
let mut v = vec![3.0f32, 4.0];
l2_normalize(&mut v);
let len = (v[0] * v[0] + v[1] * v[1]).sqrt();
assert!((len - 1.0).abs() < 1e-6);
}
#[test]
fn normalize_zero_noop() {
let mut v = vec![0.0f32, 0.0];
l2_normalize(&mut v);
assert_eq!(v, vec![0.0, 0.0]);
}
#[test]
fn dot_matches_naive() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
assert!((dot(&a, &b) - (1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0)).abs() < 1e-6);
}
#[test]
fn cosine_via_normalized_dot() {
let mut a = vec![3.0f32, 4.0];
let mut b = vec![4.0f32, 3.0];
l2_normalize(&mut a);
l2_normalize(&mut b);
let cos = dot(&a, &b);
assert!((cos - 0.96).abs() < 1e-4);
}
}