#[derive(Debug, thiserror::Error)]
pub enum EmbedderError {
#[error("embedder failed: {0}")]
Failed(String),
#[error("input too long: {len} tokens > max {max}")]
TooLong { len: usize, max: usize },
}
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError>;
fn dimension(&self) -> usize;
}
#[cfg(test)]
pub mod testing {
use super::*;
use std::sync::Mutex;
pub struct MockEmbedder {
scripts: Mutex<Vec<Vec<f32>>>,
dim: usize,
}
impl MockEmbedder {
pub fn new(dim: usize, scripts: Vec<Vec<f32>>) -> Self {
for v in &scripts {
assert_eq!(v.len(), dim, "mock vector length mismatches dim");
}
Self {
scripts: Mutex::new(scripts),
dim,
}
}
pub fn always(dim: usize, vector: Vec<f32>) -> AlwaysEmbedder {
assert_eq!(vector.len(), dim);
AlwaysEmbedder { vector, dim }
}
}
impl Embedder for MockEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
let mut guard = self.scripts.lock().unwrap();
if guard.is_empty() {
Err(EmbedderError::Failed("script exhausted".into()))
} else {
Ok(guard.remove(0))
}
}
fn dimension(&self) -> usize {
self.dim
}
}
pub struct AlwaysEmbedder {
pub vector: Vec<f32>,
pub dim: usize,
}
impl Embedder for AlwaysEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbedderError> {
Ok(self.vector.clone())
}
fn dimension(&self) -> usize {
self.dim
}
}
}
#[cfg(test)]
mod tests {
use super::testing::*;
use super::*;
#[test]
fn mock_embedder_yields_scripts_in_order() {
let e = MockEmbedder::new(2, vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
assert_eq!(e.dimension(), 2);
assert_eq!(e.embed("a").unwrap(), vec![1.0, 0.0]);
assert_eq!(e.embed("b").unwrap(), vec![0.0, 1.0]);
assert!(e.embed("c").is_err());
}
#[test]
fn always_embedder_returns_constant() {
let e = MockEmbedder::always(3, vec![1.0, 0.0, 0.0]);
assert_eq!(e.embed("x").unwrap(), vec![1.0, 0.0, 0.0]);
assert_eq!(e.embed("y").unwrap(), vec![1.0, 0.0, 0.0]);
}
}