use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
pub const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
pub const DEFAULT_CACHE_MAX_ENTRIES: usize = 10000;
#[async_trait]
pub trait EmbeddingModel: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
fn model_name(&self) -> &str;
fn provider(&self) -> &str;
}
#[derive(Debug, Clone)]
struct CacheEntry {
embedding: Vec<f32>,
created_at: Instant,
access_count: usize,
}
#[derive(Debug)]
pub struct EmbeddingCache {
store: RwLock<HashMap<String, CacheEntry>>,
max_entries: usize,
ttl_secs: u64,
}
impl EmbeddingCache {
pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
Self {
store: RwLock::new(HashMap::new()),
max_entries,
ttl_secs,
}
}
pub fn default_cache() -> Self {
Self::new(DEFAULT_CACHE_MAX_ENTRIES, DEFAULT_CACHE_TTL_SECS)
}
fn cache_key(provider: &str, model: &str, text: &str) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
provider.hash(&mut hasher);
model.hash(&mut hasher);
text.hash(&mut hasher);
format!("{}:{}:{:016x}", provider, model, hasher.finish())
}
pub async fn get(&self, provider: &str, model: &str, text: &str) -> Option<Vec<f32>> {
let key = Self::cache_key(provider, model, text);
let mut store = self.store.write().await;
if let Some(entry) = store.get_mut(&key) {
if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
store.remove(&key);
return None;
}
entry.access_count += 1;
return Some(entry.embedding.clone());
}
None
}
pub async fn put(&self, provider: &str, model: &str, text: &str, embedding: Vec<f32>) {
let key = Self::cache_key(provider, model, text);
let mut store = self.store.write().await;
if store.len() >= self.max_entries {
if let Some((lru_key, _)) = store
.iter()
.min_by_key(|(_, e)| e.access_count)
.map(|(k, v)| (k.clone(), v.access_count))
{
store.remove(&lru_key);
}
}
store.insert(
key,
CacheEntry {
embedding,
created_at: Instant::now(),
access_count: 0,
},
);
}
pub async fn get_batch(
&self,
provider: &str,
model: &str,
texts: &[String],
) -> Vec<Option<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
let mut store = self.store.write().await;
for text in texts {
let key = Self::cache_key(provider, model, text);
if let Some(entry) = store.get_mut(&key) {
if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
store.remove(&key);
results.push(None);
} else {
entry.access_count += 1;
results.push(Some(entry.embedding.clone()));
}
} else {
results.push(None);
}
}
results
}
pub async fn clear(&self) {
let mut store = self.store.write().await;
store.clear();
}
pub async fn stats(&self) -> CacheStats {
let store = self.store.read().await;
let total_entries = store.len();
let total_access: usize = store.values().map(|e| e.access_count).sum();
CacheStats {
total_entries,
total_access,
max_entries: self.max_entries,
ttl_secs: self.ttl_secs,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub total_access: usize,
pub max_entries: usize,
pub ttl_secs: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EmbeddingProvider {
OpenAI,
HuggingFace,
Cohere,
Local,
Mock,
}
impl EmbeddingProvider {
pub fn as_str(&self) -> &'static str {
match self {
Self::OpenAI => "openai",
Self::HuggingFace => "huggingface",
Self::Cohere => "cohere",
Self::Local => "local",
Self::Mock => "mock",
}
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingsConfig {
pub provider: EmbeddingProvider,
pub api_key: String,
pub base_url: Option<String>,
pub model: String,
pub dimension: Option<usize>,
}
impl Default for EmbeddingsConfig {
fn default() -> Self {
Self {
provider: EmbeddingProvider::Mock,
api_key: String::new(),
base_url: None,
model: "mock-embedding".to_string(),
dimension: Some(DEFAULT_EMBEDDING_DIMENSION),
}
}
}
impl EmbeddingsConfig {
pub fn openai_from_env() -> Result<Self> {
let api_key = std::env::var("OPENAI_API_KEY")
.map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
let base_url = std::env::var("OPENAI_BASE_URL")
.ok()
.or_else(|| Some("https://api.openai.com/v1".to_string()));
let model = std::env::var("OPENAI_EMBEDDING_MODEL")
.unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
Ok(Self {
provider: EmbeddingProvider::OpenAI,
api_key,
base_url,
model,
dimension: None,
})
}
pub fn huggingface_from_env() -> Result<Self> {
let api_key = std::env::var("HUGGINGFACE_API_KEY")
.map_err(|_| anyhow!("HUGGINGFACE_API_KEY environment variable not set"))?;
let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "sentence-transformers/all-MiniLM-L6-v2".to_string());
Ok(Self {
provider: EmbeddingProvider::HuggingFace,
api_key,
base_url: Some(
"https://api-inference.huggingface.co/pipeline/feature-extraction".to_string(),
),
model,
dimension: None,
})
}
pub fn cohere_from_env() -> Result<Self> {
let api_key = std::env::var("COHERE_API_KEY")
.map_err(|_| anyhow!("COHERE_API_KEY environment variable not set"))?;
let model = std::env::var("COHERE_EMBEDDING_MODEL")
.unwrap_or_else(|_| "embed-english-v3.0".to_string());
Ok(Self {
provider: EmbeddingProvider::Cohere,
api_key,
base_url: Some("https://api.cohere.ai/v1".to_string()),
model,
dimension: None,
})
}
pub fn local(model: impl Into<String>, dimension: Option<usize>) -> Self {
Self {
provider: EmbeddingProvider::Local,
api_key: String::new(),
base_url: None,
model: model.into(),
dimension,
}
}
pub fn is_valid(&self) -> bool {
matches!(
self.provider,
EmbeddingProvider::Local | EmbeddingProvider::Mock
) || !self.api_key.is_empty()
}
}
#[derive(Debug)]
pub struct OpenAIEmbeddings {
client: Client,
config: EmbeddingsConfig,
cache: Option<Arc<EmbeddingCache>>,
}
impl OpenAIEmbeddings {
pub fn new(config: EmbeddingsConfig) -> Result<Self> {
if !config.is_valid() {
return Err(anyhow!("OpenAI Embeddings API not configured"));
}
Ok(Self {
client: Client::new(),
config,
cache: None,
})
}
pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
let mut embeddings = Self::new(config)?;
embeddings.cache = Some(cache);
Ok(embeddings)
}
fn base_url(&self) -> &str {
self.config
.base_url
.as_deref()
.unwrap_or("https://api.openai.com/v1")
}
}
#[async_trait]
impl EmbeddingModel for OpenAIEmbeddings {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("No embedding returned"))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if let Some(cache) = &self.cache {
let cached = cache.get_batch("openai", &self.config.model, texts).await;
let all_cached = cached.iter().all(|c| c.is_some());
if all_cached {
return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
}
}
let url = format!("{}/embeddings", self.base_url());
let request_body = OpenAiEmbeddingRequest {
model: self.config.model.clone(),
input: texts.to_vec(),
encoding_format: Some("float".to_string()),
};
tracing::debug!("Sending OpenAI embedding request for {} texts", texts.len());
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
tracing::error!("OpenAI Embedding API error: {} - {}", status, response_text);
return Err(anyhow!(
"OpenAI Embedding API request failed with status {}: {}",
status,
response_text
));
}
let response_body: OpenAiEmbeddingResponse =
serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse OpenAI embedding response: {} - {}",
e,
response_text
)
})?;
let mut embeddings: Vec<(usize, Vec<f32>)> = response_body
.data
.into_iter()
.map(|item| (item.index, item.embedding))
.collect();
embeddings.sort_by_key(|(idx, _)| *idx);
let result: Vec<Vec<f32>> = embeddings.into_iter().map(|(_, emb)| emb).collect();
if let Some(cache) = &self.cache {
for (text, embedding) in texts.iter().zip(result.iter()) {
cache
.put("openai", &self.config.model, text, embedding.clone())
.await;
}
}
Ok(result)
}
fn dimension(&self) -> usize {
match self.config.model.as_str() {
"text-embedding-ada-002" => 1536,
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
_ => DEFAULT_EMBEDDING_DIMENSION,
}
}
fn model_name(&self) -> &str {
&self.config.model
}
fn provider(&self) -> &str {
"openai"
}
}
#[derive(Serialize)]
struct OpenAiEmbeddingRequest {
model: String,
input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<String>,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingResponse {
data: Vec<OpenAiEmbeddingData>,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
usage: OpenAiEmbeddingUsage,
}
#[derive(Deserialize)]
struct OpenAiEmbeddingData {
embedding: Vec<f32>,
index: usize,
#[allow(dead_code)]
object: String,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct OpenAiEmbeddingUsage {
prompt_tokens: u32,
total_tokens: u32,
}
#[derive(Debug)]
pub struct HuggingFaceEmbeddings {
client: Client,
config: EmbeddingsConfig,
cache: Option<Arc<EmbeddingCache>>,
}
impl HuggingFaceEmbeddings {
pub fn new(config: EmbeddingsConfig) -> Result<Self> {
if !config.is_valid() {
return Err(anyhow!("HuggingFace API not configured"));
}
Ok(Self {
client: Client::new(),
config,
cache: None,
})
}
pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
let mut embeddings = Self::new(config)?;
embeddings.cache = Some(cache);
Ok(embeddings)
}
}
#[async_trait]
impl EmbeddingModel for HuggingFaceEmbeddings {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("No embedding returned from HuggingFace"))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if let Some(cache) = &self.cache {
let cached = cache
.get_batch("huggingface", &self.config.model, texts)
.await;
let all_cached = cached.iter().all(|c| c.is_some());
if all_cached {
return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
}
}
let url = format!(
"https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
self.config.model
);
tracing::debug!(
"Sending HuggingFace embedding request for {} texts",
texts.len()
);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&serde_json::json!({ "inputs": texts }))
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
tracing::error!("HuggingFace API error: {} - {}", status, response_text);
return Err(anyhow!(
"HuggingFace API request failed with status {}: {}",
status,
response_text
));
}
let embeddings: Vec<Vec<f32>> = serde_json::from_str(&response_text).map_err(|e| {
anyhow!(
"Failed to parse HuggingFace response: {} - {}",
e,
response_text
)
})?;
if let Some(cache) = &self.cache {
for (text, embedding) in texts.iter().zip(embeddings.iter()) {
cache
.put("huggingface", &self.config.model, text, embedding.clone())
.await;
}
}
Ok(embeddings)
}
fn dimension(&self) -> usize {
match self.config.model.as_str() {
"sentence-transformers/all-MiniLM-L6-v2" => 384,
"sentence-transformers/all-mpnet-base-v2" => 768,
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" => 384,
_ => self.config.dimension.unwrap_or(768),
}
}
fn model_name(&self) -> &str {
&self.config.model
}
fn provider(&self) -> &str {
"huggingface"
}
}
#[derive(Debug)]
pub struct CohereEmbeddings {
client: Client,
config: EmbeddingsConfig,
cache: Option<Arc<EmbeddingCache>>,
}
impl CohereEmbeddings {
pub fn new(config: EmbeddingsConfig) -> Result<Self> {
if !config.is_valid() {
return Err(anyhow!("Cohere API not configured"));
}
Ok(Self {
client: Client::new(),
config,
cache: None,
})
}
pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
let mut embeddings = Self::new(config)?;
embeddings.cache = Some(cache);
Ok(embeddings)
}
}
#[async_trait]
impl EmbeddingModel for CohereEmbeddings {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.embed_batch(&[text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow!("No embedding returned from Cohere"))
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if let Some(cache) = &self.cache {
let cached = cache.get_batch("cohere", &self.config.model, texts).await;
let all_cached = cached.iter().all(|c| c.is_some());
if all_cached {
return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
}
}
let url = "https://api.cohere.ai/v1/embed";
let request_body = CohereEmbeddingRequest {
model: self.config.model.clone(),
texts: texts.to_vec(),
input_type: "search_document",
embedding_types: Some(vec!["float".to_string()]),
};
tracing::debug!("Sending Cohere embedding request for {} texts", texts.len());
let response = self
.client
.post(url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await?;
let status = response.status();
let response_text = response.text().await?;
if !status.is_success() {
tracing::error!("Cohere API error: {} - {}", status, response_text);
return Err(anyhow!(
"Cohere API request failed with status {}: {}",
status,
response_text
));
}
let response_body: CohereEmbeddingResponse = serde_json::from_str(&response_text)
.map_err(|e| anyhow!("Failed to parse Cohere response: {} - {}", e, response_text))?;
let result = response_body.embeddings.float;
if let Some(cache) = &self.cache {
for (text, embedding) in texts.iter().zip(result.iter()) {
cache
.put("cohere", &self.config.model, text, embedding.clone())
.await;
}
}
Ok(result)
}
fn dimension(&self) -> usize {
match self.config.model.as_str() {
"embed-english-v3.0" | "embed-english-light-v3.0" => 1024,
"embed-multilingual-v3.0" => 1024,
"embed-english-v2.0" => 4096,
_ => self.config.dimension.unwrap_or(1024),
}
}
fn model_name(&self) -> &str {
&self.config.model
}
fn provider(&self) -> &str {
"cohere"
}
}
#[derive(Serialize)]
struct CohereEmbeddingRequest {
model: String,
texts: Vec<String>,
input_type: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
embedding_types: Option<Vec<String>>,
}
#[derive(Deserialize)]
struct CohereEmbeddingResponse {
embeddings: CohereEmbeddingsData,
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
text_type: String,
}
#[derive(Deserialize)]
struct CohereEmbeddingsData {
float: Vec<Vec<f32>>,
}
pub struct LocalEmbeddings {
config: EmbeddingsConfig,
cache: Option<Arc<EmbeddingCache>>,
#[cfg(feature = "local-embeddings")]
#[allow(dead_code)]
model: Option<std::sync::Mutex<Box<dyn LocalModelBackend>>>,
}
impl std::fmt::Debug for LocalEmbeddings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalEmbeddings")
.field("config", &self.config)
.field("cache", &self.cache)
.field("model", &"<model>")
.finish()
}
}
impl LocalEmbeddings {
pub fn new(config: EmbeddingsConfig) -> Result<Self> {
Ok(Self {
config,
cache: None,
#[cfg(feature = "local-embeddings")]
model: None,
})
}
pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
let mut embeddings = Self::new(config)?;
embeddings.cache = Some(cache);
Ok(embeddings)
}
#[cfg(feature = "local-embeddings")]
pub fn load_model(&mut self) -> Result<()> {
tracing::info!("Loading local embedding model: {}", self.config.model);
Ok(())
}
}
#[async_trait]
impl EmbeddingModel for LocalEmbeddings {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
if let Some(cache) = &self.cache {
if let Some(embedding) = cache.get("local", &self.config.model, text).await {
return Ok(embedding);
}
}
#[cfg(feature = "local-embeddings")]
{
let embedding = vec![0.0f32; self.dimension()];
if let Some(cache) = &self.cache {
cache
.put("local", &self.config.model, text, embedding.clone())
.await;
}
Ok(embedding)
}
#[cfg(not(feature = "local-embeddings"))]
{
Err(anyhow!(
"Local embeddings require 'local-embeddings' feature. \
Enable it in Cargo.toml and ensure candle or ort is available."
))
}
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if let Some(cache) = &self.cache {
let cached = cache.get_batch("local", &self.config.model, texts).await;
if cached.iter().all(|c| c.is_some()) {
return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
}
}
#[cfg(feature = "local-embeddings")]
{
let mut results = Vec::with_capacity(texts.len());
for text in texts {
results.push(self.embed(text).await?);
}
if let Some(cache) = &self.cache {
for (text, embedding) in texts.iter().zip(results.iter()) {
cache
.put("local", &self.config.model, text, embedding.clone())
.await;
}
}
Ok(results)
}
#[cfg(not(feature = "local-embeddings"))]
{
Err(anyhow!(
"Local embeddings require 'local-embeddings' feature"
))
}
}
fn dimension(&self) -> usize {
self.config.dimension.unwrap_or(384)
}
fn model_name(&self) -> &str {
&self.config.model
}
fn provider(&self) -> &str {
"local"
}
}
#[cfg(feature = "local-embeddings")]
#[allow(dead_code)]
trait LocalModelBackend: Send + Sync {
fn encode(&self, text: &str) -> Result<Vec<f32>>;
}
pub struct EmbeddingsFactory {
cache: Arc<EmbeddingCache>,
}
impl EmbeddingsFactory {
pub fn new() -> Self {
Self {
cache: Arc::new(EmbeddingCache::default_cache()),
}
}
pub fn with_cache(cache: Arc<EmbeddingCache>) -> Self {
Self { cache }
}
pub fn create(&self, config: EmbeddingsConfig) -> Result<Box<dyn EmbeddingModel>> {
match config.provider {
EmbeddingProvider::OpenAI => Ok(Box::new(OpenAIEmbeddings::with_cache(
config,
self.cache.clone(),
)?)),
EmbeddingProvider::HuggingFace => Ok(Box::new(HuggingFaceEmbeddings::with_cache(
config,
self.cache.clone(),
)?)),
EmbeddingProvider::Cohere => Ok(Box::new(CohereEmbeddings::with_cache(
config,
self.cache.clone(),
)?)),
EmbeddingProvider::Local => Ok(Box::new(LocalEmbeddings::with_cache(
config,
self.cache.clone(),
)?)),
EmbeddingProvider::Mock => {
let dimension = config.dimension.unwrap_or(DEFAULT_EMBEDDING_DIMENSION);
#[cfg(any(feature = "mock", test))]
{
Ok(Box::new(MockEmbeddingModel::with_name(
dimension,
&config.model,
)))
}
#[cfg(not(any(feature = "mock", test)))]
{
let local_config = EmbeddingsConfig::local(&config.model, Some(dimension));
Ok(Box::new(LocalEmbeddings::new(local_config)?))
}
}
}
}
pub fn create_safe(&self, config: EmbeddingsConfig) -> Box<dyn EmbeddingModel> {
if config.is_valid() {
self.create(config)
.unwrap_or_else(|_| self.create_mock_default())
} else {
self.create_mock_default()
}
}
fn create_mock_default(&self) -> Box<dyn EmbeddingModel> {
#[cfg(any(feature = "mock", test))]
{
Box::new(MockEmbeddingModel::new(DEFAULT_EMBEDDING_DIMENSION))
}
#[cfg(not(any(feature = "mock", test)))]
{
let config = EmbeddingsConfig::local("fallback", Some(DEFAULT_EMBEDDING_DIMENSION));
Box::new(LocalEmbeddings::new(config).expect("Local embeddings should always work"))
}
}
pub fn openai(&self) -> Result<Box<dyn EmbeddingModel>> {
let config = EmbeddingsConfig::openai_from_env()?;
self.create(config)
}
pub fn huggingface(&self) -> Result<Box<dyn EmbeddingModel>> {
let config = EmbeddingsConfig::huggingface_from_env()?;
self.create(config)
}
pub fn cohere(&self) -> Result<Box<dyn EmbeddingModel>> {
let config = EmbeddingsConfig::cohere_from_env()?;
self.create(config)
}
pub fn local(&self, model: &str, dimension: Option<usize>) -> Result<Box<dyn EmbeddingModel>> {
let config = EmbeddingsConfig::local(model, dimension);
self.create(config)
}
#[cfg(any(feature = "mock", test))]
pub fn mock(&self, dimension: usize) -> Box<dyn EmbeddingModel> {
Box::new(MockEmbeddingModel::new(dimension))
}
pub fn cache(&self) -> Arc<EmbeddingCache> {
self.cache.clone()
}
}
impl Default for EmbeddingsFactory {
fn default() -> Self {
Self::new()
}
}
#[cfg(any(feature = "mock", test))]
pub struct MockEmbeddingModel {
dimension: usize,
model_name: String,
}
#[cfg(any(feature = "mock", test))]
impl MockEmbeddingModel {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
model_name: "mock-embedding".to_string(),
}
}
pub fn with_name(dimension: usize, model_name: impl Into<String>) -> Self {
Self {
dimension,
model_name: model_name.into(),
}
}
}
#[cfg(any(feature = "mock", test))]
#[async_trait]
impl EmbeddingModel for MockEmbeddingModel {
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.0; self.dimension])
}
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|_| vec![0.0; self.dimension]).collect())
}
fn dimension(&self) -> usize {
self.dimension
}
fn model_name(&self) -> &str {
&self.model_name
}
fn provider(&self) -> &str {
"mock"
}
}
pub type Embeddings = OpenAIEmbeddings;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_basic_operations() {
let cache = EmbeddingCache::new(100, 3600);
let embedding = vec![0.1f32, 0.2, 0.3];
cache
.put("openai", "test-model", "hello", embedding.clone())
.await;
let cached = cache.get("openai", "test-model", "hello").await;
assert!(cached.is_some());
assert_eq!(cached.unwrap(), embedding);
let not_cached = cache.get("openai", "test-model", "not-exists").await;
assert!(not_cached.is_none());
}
#[tokio::test]
async fn test_cache_batch_operations() {
let cache = EmbeddingCache::new(100, 3600);
let texts: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| vec![t.len() as f32]).collect();
for (text, emb) in texts.iter().zip(embeddings.iter()) {
cache.put("test", "model", text, emb.clone()).await;
}
let cached = cache.get_batch("test", "model", &texts).await;
assert!(cached.iter().all(|c| c.is_some()));
}
#[tokio::test]
async fn test_cache_stats() {
let cache = EmbeddingCache::new(100, 3600);
cache.put("test", "model", "a", vec![1.0f32]).await;
cache.put("test", "model", "b", vec![2.0]).await;
let _ = cache.get("test", "model", "a").await;
let _ = cache.get("test", "model", "a").await;
let stats = cache.stats().await;
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.total_access, 2);
}
#[test]
fn test_config_openai_from_env() {
std::env::set_var("OPENAI_API_KEY", "test_key");
std::env::remove_var("OPENAI_BASE_URL");
std::env::remove_var("OPENAI_EMBEDDING_MODEL");
let config = EmbeddingsConfig::openai_from_env().unwrap();
assert_eq!(config.api_key, "test_key");
assert_eq!(config.model, DEFAULT_EMBEDDING_MODEL);
std::env::remove_var("OPENAI_API_KEY");
}
#[test]
fn test_config_huggingface_from_env() {
std::env::set_var("HUGGINGFACE_API_KEY", "hf_test");
std::env::remove_var("HUGGINGFACE_EMBEDDING_MODEL");
let config = EmbeddingsConfig::huggingface_from_env().unwrap();
assert_eq!(config.api_key, "hf_test");
assert!(config.model.contains("sentence-transformers"));
std::env::remove_var("HUGGINGFACE_API_KEY");
}
#[test]
fn test_config_cohere_from_env() {
std::env::set_var("COHERE_API_KEY", "cohere_test");
std::env::remove_var("COHERE_EMBEDDING_MODEL");
let config = EmbeddingsConfig::cohere_from_env().unwrap();
assert_eq!(config.api_key, "cohere_test");
assert!(config.model.starts_with("embed-"));
std::env::remove_var("COHERE_API_KEY");
}
#[test]
fn test_config_local() {
let config = EmbeddingsConfig::local("all-MiniLM-L6-v2", Some(384));
assert_eq!(config.provider, EmbeddingProvider::Local);
assert!(config.api_key.is_empty());
assert!(config.is_valid()); }
#[test]
fn test_openai_dimension() {
let config = EmbeddingsConfig {
provider: EmbeddingProvider::OpenAI,
api_key: "test".to_string(),
base_url: None,
model: "text-embedding-ada-002".to_string(),
dimension: None,
};
let embeddings = OpenAIEmbeddings::new(config).unwrap();
assert_eq!(embeddings.dimension(), 1536);
let config = EmbeddingsConfig {
provider: EmbeddingProvider::OpenAI,
api_key: "test".to_string(),
base_url: None,
model: "text-embedding-3-large".to_string(),
dimension: None,
};
let embeddings = OpenAIEmbeddings::new(config).unwrap();
assert_eq!(embeddings.dimension(), 3072);
}
#[test]
fn test_huggingface_dimension() {
let config = EmbeddingsConfig {
provider: EmbeddingProvider::HuggingFace,
api_key: "test".to_string(),
base_url: None,
model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
dimension: None,
};
let embeddings = HuggingFaceEmbeddings::new(config).unwrap();
assert_eq!(embeddings.dimension(), 384);
}
#[test]
fn test_cohere_dimension() {
let config = EmbeddingsConfig {
provider: EmbeddingProvider::Cohere,
api_key: "test".to_string(),
base_url: None,
model: "embed-english-v3.0".to_string(),
dimension: None,
};
let embeddings = CohereEmbeddings::new(config).unwrap();
assert_eq!(embeddings.dimension(), 1024);
}
#[test]
fn test_factory_create_openai() {
std::env::set_var("OPENAI_API_KEY", "test_key");
let factory = EmbeddingsFactory::new();
let model = factory.openai().unwrap();
assert_eq!(model.provider(), "openai");
std::env::remove_var("OPENAI_API_KEY");
}
#[test]
fn test_factory_create_local() {
let factory = EmbeddingsFactory::new();
let model = factory.local("test-model", Some(384)).unwrap();
assert_eq!(model.provider(), "local");
assert_eq!(model.dimension(), 384);
}
#[test]
fn test_factory_create_mock() {
let factory = EmbeddingsFactory::new();
let model = factory.mock(512);
assert_eq!(model.provider(), "mock");
assert_eq!(model.dimension(), 512);
}
#[test]
fn test_factory_create_safe_with_invalid_config() {
let factory = EmbeddingsFactory::new();
let config = EmbeddingsConfig {
provider: EmbeddingProvider::OpenAI,
api_key: String::new(),
base_url: None,
model: "test".to_string(),
dimension: None,
};
let model = factory.create_safe(config);
assert_eq!(model.provider(), "mock");
}
#[test]
fn test_factory_create_safe_with_valid_config() {
std::env::set_var("OPENAI_API_KEY", "test_key");
let factory = EmbeddingsFactory::new();
let config = EmbeddingsConfig::openai_from_env().unwrap();
let model = factory.create_safe(config);
assert_eq!(model.provider(), "openai");
std::env::remove_var("OPENAI_API_KEY");
}
#[test]
fn test_config_default_is_safe() {
let config = EmbeddingsConfig::default();
assert_eq!(config.provider, EmbeddingProvider::Mock);
assert!(config.is_valid());
}
#[test]
fn test_provider_mock_is_valid() {
let config = EmbeddingsConfig {
provider: EmbeddingProvider::Mock,
api_key: String::new(),
base_url: None,
model: "mock-test".to_string(),
dimension: Some(256),
};
assert!(config.is_valid());
}
#[test]
fn test_embeddings_factory_mock_default_dimension() {
let factory = EmbeddingsFactory::new();
let model = factory.mock(DEFAULT_EMBEDDING_DIMENSION);
assert_eq!(model.dimension(), DEFAULT_EMBEDDING_DIMENSION);
}
#[test]
fn test_backward_compatible_embeddings() {
std::env::set_var("OPENAI_API_KEY", "test_key");
let config = EmbeddingsConfig::openai_from_env().unwrap();
let embeddings = Embeddings::new(config).unwrap();
assert_eq!(embeddings.provider(), "openai");
std::env::remove_var("OPENAI_API_KEY");
}
}