mod openai;
pub use openai::{OpenAIConfig, OpenAIEmbedding, OpenAIModel, UsageSnapshot, UsageStats};
use crate::error::{Error, Result};
use std::any::Any;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[async_trait::async_trait]
pub trait EmbeddingProvider: Any + Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
embeddings.push(self.embed(text).await?);
}
Ok(embeddings)
}
fn dimensions(&self) -> usize;
fn as_any(&self) -> &dyn Any;
}
pub struct EmbeddingEngine {
provider: Box<dyn EmbeddingProvider>,
}
impl EmbeddingEngine {
pub fn new(dimensions: usize) -> Self {
Self {
provider: Box::new(HashEmbedding::new(dimensions)),
}
}
pub fn from_env() -> Self {
if let Ok(provider) = OpenAIEmbedding::from_env() {
Self {
provider: Box::new(provider),
}
} else {
tracing::warn!("OPENAI_API_KEY not set, falling back to hash embeddings");
Self::new(1536) }
}
pub fn from_env_required() -> Result<Self> {
let provider = OpenAIEmbedding::from_env()?;
Ok(Self {
provider: Box::new(provider),
})
}
pub fn with_openai(api_key: impl Into<String>, model: Option<String>) -> Self {
Self {
provider: Box::new(OpenAIEmbedding::new(api_key, model)),
}
}
pub fn with_openai_config(api_key: impl Into<String>, config: OpenAIConfig) -> Self {
Self {
provider: Box::new(OpenAIEmbedding::with_config(api_key, config)),
}
}
pub fn with_provider(provider: Box<dyn EmbeddingProvider>) -> Self {
Self { provider }
}
pub fn dimensions(&self) -> usize {
self.provider.dimensions()
}
pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.provider.embed(text).await
}
pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.provider.embed_batch(texts).await
}
pub fn similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
}
pub struct HashEmbedding {
dimensions: usize,
}
impl HashEmbedding {
pub fn new(dimensions: usize) -> Self {
Self { dimensions }
}
fn embed_sync(&self, text: &str) -> Result<Vec<f32>> {
if text.is_empty() {
return Err(Error::embedding("Cannot embed empty text"));
}
let mut embedding = vec![0.0f32; self.dimensions];
let normalized_text = text.to_lowercase();
for word in normalized_text.split_whitespace() {
self.add_word_embedding(&mut embedding, word, 1.0);
}
let words: Vec<&str> = normalized_text.split_whitespace().collect();
for window in words.windows(2) {
let bigram = format!("{} {}", window[0], window[1]);
self.add_word_embedding(&mut embedding, &bigram, 0.5);
}
for window in words.windows(3) {
let trigram = format!("{} {} {}", window[0], window[1], window[2]);
self.add_word_embedding(&mut embedding, &trigram, 0.3);
}
for word in &words {
for char_ngram in word.as_bytes().windows(3) {
let hash = self.hash_bytes(char_ngram);
let idx = (hash as usize) % self.dimensions;
embedding[idx] += 0.1;
}
}
self.normalize(&mut embedding);
Ok(embedding)
}
fn add_word_embedding(&self, embedding: &mut [f32], text: &str, weight: f32) {
let hash = self.hash_text(text);
for i in 0..8 {
let idx = ((hash.wrapping_add(i * 0x9e37_79b9)) as usize) % self.dimensions;
let sign = if (hash >> i) & 1 == 0 { 1.0 } else { -1.0 };
embedding[idx] += sign * weight;
}
}
fn hash_text(&self, text: &str) -> u64 {
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
}
fn hash_bytes(&self, bytes: &[u8]) -> u64 {
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
hasher.finish()
}
fn normalize(&self, embedding: &mut [f32]) {
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in embedding.iter_mut() {
*x /= norm;
}
}
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for HashEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
self.embed_sync(text)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_embedding_dimensions() {
let engine = EmbeddingEngine::new(128);
let embedding = engine.embed("test text").await.unwrap();
assert_eq!(embedding.len(), 128);
}
#[tokio::test]
async fn test_embedding_consistency() {
let engine = EmbeddingEngine::new(64);
let emb1 = engine.embed("hello world").await.unwrap();
let emb2 = engine.embed("hello world").await.unwrap();
assert_eq!(emb1, emb2);
}
#[tokio::test]
async fn test_embedding_similarity() {
let engine = EmbeddingEngine::new(128);
let emb1 = engine.embed("rust programming language").await.unwrap();
let emb2 = engine.embed("rust programming").await.unwrap();
let emb3 = engine.embed("cooking recipes").await.unwrap();
let sim_similar = engine.similarity(&emb1, &emb2);
let sim_different = engine.similarity(&emb1, &emb3);
assert!(sim_similar > sim_different);
}
#[tokio::test]
async fn test_normalized_embeddings() {
let engine = EmbeddingEngine::new(256);
let embedding = engine.embed("some text here").await.unwrap();
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[tokio::test]
async fn test_empty_text_error() {
let engine = EmbeddingEngine::new(64);
assert!(engine.embed("").await.is_err());
}
}