use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
pub trait EmbeddingProvider: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
fn dimension(&self) -> usize;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct MockEmbeddingProvider {
dimension: usize,
}
impl MockEmbeddingProvider {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
fn hash_text(&self, text: &str) -> u64 {
let mut hash: u64 = 5381;
for byte in text.as_bytes() {
hash = hash.wrapping_mul(33).wrapping_add(*byte as u64);
}
hash
}
}
impl EmbeddingProvider for MockEmbeddingProvider {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let hash = self.hash_text(text);
let mut embedding = Vec::with_capacity(self.dimension);
for i in 0..self.dimension {
let val = ((hash.wrapping_add(i as u64) % 1000) as f32) / 1000.0;
embedding.push(val);
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
fn dimension(&self) -> usize {
self.dimension
}
fn name(&self) -> &str {
"mock"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key: String,
pub model: String,
pub endpoint: Option<String>,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_key: String::new(),
model: "text-embedding-3-small".to_string(),
endpoint: None,
}
}
}
#[derive(Debug, Clone)]
pub struct OpenAIEmbeddingProvider {
#[allow(dead_code)]
config: OpenAIConfig,
dimension: usize,
}
impl OpenAIEmbeddingProvider {
pub fn new(config: OpenAIConfig) -> Result<Self> {
let dimension = match config.model.as_str() {
"text-embedding-ada-002" => 1536,
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
_ => anyhow::bail!("Unknown OpenAI model: {}", config.model),
};
Ok(Self { config, dimension })
}
}
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
anyhow::bail!(
"OpenAI provider requires HTTP client implementation (add reqwest dependency)"
)
}
fn dimension(&self) -> usize {
self.dimension
}
fn name(&self) -> &str {
"openai"
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
embedding: Vec<f32>,
created_at: SystemTime,
}
#[derive(Debug, Clone)]
pub struct EmbeddingCache {
cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
ttl: Duration,
max_entries: usize,
}
impl EmbeddingCache {
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
ttl,
max_entries,
}
}
pub fn get(&self, text: &str) -> Option<Vec<f32>> {
let cache = self.cache.lock().unwrap();
if let Some(entry) = cache.get(text) {
let elapsed = SystemTime::now()
.duration_since(entry.created_at)
.unwrap_or(Duration::MAX);
if elapsed < self.ttl {
return Some(entry.embedding.clone());
}
}
None
}
pub fn put(&self, text: String, embedding: Vec<f32>) {
let mut cache = self.cache.lock().unwrap();
if cache.len() >= self.max_entries {
self.evict_oldest(&mut cache);
}
cache.insert(
text,
CacheEntry {
embedding,
created_at: SystemTime::now(),
},
);
}
pub fn clear(&self) {
let mut cache = self.cache.lock().unwrap();
cache.clear();
}
pub fn size(&self) -> usize {
let cache = self.cache.lock().unwrap();
cache.len()
}
fn evict_oldest(&self, cache: &mut HashMap<String, CacheEntry>) {
if cache.is_empty() {
return;
}
let oldest_key = cache
.iter()
.min_by_key(|(_, entry)| entry.created_at)
.map(|(key, _)| key.clone());
if let Some(key) = oldest_key {
cache.remove(&key);
}
}
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::new(Duration::from_secs(3600), 10000) }
}
pub struct CachedEmbeddingProvider<P: EmbeddingProvider> {
provider: P,
cache: EmbeddingCache,
}
impl<P: EmbeddingProvider> CachedEmbeddingProvider<P> {
pub fn new(provider: P, cache: EmbeddingCache) -> Self {
Self { provider, cache }
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn cache_size(&self) -> usize {
self.cache.size()
}
}
impl<P: EmbeddingProvider> EmbeddingProvider for CachedEmbeddingProvider<P> {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
if let Some(embedding) = self.cache.get(text) {
return Ok(embedding);
}
let embedding = self.provider.embed(text)?;
self.cache.put(text.to_string(), embedding.clone());
Ok(embedding)
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
let mut uncached_indices = Vec::new();
let mut uncached_texts = Vec::new();
for (i, text) in texts.iter().enumerate() {
if let Some(embedding) = self.cache.get(text) {
results.push(Some(embedding));
} else {
results.push(None);
uncached_indices.push(i);
uncached_texts.push(*text);
}
}
if !uncached_texts.is_empty() {
let new_embeddings = self.provider.embed_batch(&uncached_texts)?;
for (idx, embedding) in uncached_indices.iter().zip(new_embeddings.iter()) {
self.cache.put(texts[*idx].to_string(), embedding.clone());
results[*idx] = Some(embedding.clone());
}
}
results
.into_iter()
.map(|opt| opt.context("Missing embedding"))
.collect()
}
fn dimension(&self) -> usize {
self.provider.dimension()
}
fn name(&self) -> &str {
self.provider.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_provider() {
let provider = MockEmbeddingProvider::new(384);
assert_eq!(provider.dimension(), 384);
assert_eq!(provider.name(), "mock");
let embedding = provider.embed("Hello, world!").unwrap();
assert_eq!(embedding.len(), 384);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_mock_provider_deterministic() {
let provider = MockEmbeddingProvider::new(128);
let embedding1 = provider.embed("test").unwrap();
let embedding2 = provider.embed("test").unwrap();
assert_eq!(embedding1, embedding2);
}
#[test]
fn test_mock_provider_different_texts() {
let provider = MockEmbeddingProvider::new(128);
let embedding1 = provider.embed("hello").unwrap();
let embedding2 = provider.embed("world").unwrap();
assert_ne!(embedding1, embedding2);
}
#[test]
fn test_mock_provider_batch() {
let provider = MockEmbeddingProvider::new(256);
let texts = vec!["text1", "text2", "text3"];
let embeddings = provider.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 3);
assert_eq!(embeddings[0].len(), 256);
assert_eq!(embeddings[1].len(), 256);
assert_eq!(embeddings[2].len(), 256);
}
#[test]
fn test_openai_provider_creation() {
let config = OpenAIConfig {
api_key: "test-key".to_string(),
model: "text-embedding-ada-002".to_string(),
endpoint: None,
};
let provider = OpenAIEmbeddingProvider::new(config).unwrap();
assert_eq!(provider.dimension(), 1536);
assert_eq!(provider.name(), "openai");
}
#[test]
fn test_openai_provider_unknown_model() {
let config = OpenAIConfig {
api_key: "test-key".to_string(),
model: "unknown-model".to_string(),
endpoint: None,
};
let result = OpenAIEmbeddingProvider::new(config);
assert!(result.is_err());
}
#[test]
fn test_embedding_cache() {
let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
assert_eq!(cache.size(), 0);
assert!(cache.get("test").is_none());
cache.put("test".to_string(), vec![1.0, 2.0, 3.0]);
assert_eq!(cache.size(), 1);
let embedding = cache.get("test").unwrap();
assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
cache.clear();
assert_eq!(cache.size(), 0);
assert!(cache.get("test").is_none());
}
#[test]
fn test_embedding_cache_max_entries() {
let cache = EmbeddingCache::new(Duration::from_secs(10), 3);
cache.put("key1".to_string(), vec![1.0]);
cache.put("key2".to_string(), vec![2.0]);
cache.put("key3".to_string(), vec![3.0]);
assert_eq!(cache.size(), 3);
cache.put("key4".to_string(), vec![4.0]);
assert_eq!(cache.size(), 3);
assert!(cache.get("key1").is_none());
assert!(cache.get("key4").is_some());
}
#[test]
fn test_cached_provider() {
let mock_provider = MockEmbeddingProvider::new(128);
let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
assert_eq!(cached_provider.cache_size(), 0);
let embedding1 = cached_provider.embed("test").unwrap();
assert_eq!(cached_provider.cache_size(), 1);
let embedding2 = cached_provider.embed("test").unwrap();
assert_eq!(cached_provider.cache_size(), 1);
assert_eq!(embedding1, embedding2);
}
#[test]
fn test_cached_provider_batch() {
let mock_provider = MockEmbeddingProvider::new(64);
let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
let texts = vec!["text1", "text2", "text3"];
let embeddings1 = cached_provider.embed_batch(&texts).unwrap();
assert_eq!(cached_provider.cache_size(), 3);
let embeddings2 = cached_provider.embed_batch(&texts).unwrap();
assert_eq!(cached_provider.cache_size(), 3);
assert_eq!(embeddings1, embeddings2);
}
#[test]
fn test_cached_provider_partial_cache() {
let mock_provider = MockEmbeddingProvider::new(32);
let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
cached_provider.embed("text1").unwrap();
cached_provider.embed("text2").unwrap();
assert_eq!(cached_provider.cache_size(), 2);
let texts = vec!["text1", "text2", "text3", "text4"];
let embeddings = cached_provider.embed_batch(&texts).unwrap();
assert_eq!(embeddings.len(), 4);
assert_eq!(cached_provider.cache_size(), 4);
}
#[test]
fn test_cache_clear() {
let mock_provider = MockEmbeddingProvider::new(16);
let cache = EmbeddingCache::new(Duration::from_secs(10), 100);
let cached_provider = CachedEmbeddingProvider::new(mock_provider, cache);
cached_provider.embed("test1").unwrap();
cached_provider.embed("test2").unwrap();
assert_eq!(cached_provider.cache_size(), 2);
cached_provider.clear_cache();
assert_eq!(cached_provider.cache_size(), 0);
}
}