1use anyhow::{anyhow, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20pub const DEFAULT_EMBEDDING_MODEL: &str = "text-embedding-ada-002";
26
27pub const DEFAULT_EMBEDDING_DIMENSION: usize = 1536;
29
30pub const DEFAULT_CACHE_TTL_SECS: u64 = 3600;
32
33pub const DEFAULT_CACHE_MAX_ENTRIES: usize = 10000;
35
36#[async_trait]
44pub trait EmbeddingModel: Send + Sync {
45 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
47
48 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
50
51 fn dimension(&self) -> usize;
53
54 fn model_name(&self) -> &str;
56
57 fn provider(&self) -> &str;
59}
60
61#[derive(Debug, Clone)]
67struct CacheEntry {
68 embedding: Vec<f32>,
70 created_at: Instant,
72 access_count: usize,
74}
75
76#[derive(Debug)]
80pub struct EmbeddingCache {
81 store: RwLock<HashMap<String, CacheEntry>>,
83 max_entries: usize,
85 ttl_secs: u64,
87}
88
89impl EmbeddingCache {
90 pub fn new(max_entries: usize, ttl_secs: u64) -> Self {
92 Self {
93 store: RwLock::new(HashMap::new()),
94 max_entries,
95 ttl_secs,
96 }
97 }
98
99 pub fn default_cache() -> Self {
101 Self::new(DEFAULT_CACHE_MAX_ENTRIES, DEFAULT_CACHE_TTL_SECS)
102 }
103
104 fn cache_key(provider: &str, model: &str, text: &str) -> String {
106 use std::collections::hash_map::DefaultHasher;
107 use std::hash::{Hash, Hasher};
108
109 let mut hasher = DefaultHasher::new();
110 provider.hash(&mut hasher);
111 model.hash(&mut hasher);
112 text.hash(&mut hasher);
113 format!("{}:{}:{:016x}", provider, model, hasher.finish())
114 }
115
116 pub async fn get(&self, provider: &str, model: &str, text: &str) -> Option<Vec<f32>> {
118 let key = Self::cache_key(provider, model, text);
119 let mut store = self.store.write().await;
120
121 if let Some(entry) = store.get_mut(&key) {
122 if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
124 store.remove(&key);
125 return None;
126 }
127
128 entry.access_count += 1;
129 return Some(entry.embedding.clone());
130 }
131
132 None
133 }
134
135 pub async fn put(&self, provider: &str, model: &str, text: &str, embedding: Vec<f32>) {
137 let key = Self::cache_key(provider, model, text);
138 let mut store = self.store.write().await;
139
140 if store.len() >= self.max_entries {
142 if let Some((lru_key, _)) = store
143 .iter()
144 .min_by_key(|(_, e)| e.access_count)
145 .map(|(k, v)| (k.clone(), v.access_count))
146 {
147 store.remove(&lru_key);
148 }
149 }
150
151 store.insert(
152 key,
153 CacheEntry {
154 embedding,
155 created_at: Instant::now(),
156 access_count: 0,
157 },
158 );
159 }
160
161 pub async fn get_batch(
163 &self,
164 provider: &str,
165 model: &str,
166 texts: &[String],
167 ) -> Vec<Option<Vec<f32>>> {
168 let mut results = Vec::with_capacity(texts.len());
169 let mut store = self.store.write().await;
170
171 for text in texts {
172 let key = Self::cache_key(provider, model, text);
173
174 if let Some(entry) = store.get_mut(&key) {
175 if entry.created_at.elapsed() > Duration::from_secs(self.ttl_secs) {
176 store.remove(&key);
177 results.push(None);
178 } else {
179 entry.access_count += 1;
180 results.push(Some(entry.embedding.clone()));
181 }
182 } else {
183 results.push(None);
184 }
185 }
186
187 results
188 }
189
190 pub async fn clear(&self) {
192 let mut store = self.store.write().await;
193 store.clear();
194 }
195
196 pub async fn stats(&self) -> CacheStats {
198 let store = self.store.read().await;
199 let total_entries = store.len();
200 let total_access: usize = store.values().map(|e| e.access_count).sum();
201
202 CacheStats {
203 total_entries,
204 total_access,
205 max_entries: self.max_entries,
206 ttl_secs: self.ttl_secs,
207 }
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct CacheStats {
214 pub total_entries: usize,
215 pub total_access: usize,
216 pub max_entries: usize,
217 pub ttl_secs: u64,
218}
219
220#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum EmbeddingProvider {
227 OpenAI,
228 HuggingFace,
229 Cohere,
230 Local,
231 Mock,
233}
234
235impl EmbeddingProvider {
236 pub fn as_str(&self) -> &'static str {
237 match self {
238 Self::OpenAI => "openai",
239 Self::HuggingFace => "huggingface",
240 Self::Cohere => "cohere",
241 Self::Local => "local",
242 Self::Mock => "mock",
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct EmbeddingsConfig {
250 pub provider: EmbeddingProvider,
252 pub api_key: String,
254 pub base_url: Option<String>,
256 pub model: String,
258 pub dimension: Option<usize>,
260}
261
262impl Default for EmbeddingsConfig {
263 fn default() -> Self {
264 Self {
267 provider: EmbeddingProvider::Mock,
268 api_key: String::new(),
269 base_url: None,
270 model: "mock-embedding".to_string(),
271 dimension: Some(DEFAULT_EMBEDDING_DIMENSION),
272 }
273 }
274}
275
276impl EmbeddingsConfig {
277 pub fn openai_from_env() -> Result<Self> {
279 let api_key = std::env::var("OPENAI_API_KEY")
280 .map_err(|_| anyhow!("OPENAI_API_KEY environment variable not set"))?;
281
282 let base_url = std::env::var("OPENAI_BASE_URL")
283 .ok()
284 .or_else(|| Some("https://api.openai.com/v1".to_string()));
285
286 let model = std::env::var("OPENAI_EMBEDDING_MODEL")
287 .unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
288
289 Ok(Self {
290 provider: EmbeddingProvider::OpenAI,
291 api_key,
292 base_url,
293 model,
294 dimension: None,
295 })
296 }
297
298 pub fn huggingface_from_env() -> Result<Self> {
300 let api_key = std::env::var("HUGGINGFACE_API_KEY")
301 .map_err(|_| anyhow!("HUGGINGFACE_API_KEY environment variable not set"))?;
302
303 let model = std::env::var("HUGGINGFACE_EMBEDDING_MODEL")
304 .unwrap_or_else(|_| "sentence-transformers/all-MiniLM-L6-v2".to_string());
305
306 Ok(Self {
307 provider: EmbeddingProvider::HuggingFace,
308 api_key,
309 base_url: Some(
310 "https://api-inference.huggingface.co/pipeline/feature-extraction".to_string(),
311 ),
312 model,
313 dimension: None,
314 })
315 }
316
317 pub fn cohere_from_env() -> Result<Self> {
319 let api_key = std::env::var("COHERE_API_KEY")
320 .map_err(|_| anyhow!("COHERE_API_KEY environment variable not set"))?;
321
322 let model = std::env::var("COHERE_EMBEDDING_MODEL")
323 .unwrap_or_else(|_| "embed-english-v3.0".to_string());
324
325 Ok(Self {
326 provider: EmbeddingProvider::Cohere,
327 api_key,
328 base_url: Some("https://api.cohere.ai/v1".to_string()),
329 model,
330 dimension: None,
331 })
332 }
333
334 pub fn local(model: impl Into<String>, dimension: Option<usize>) -> Self {
336 Self {
337 provider: EmbeddingProvider::Local,
338 api_key: String::new(),
339 base_url: None,
340 model: model.into(),
341 dimension,
342 }
343 }
344
345 pub fn is_valid(&self) -> bool {
347 matches!(
348 self.provider,
349 EmbeddingProvider::Local | EmbeddingProvider::Mock
350 ) || !self.api_key.is_empty()
351 }
352}
353
354#[derive(Debug)]
360pub struct OpenAIEmbeddings {
361 client: Client,
362 config: EmbeddingsConfig,
363 cache: Option<Arc<EmbeddingCache>>,
364}
365
366impl OpenAIEmbeddings {
367 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
368 if !config.is_valid() {
369 return Err(anyhow!("OpenAI Embeddings API not configured"));
370 }
371
372 Ok(Self {
373 client: Client::new(),
374 config,
375 cache: None,
376 })
377 }
378
379 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
380 let mut embeddings = Self::new(config)?;
381 embeddings.cache = Some(cache);
382 Ok(embeddings)
383 }
384
385 fn base_url(&self) -> &str {
386 self.config
387 .base_url
388 .as_deref()
389 .unwrap_or("https://api.openai.com/v1")
390 }
391}
392
393#[async_trait]
394impl EmbeddingModel for OpenAIEmbeddings {
395 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
396 let embeddings = self.embed_batch(&[text.to_string()]).await?;
397 embeddings
398 .into_iter()
399 .next()
400 .ok_or_else(|| anyhow!("No embedding returned"))
401 }
402
403 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
404 if texts.is_empty() {
405 return Ok(Vec::new());
406 }
407
408 if let Some(cache) = &self.cache {
410 let cached = cache.get_batch("openai", &self.config.model, texts).await;
411 let all_cached = cached.iter().all(|c| c.is_some());
412 if all_cached {
413 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
414 }
415 }
416
417 let url = format!("{}/embeddings", self.base_url());
418
419 let request_body = OpenAiEmbeddingRequest {
420 model: self.config.model.clone(),
421 input: texts.to_vec(),
422 encoding_format: Some("float".to_string()),
423 };
424
425 tracing::debug!("Sending OpenAI embedding request for {} texts", texts.len());
426
427 let response = self
428 .client
429 .post(&url)
430 .header("Authorization", format!("Bearer {}", self.config.api_key))
431 .header("Content-Type", "application/json")
432 .json(&request_body)
433 .send()
434 .await?;
435
436 let status = response.status();
437 let response_text = response.text().await?;
438
439 if !status.is_success() {
440 tracing::error!("OpenAI Embedding API error: {} - {}", status, response_text);
441 return Err(anyhow!(
442 "OpenAI Embedding API request failed with status {}: {}",
443 status,
444 response_text
445 ));
446 }
447
448 let response_body: OpenAiEmbeddingResponse =
449 serde_json::from_str(&response_text).map_err(|e| {
450 anyhow!(
451 "Failed to parse OpenAI embedding response: {} - {}",
452 e,
453 response_text
454 )
455 })?;
456
457 let mut embeddings: Vec<(usize, Vec<f32>)> = response_body
459 .data
460 .into_iter()
461 .map(|item| (item.index, item.embedding))
462 .collect();
463 embeddings.sort_by_key(|(idx, _)| *idx);
464 let result: Vec<Vec<f32>> = embeddings.into_iter().map(|(_, emb)| emb).collect();
465
466 if let Some(cache) = &self.cache {
468 for (text, embedding) in texts.iter().zip(result.iter()) {
469 cache
470 .put("openai", &self.config.model, text, embedding.clone())
471 .await;
472 }
473 }
474
475 Ok(result)
476 }
477
478 fn dimension(&self) -> usize {
479 match self.config.model.as_str() {
480 "text-embedding-ada-002" => 1536,
481 "text-embedding-3-small" => 1536,
482 "text-embedding-3-large" => 3072,
483 _ => DEFAULT_EMBEDDING_DIMENSION,
484 }
485 }
486
487 fn model_name(&self) -> &str {
488 &self.config.model
489 }
490
491 fn provider(&self) -> &str {
492 "openai"
493 }
494}
495
496#[derive(Serialize)]
497struct OpenAiEmbeddingRequest {
498 model: String,
499 input: Vec<String>,
500 #[serde(skip_serializing_if = "Option::is_none")]
501 encoding_format: Option<String>,
502}
503
504#[derive(Deserialize)]
505struct OpenAiEmbeddingResponse {
506 data: Vec<OpenAiEmbeddingData>,
507 #[allow(dead_code)]
508 model: String,
509 #[allow(dead_code)]
510 usage: OpenAiEmbeddingUsage,
511}
512
513#[derive(Deserialize)]
514struct OpenAiEmbeddingData {
515 embedding: Vec<f32>,
516 index: usize,
517 #[allow(dead_code)]
518 object: String,
519}
520
521#[derive(Deserialize)]
522#[allow(dead_code)]
523struct OpenAiEmbeddingUsage {
524 prompt_tokens: u32,
525 total_tokens: u32,
526}
527
528#[derive(Debug)]
534pub struct HuggingFaceEmbeddings {
535 client: Client,
536 config: EmbeddingsConfig,
537 cache: Option<Arc<EmbeddingCache>>,
538}
539
540impl HuggingFaceEmbeddings {
541 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
542 if !config.is_valid() {
543 return Err(anyhow!("HuggingFace API not configured"));
544 }
545
546 Ok(Self {
547 client: Client::new(),
548 config,
549 cache: None,
550 })
551 }
552
553 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
554 let mut embeddings = Self::new(config)?;
555 embeddings.cache = Some(cache);
556 Ok(embeddings)
557 }
558}
559
560#[async_trait]
561impl EmbeddingModel for HuggingFaceEmbeddings {
562 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
563 let embeddings = self.embed_batch(&[text.to_string()]).await?;
565 embeddings
566 .into_iter()
567 .next()
568 .ok_or_else(|| anyhow!("No embedding returned from HuggingFace"))
569 }
570
571 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
572 if texts.is_empty() {
573 return Ok(Vec::new());
574 }
575
576 if let Some(cache) = &self.cache {
578 let cached = cache
579 .get_batch("huggingface", &self.config.model, texts)
580 .await;
581 let all_cached = cached.iter().all(|c| c.is_some());
582 if all_cached {
583 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
584 }
585 }
586
587 let url = format!(
588 "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
589 self.config.model
590 );
591
592 tracing::debug!(
593 "Sending HuggingFace embedding request for {} texts",
594 texts.len()
595 );
596
597 let response = self
598 .client
599 .post(&url)
600 .header("Authorization", format!("Bearer {}", self.config.api_key))
601 .header("Content-Type", "application/json")
602 .json(&serde_json::json!({ "inputs": texts }))
603 .send()
604 .await?;
605
606 let status = response.status();
607 let response_text = response.text().await?;
608
609 if !status.is_success() {
610 tracing::error!("HuggingFace API error: {} - {}", status, response_text);
611 return Err(anyhow!(
612 "HuggingFace API request failed with status {}: {}",
613 status,
614 response_text
615 ));
616 }
617
618 let embeddings: Vec<Vec<f32>> = serde_json::from_str(&response_text).map_err(|e| {
620 anyhow!(
621 "Failed to parse HuggingFace response: {} - {}",
622 e,
623 response_text
624 )
625 })?;
626
627 if let Some(cache) = &self.cache {
629 for (text, embedding) in texts.iter().zip(embeddings.iter()) {
630 cache
631 .put("huggingface", &self.config.model, text, embedding.clone())
632 .await;
633 }
634 }
635
636 Ok(embeddings)
637 }
638
639 fn dimension(&self) -> usize {
640 match self.config.model.as_str() {
642 "sentence-transformers/all-MiniLM-L6-v2" => 384,
643 "sentence-transformers/all-mpnet-base-v2" => 768,
644 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" => 384,
645 _ => self.config.dimension.unwrap_or(768),
646 }
647 }
648
649 fn model_name(&self) -> &str {
650 &self.config.model
651 }
652
653 fn provider(&self) -> &str {
654 "huggingface"
655 }
656}
657
658#[derive(Debug)]
664pub struct CohereEmbeddings {
665 client: Client,
666 config: EmbeddingsConfig,
667 cache: Option<Arc<EmbeddingCache>>,
668}
669
670impl CohereEmbeddings {
671 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
672 if !config.is_valid() {
673 return Err(anyhow!("Cohere API not configured"));
674 }
675
676 Ok(Self {
677 client: Client::new(),
678 config,
679 cache: None,
680 })
681 }
682
683 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
684 let mut embeddings = Self::new(config)?;
685 embeddings.cache = Some(cache);
686 Ok(embeddings)
687 }
688}
689
690#[async_trait]
691impl EmbeddingModel for CohereEmbeddings {
692 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
693 let embeddings = self.embed_batch(&[text.to_string()]).await?;
694 embeddings
695 .into_iter()
696 .next()
697 .ok_or_else(|| anyhow!("No embedding returned from Cohere"))
698 }
699
700 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
701 if texts.is_empty() {
702 return Ok(Vec::new());
703 }
704
705 if let Some(cache) = &self.cache {
707 let cached = cache.get_batch("cohere", &self.config.model, texts).await;
708 let all_cached = cached.iter().all(|c| c.is_some());
709 if all_cached {
710 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
711 }
712 }
713
714 let url = "https://api.cohere.ai/v1/embed";
715
716 let request_body = CohereEmbeddingRequest {
717 model: self.config.model.clone(),
718 texts: texts.to_vec(),
719 input_type: "search_document",
720 embedding_types: Some(vec!["float".to_string()]),
721 };
722
723 tracing::debug!("Sending Cohere embedding request for {} texts", texts.len());
724
725 let response = self
726 .client
727 .post(url)
728 .header("Authorization", format!("Bearer {}", self.config.api_key))
729 .header("Content-Type", "application/json")
730 .json(&request_body)
731 .send()
732 .await?;
733
734 let status = response.status();
735 let response_text = response.text().await?;
736
737 if !status.is_success() {
738 tracing::error!("Cohere API error: {} - {}", status, response_text);
739 return Err(anyhow!(
740 "Cohere API request failed with status {}: {}",
741 status,
742 response_text
743 ));
744 }
745
746 let response_body: CohereEmbeddingResponse = serde_json::from_str(&response_text)
747 .map_err(|e| anyhow!("Failed to parse Cohere response: {} - {}", e, response_text))?;
748
749 let result = response_body.embeddings.float;
750
751 if let Some(cache) = &self.cache {
753 for (text, embedding) in texts.iter().zip(result.iter()) {
754 cache
755 .put("cohere", &self.config.model, text, embedding.clone())
756 .await;
757 }
758 }
759
760 Ok(result)
761 }
762
763 fn dimension(&self) -> usize {
764 match self.config.model.as_str() {
765 "embed-english-v3.0" | "embed-english-light-v3.0" => 1024,
766 "embed-multilingual-v3.0" => 1024,
767 "embed-english-v2.0" => 4096,
768 _ => self.config.dimension.unwrap_or(1024),
769 }
770 }
771
772 fn model_name(&self) -> &str {
773 &self.config.model
774 }
775
776 fn provider(&self) -> &str {
777 "cohere"
778 }
779}
780
781#[derive(Serialize)]
782struct CohereEmbeddingRequest {
783 model: String,
784 texts: Vec<String>,
785 input_type: &'static str,
786 #[serde(skip_serializing_if = "Option::is_none")]
787 embedding_types: Option<Vec<String>>,
788}
789
790#[derive(Deserialize)]
791struct CohereEmbeddingResponse {
792 embeddings: CohereEmbeddingsData,
793 #[allow(dead_code)]
794 id: String,
795 #[allow(dead_code)]
796 text_type: String,
797}
798
799#[derive(Deserialize)]
800struct CohereEmbeddingsData {
801 float: Vec<Vec<f32>>,
802}
803
804#[derive(Debug)]
813pub struct LocalEmbeddings {
814 config: EmbeddingsConfig,
815 cache: Option<Arc<EmbeddingCache>>,
816 #[cfg(feature = "local-embeddings")]
817 model: Option<std::sync::Mutex<Box<dyn LocalModelBackend>>>,
818}
819
820impl LocalEmbeddings {
821 pub fn new(config: EmbeddingsConfig) -> Result<Self> {
822 Ok(Self {
823 config,
824 cache: None,
825 #[cfg(feature = "local-embeddings")]
826 model: None,
827 })
828 }
829
830 pub fn with_cache(config: EmbeddingsConfig, cache: Arc<EmbeddingCache>) -> Result<Self> {
831 let mut embeddings = Self::new(config)?;
832 embeddings.cache = Some(cache);
833 Ok(embeddings)
834 }
835
836 #[cfg(feature = "local-embeddings")]
838 pub fn load_model(&mut self) -> Result<()> {
839 tracing::info!("Loading local embedding model: {}", self.config.model);
842 Ok(())
843 }
844}
845
846#[async_trait]
847impl EmbeddingModel for LocalEmbeddings {
848 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
849 if let Some(cache) = &self.cache {
851 if let Some(embedding) = cache.get("local", &self.config.model, text).await {
852 return Ok(embedding);
853 }
854 }
855
856 #[cfg(feature = "local-embeddings")]
857 {
858 let embedding = vec![0.0f32; self.dimension()];
861
862 if let Some(cache) = &self.cache {
863 cache
864 .put("local", &self.config.model, text, embedding.clone())
865 .await;
866 }
867
868 Ok(embedding)
869 }
870
871 #[cfg(not(feature = "local-embeddings"))]
872 {
873 Err(anyhow!(
874 "Local embeddings require 'local-embeddings' feature. \
875 Enable it in Cargo.toml and ensure candle or ort is available."
876 ))
877 }
878 }
879
880 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
881 if let Some(cache) = &self.cache {
883 let cached = cache.get_batch("local", &self.config.model, texts).await;
884 if cached.iter().all(|c| c.is_some()) {
885 return Ok(cached.into_iter().map(|c| c.unwrap()).collect());
886 }
887 }
888
889 #[cfg(feature = "local-embeddings")]
890 {
891 let mut results = Vec::with_capacity(texts.len());
892 for text in texts {
893 results.push(self.embed(text).await?);
894 }
895
896 if let Some(cache) = &self.cache {
898 for (text, embedding) in texts.iter().zip(results.iter()) {
899 cache
900 .put("local", &self.config.model, text, embedding.clone())
901 .await;
902 }
903 }
904
905 Ok(results)
906 }
907
908 #[cfg(not(feature = "local-embeddings"))]
909 {
910 Err(anyhow!(
911 "Local embeddings require 'local-embeddings' feature"
912 ))
913 }
914 }
915
916 fn dimension(&self) -> usize {
917 self.config.dimension.unwrap_or(384)
918 }
919
920 fn model_name(&self) -> &str {
921 &self.config.model
922 }
923
924 fn provider(&self) -> &str {
925 "local"
926 }
927}
928
929#[cfg(feature = "local-embeddings")]
931trait LocalModelBackend: Send + Sync {
932 fn encode(&self, text: &str) -> Result<Vec<f32>>;
933}
934
935pub struct EmbeddingsFactory {
941 cache: Arc<EmbeddingCache>,
942}
943
944impl EmbeddingsFactory {
945 pub fn new() -> Self {
946 Self {
947 cache: Arc::new(EmbeddingCache::default_cache()),
948 }
949 }
950
951 pub fn with_cache(cache: Arc<EmbeddingCache>) -> Self {
952 Self { cache }
953 }
954
955 pub fn create(&self, config: EmbeddingsConfig) -> Result<Box<dyn EmbeddingModel>> {
957 match config.provider {
958 EmbeddingProvider::OpenAI => Ok(Box::new(OpenAIEmbeddings::with_cache(
959 config,
960 self.cache.clone(),
961 )?)),
962 EmbeddingProvider::HuggingFace => Ok(Box::new(HuggingFaceEmbeddings::with_cache(
963 config,
964 self.cache.clone(),
965 )?)),
966 EmbeddingProvider::Cohere => Ok(Box::new(CohereEmbeddings::with_cache(
967 config,
968 self.cache.clone(),
969 )?)),
970 EmbeddingProvider::Local => Ok(Box::new(LocalEmbeddings::with_cache(
971 config,
972 self.cache.clone(),
973 )?)),
974 EmbeddingProvider::Mock => {
975 let dimension = config.dimension.unwrap_or(DEFAULT_EMBEDDING_DIMENSION);
976 #[cfg(any(feature = "mock", test))]
977 {
978 Ok(Box::new(MockEmbeddingModel::with_name(
979 dimension,
980 &config.model,
981 )))
982 }
983 #[cfg(not(any(feature = "mock", test)))]
984 {
985 let local_config = EmbeddingsConfig::local(&config.model, Some(dimension));
987 Ok(Box::new(LocalEmbeddings::new(local_config)?))
988 }
989 }
990 }
991 }
992
993 pub fn create_safe(&self, config: EmbeddingsConfig) -> Box<dyn EmbeddingModel> {
998 if config.is_valid() {
999 self.create(config)
1000 .unwrap_or_else(|_| self.create_mock_default())
1001 } else {
1002 self.create_mock_default()
1003 }
1004 }
1005
1006 fn create_mock_default(&self) -> Box<dyn EmbeddingModel> {
1008 #[cfg(any(feature = "mock", test))]
1009 {
1010 Box::new(MockEmbeddingModel::new(DEFAULT_EMBEDDING_DIMENSION))
1011 }
1012 #[cfg(not(any(feature = "mock", test)))]
1013 {
1014 let config = EmbeddingsConfig::local("fallback", Some(DEFAULT_EMBEDDING_DIMENSION));
1016 Box::new(LocalEmbeddings::new(config).expect("Local embeddings should always work"))
1017 }
1018 }
1019
1020 pub fn openai(&self) -> Result<Box<dyn EmbeddingModel>> {
1022 let config = EmbeddingsConfig::openai_from_env()?;
1023 self.create(config)
1024 }
1025
1026 pub fn huggingface(&self) -> Result<Box<dyn EmbeddingModel>> {
1028 let config = EmbeddingsConfig::huggingface_from_env()?;
1029 self.create(config)
1030 }
1031
1032 pub fn cohere(&self) -> Result<Box<dyn EmbeddingModel>> {
1034 let config = EmbeddingsConfig::cohere_from_env()?;
1035 self.create(config)
1036 }
1037
1038 pub fn local(&self, model: &str, dimension: Option<usize>) -> Result<Box<dyn EmbeddingModel>> {
1040 let config = EmbeddingsConfig::local(model, dimension);
1041 self.create(config)
1042 }
1043
1044 #[cfg(any(feature = "mock", test))]
1049 pub fn mock(&self, dimension: usize) -> Box<dyn EmbeddingModel> {
1050 Box::new(MockEmbeddingModel::new(dimension))
1051 }
1052
1053 pub fn cache(&self) -> Arc<EmbeddingCache> {
1055 self.cache.clone()
1056 }
1057}
1058
1059impl Default for EmbeddingsFactory {
1060 fn default() -> Self {
1061 Self::new()
1062 }
1063}
1064
1065#[cfg(any(feature = "mock", test))]
1076pub struct MockEmbeddingModel {
1077 dimension: usize,
1078 model_name: String,
1079}
1080
1081#[cfg(any(feature = "mock", test))]
1082impl MockEmbeddingModel {
1083 pub fn new(dimension: usize) -> Self {
1085 Self {
1086 dimension,
1087 model_name: "mock-embedding".to_string(),
1088 }
1089 }
1090
1091 pub fn with_name(dimension: usize, model_name: impl Into<String>) -> Self {
1093 Self {
1094 dimension,
1095 model_name: model_name.into(),
1096 }
1097 }
1098}
1099
1100#[cfg(any(feature = "mock", test))]
1101#[async_trait]
1102impl EmbeddingModel for MockEmbeddingModel {
1103 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
1104 Ok(vec![0.0; self.dimension])
1105 }
1106
1107 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
1108 Ok(texts.iter().map(|_| vec![0.0; self.dimension]).collect())
1109 }
1110
1111 fn dimension(&self) -> usize {
1112 self.dimension
1113 }
1114
1115 fn model_name(&self) -> &str {
1116 &self.model_name
1117 }
1118
1119 fn provider(&self) -> &str {
1120 "mock"
1121 }
1122}
1123
1124pub type Embeddings = OpenAIEmbeddings;
1132
1133#[cfg(test)]
1138mod tests {
1139 use super::*;
1140
1141 #[tokio::test]
1146 async fn test_cache_basic_operations() {
1147 let cache = EmbeddingCache::new(100, 3600);
1148
1149 let embedding = vec![0.1f32, 0.2, 0.3];
1151 cache
1152 .put("openai", "test-model", "hello", embedding.clone())
1153 .await;
1154
1155 let cached = cache.get("openai", "test-model", "hello").await;
1156 assert!(cached.is_some());
1157 assert_eq!(cached.unwrap(), embedding);
1158
1159 let not_cached = cache.get("openai", "test-model", "not-exists").await;
1161 assert!(not_cached.is_none());
1162 }
1163
1164 #[tokio::test]
1165 async fn test_cache_batch_operations() {
1166 let cache = EmbeddingCache::new(100, 3600);
1167
1168 let texts: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1169 let embeddings: Vec<Vec<f32>> = texts.iter().map(|t| vec![t.len() as f32]).collect();
1170
1171 for (text, emb) in texts.iter().zip(embeddings.iter()) {
1172 cache.put("test", "model", text, emb.clone()).await;
1173 }
1174
1175 let cached = cache.get_batch("test", "model", &texts).await;
1176 assert!(cached.iter().all(|c| c.is_some()));
1177 }
1178
1179 #[tokio::test]
1180 async fn test_cache_stats() {
1181 let cache = EmbeddingCache::new(100, 3600);
1182
1183 cache.put("test", "model", "a", vec![1.0f32]).await;
1184 cache.put("test", "model", "b", vec![2.0]).await;
1185
1186 let _ = cache.get("test", "model", "a").await;
1187 let _ = cache.get("test", "model", "a").await;
1188
1189 let stats = cache.stats().await;
1190 assert_eq!(stats.total_entries, 2);
1191 assert_eq!(stats.total_access, 2);
1192 }
1193
1194 #[test]
1199 fn test_config_openai_from_env() {
1200 std::env::set_var("OPENAI_API_KEY", "test_key");
1201 std::env::remove_var("OPENAI_BASE_URL");
1202 std::env::remove_var("OPENAI_EMBEDDING_MODEL");
1203
1204 let config = EmbeddingsConfig::openai_from_env().unwrap();
1205 assert_eq!(config.api_key, "test_key");
1206 assert_eq!(config.model, DEFAULT_EMBEDDING_MODEL);
1207
1208 std::env::remove_var("OPENAI_API_KEY");
1209 }
1210
1211 #[test]
1212 fn test_config_huggingface_from_env() {
1213 std::env::set_var("HUGGINGFACE_API_KEY", "hf_test");
1214 std::env::remove_var("HUGGINGFACE_EMBEDDING_MODEL");
1215
1216 let config = EmbeddingsConfig::huggingface_from_env().unwrap();
1217 assert_eq!(config.api_key, "hf_test");
1218 assert!(config.model.contains("sentence-transformers"));
1219
1220 std::env::remove_var("HUGGINGFACE_API_KEY");
1221 }
1222
1223 #[test]
1224 fn test_config_cohere_from_env() {
1225 std::env::set_var("COHERE_API_KEY", "cohere_test");
1226 std::env::remove_var("COHERE_EMBEDDING_MODEL");
1227
1228 let config = EmbeddingsConfig::cohere_from_env().unwrap();
1229 assert_eq!(config.api_key, "cohere_test");
1230 assert!(config.model.starts_with("embed-"));
1231
1232 std::env::remove_var("COHERE_API_KEY");
1233 }
1234
1235 #[test]
1236 fn test_config_local() {
1237 let config = EmbeddingsConfig::local("all-MiniLM-L6-v2", Some(384));
1238 assert_eq!(config.provider, EmbeddingProvider::Local);
1239 assert!(config.api_key.is_empty());
1240 assert!(config.is_valid()); }
1242
1243 #[test]
1248 fn test_openai_dimension() {
1249 let config = EmbeddingsConfig {
1250 provider: EmbeddingProvider::OpenAI,
1251 api_key: "test".to_string(),
1252 base_url: None,
1253 model: "text-embedding-ada-002".to_string(),
1254 dimension: None,
1255 };
1256 let embeddings = OpenAIEmbeddings::new(config).unwrap();
1257 assert_eq!(embeddings.dimension(), 1536);
1258
1259 let config = EmbeddingsConfig {
1260 provider: EmbeddingProvider::OpenAI,
1261 api_key: "test".to_string(),
1262 base_url: None,
1263 model: "text-embedding-3-large".to_string(),
1264 dimension: None,
1265 };
1266 let embeddings = OpenAIEmbeddings::new(config).unwrap();
1267 assert_eq!(embeddings.dimension(), 3072);
1268 }
1269
1270 #[test]
1271 fn test_huggingface_dimension() {
1272 let config = EmbeddingsConfig {
1273 provider: EmbeddingProvider::HuggingFace,
1274 api_key: "test".to_string(),
1275 base_url: None,
1276 model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
1277 dimension: None,
1278 };
1279 let embeddings = HuggingFaceEmbeddings::new(config).unwrap();
1280 assert_eq!(embeddings.dimension(), 384);
1281 }
1282
1283 #[test]
1284 fn test_cohere_dimension() {
1285 let config = EmbeddingsConfig {
1286 provider: EmbeddingProvider::Cohere,
1287 api_key: "test".to_string(),
1288 base_url: None,
1289 model: "embed-english-v3.0".to_string(),
1290 dimension: None,
1291 };
1292 let embeddings = CohereEmbeddings::new(config).unwrap();
1293 assert_eq!(embeddings.dimension(), 1024);
1294 }
1295
1296 #[test]
1301 fn test_factory_create_openai() {
1302 std::env::set_var("OPENAI_API_KEY", "test_key");
1303
1304 let factory = EmbeddingsFactory::new();
1305 let model = factory.openai().unwrap();
1306 assert_eq!(model.provider(), "openai");
1307
1308 std::env::remove_var("OPENAI_API_KEY");
1309 }
1310
1311 #[test]
1312 fn test_factory_create_local() {
1313 let factory = EmbeddingsFactory::new();
1314 let model = factory.local("test-model", Some(384)).unwrap();
1315 assert_eq!(model.provider(), "local");
1316 assert_eq!(model.dimension(), 384);
1317 }
1318
1319 #[test]
1320 fn test_factory_create_mock() {
1321 let factory = EmbeddingsFactory::new();
1322 let model = factory.mock(512);
1323 assert_eq!(model.provider(), "mock");
1324 assert_eq!(model.dimension(), 512);
1325 }
1326
1327 #[test]
1328 fn test_factory_create_safe_with_invalid_config() {
1329 let factory = EmbeddingsFactory::new();
1330 let config = EmbeddingsConfig {
1332 provider: EmbeddingProvider::OpenAI,
1333 api_key: String::new(),
1334 base_url: None,
1335 model: "test".to_string(),
1336 dimension: None,
1337 };
1338 let model = factory.create_safe(config);
1339 assert_eq!(model.provider(), "mock");
1341 }
1342
1343 #[test]
1344 fn test_factory_create_safe_with_valid_config() {
1345 std::env::set_var("OPENAI_API_KEY", "test_key");
1346 let factory = EmbeddingsFactory::new();
1347 let config = EmbeddingsConfig::openai_from_env().unwrap();
1348 let model = factory.create_safe(config);
1349 assert_eq!(model.provider(), "openai");
1350 std::env::remove_var("OPENAI_API_KEY");
1351 }
1352
1353 #[test]
1358 fn test_config_default_is_safe() {
1359 let config = EmbeddingsConfig::default();
1360 assert_eq!(config.provider, EmbeddingProvider::Mock);
1362 assert!(config.is_valid());
1364 }
1365
1366 #[test]
1367 fn test_provider_mock_is_valid() {
1368 let config = EmbeddingsConfig {
1369 provider: EmbeddingProvider::Mock,
1370 api_key: String::new(),
1371 base_url: None,
1372 model: "mock-test".to_string(),
1373 dimension: Some(256),
1374 };
1375 assert!(config.is_valid());
1376 }
1377
1378 #[test]
1379 fn test_embeddings_factory_mock_default_dimension() {
1380 let factory = EmbeddingsFactory::new();
1381 let model = factory.mock(DEFAULT_EMBEDDING_DIMENSION);
1382 assert_eq!(model.dimension(), DEFAULT_EMBEDDING_DIMENSION);
1383 }
1384
1385 #[test]
1390 fn test_backward_compatible_embeddings() {
1391 std::env::set_var("OPENAI_API_KEY", "test_key");
1392
1393 let config = EmbeddingsConfig::openai_from_env().unwrap();
1394 let embeddings = Embeddings::new(config).unwrap();
1395 assert_eq!(embeddings.provider(), "openai");
1396
1397 std::env::remove_var("OPENAI_API_KEY");
1398 }
1399}