use anyhow::Result;
use std::fmt::Debug;
pub trait EmbeddingProvider: Debug + Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
fn dimension(&self) -> usize;
}
pub mod onnx;
pub use onnx::OnnxEmbeddingProvider;
#[cfg(test)]
mod tests {
use super::*;
pub fn test_provider_basics<P: EmbeddingProvider>(provider: &P) {
let text = "fn main() { println!(\"Hello, world!\"); }";
let embedding = provider.embed(text).unwrap();
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "Embedding should be normalized");
let texts = vec!["fn main() {}", "struct Point { x: i32, y: i32 }"];
let embeddings = provider.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 2);
assert_ne!(embeddings[0], embeddings[1]);
}
#[allow(dead_code)] fn normalize(v: &mut [f32]) {
let norm = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 1e-6 {
for x in v.iter_mut() {
*x /= norm;
}
let norm_after = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm_after - 1.0).abs() < 0.01, "Embedding should be normalized");
}
}
}