use thiserror::Error;
#[derive(Debug, Error)]
pub enum EmbedError {
#[error("embed failed: {0}")]
Failure(String),
#[error("embed produced wrong dim: expected {expected}, got {got}")]
DimensionMismatch { expected: u16, got: u16 },
}
pub trait Embedder: Send + Sync {
fn id(&self) -> &str;
fn version(&self) -> &str;
fn dim(&self) -> u16;
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError>;
}
#[derive(Debug, Clone)]
pub struct HashEmbedder {
dim: u16,
seed: u64,
id: String,
version: String,
}
impl HashEmbedder {
pub fn new(dim: u16, seed: u64) -> Self {
let id = "prollytree:hash-embedder/v1".to_string();
let version = format!("dim={dim};seed={seed}");
Self {
dim,
seed,
id,
version,
}
}
}
impl Embedder for HashEmbedder {
fn id(&self) -> &str {
&self.id
}
fn version(&self) -> &str {
&self.version
}
fn dim(&self) -> u16 {
self.dim
}
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
use crate::digest::ValueDigest;
let mut out = vec![0.0f32; self.dim as usize];
for (i, slot) in out.iter_mut().enumerate() {
let mut buf = Vec::with_capacity(8 + 4 + text.len());
buf.extend_from_slice(&self.seed.to_le_bytes());
buf.extend_from_slice(&(i as u32).to_le_bytes());
buf.extend_from_slice(text.as_bytes());
let h = ValueDigest::<32>::new(&buf);
let b = h.as_bytes()[0] as i8;
*slot = (b as f32) / 128.0;
}
let norm: f32 = out.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in out.iter_mut() {
*x /= norm;
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_embedder_basic_shape() {
let e = HashEmbedder::new(32, 0);
let v = e.embed("hello").unwrap();
assert_eq!(v.len(), 32);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "expected unit norm, got {norm}");
}
#[test]
fn hash_embedder_deterministic() {
let e = HashEmbedder::new(8, 42);
assert_eq!(e.embed("foo").unwrap(), e.embed("foo").unwrap());
assert_ne!(e.embed("foo").unwrap(), e.embed("bar").unwrap());
}
#[test]
fn different_seeds_produce_different_vectors() {
let e1 = HashEmbedder::new(16, 0);
let e2 = HashEmbedder::new(16, 1);
assert_ne!(e1.embed("abc").unwrap(), e2.embed("abc").unwrap());
}
#[test]
fn version_changes_with_dim_and_seed() {
let v1 = HashEmbedder::new(8, 0).version().to_string();
let v2 = HashEmbedder::new(16, 0).version().to_string();
let v3 = HashEmbedder::new(8, 1).version().to_string();
assert_ne!(v1, v2);
assert_ne!(v1, v3);
}
#[test]
fn id_is_stable_across_instances() {
let a = HashEmbedder::new(8, 0);
let b = HashEmbedder::new(16, 99);
assert_eq!(a.id(), b.id()); }
#[test]
fn empty_text_is_embeddable() {
let e = HashEmbedder::new(8, 0);
let v = e.embed("").unwrap();
assert_eq!(v.len(), 8);
}
}