use anyhow::Result;
use std::future::Future;
use std::pin::Pin;
pub type EmbedFuture<'a> = Pin<Box<dyn Future<Output = Result<Vec<f32>>> + Send + 'a>>;
pub trait Embedder: Send + Sync {
fn embed_one<'a>(&'a self, text: &'a str) -> EmbedFuture<'a>;
}
impl Embedder for ailloy::Client {
fn embed_one<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
Box::pin(async move { ailloy::Client::embed_one(self, text).await })
}
}
pub struct DeterministicEmbedder {
pub dims: usize,
}
impl DeterministicEmbedder {
pub fn new(dims: usize) -> Self {
Self { dims }
}
}
impl Embedder for DeterministicEmbedder {
fn embed_one<'a>(&'a self, text: &'a str) -> EmbedFuture<'a> {
Box::pin(async move {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut out = Vec::with_capacity(self.dims);
for i in 0..self.dims {
let mut h = DefaultHasher::new();
(i as u32).hash(&mut h);
text.hash(&mut h);
let raw = h.finish();
let f = (raw as f64 / u64::MAX as f64) * 2.0 - 1.0;
out.push(f as f32);
}
Ok(out)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn deterministic_embedder_is_stable() {
let e = DeterministicEmbedder::new(8);
let v1 = e.embed_one("hello").await.unwrap();
let v2 = e.embed_one("hello").await.unwrap();
assert_eq!(v1, v2);
assert_eq!(v1.len(), 8);
}
#[tokio::test]
async fn deterministic_embedder_differs_by_input() {
let e = DeterministicEmbedder::new(16);
let a = e.embed_one("foo").await.unwrap();
let b = e.embed_one("bar").await.unwrap();
assert_ne!(a, b);
}
}