1use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use tokio::sync::RwLock;
32
33use argentor_core::{ArgentorError, ArgentorResult};
34
35use crate::embedding::{EmbeddingProvider, LocalEmbedding};
36
37fn fnv1a_hash(data: &[u8]) -> u64 {
43 let mut hash: u64 = 14695981039346656037;
44 for &byte in data {
45 hash ^= byte as u64;
46 hash = hash.wrapping_mul(1099511628211);
47 }
48 hash
49}
50
51#[derive(Debug, Serialize)]
57pub struct OpenAiEmbeddingRequest {
58 pub model: String,
60 pub input: Vec<String>,
62}
63
64#[derive(Debug, Deserialize)]
66pub struct OpenAiEmbeddingObject {
67 pub embedding: Vec<f32>,
69 pub index: usize,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct OpenAiEmbeddingResponse {
76 pub data: Vec<OpenAiEmbeddingObject>,
78 pub model: String,
80}
81
82#[derive(Debug, Serialize)]
84pub struct CohereEmbedRequest {
85 pub model: String,
87 pub texts: Vec<String>,
89 pub input_type: String,
91 pub embedding_types: Vec<String>,
93}
94
95#[derive(Debug, Deserialize)]
97pub struct CohereEmbedResponse {
98 pub embeddings: CohereEmbeddingsMap,
100}
101
102#[derive(Debug, Deserialize)]
104pub struct CohereEmbeddingsMap {
105 #[serde(default)]
107 pub float: Vec<Vec<f32>>,
108}
109
110#[derive(Debug, Serialize)]
112pub struct VoyageEmbeddingRequest {
113 pub model: String,
115 pub input: Vec<String>,
117}
118
119#[derive(Debug, Deserialize)]
121pub struct VoyageEmbeddingObject {
122 pub embedding: Vec<f32>,
124 pub index: usize,
126}
127
128#[derive(Debug, Deserialize)]
130pub struct VoyageEmbeddingResponse {
131 pub data: Vec<VoyageEmbeddingObject>,
133}
134
135pub fn parse_openai_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
144 let response: OpenAiEmbeddingResponse = serde_json::from_value(json.clone())
145 .map_err(|e| ArgentorError::Agent(format!("Failed to parse OpenAI response: {e}")))?;
146 response
147 .data
148 .into_iter()
149 .next()
150 .map(|obj| obj.embedding)
151 .ok_or_else(|| {
152 ArgentorError::Agent("OpenAI response contains no embedding data".to_string())
153 })
154}
155
156pub fn parse_cohere_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
161 let response: CohereEmbedResponse = serde_json::from_value(json.clone())
162 .map_err(|e| ArgentorError::Agent(format!("Failed to parse Cohere response: {e}")))?;
163 response.embeddings.float.into_iter().next().ok_or_else(|| {
164 ArgentorError::Agent("Cohere response contains no float embeddings".to_string())
165 })
166}
167
168pub fn parse_voyage_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
173 let response: VoyageEmbeddingResponse = serde_json::from_value(json.clone())
174 .map_err(|e| ArgentorError::Agent(format!("Failed to parse Voyage response: {e}")))?;
175 response
176 .data
177 .into_iter()
178 .next()
179 .map(|obj| obj.embedding)
180 .ok_or_else(|| {
181 ArgentorError::Agent("Voyage response contains no embedding data".to_string())
182 })
183}
184
185pub struct OpenAiEmbeddingProvider {
196 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
197 api_key: String,
198 model: String,
199 dimensions: usize,
200 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
201 base_url: String,
202}
203
204impl OpenAiEmbeddingProvider {
205 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
209 let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
210 let dimensions = Self::default_dimensions(&model);
211 Self {
212 api_key: api_key.into(),
213 model,
214 dimensions,
215 base_url: "https://api.openai.com/v1/embeddings".to_string(),
216 }
217 }
218
219 pub fn with_base_url(
221 api_key: impl Into<String>,
222 model: Option<String>,
223 base_url: impl Into<String>,
224 ) -> Self {
225 let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
226 let dimensions = Self::default_dimensions(&model);
227 Self {
228 api_key: api_key.into(),
229 model,
230 dimensions,
231 base_url: base_url.into(),
232 }
233 }
234
235 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
237 self.dimensions = dimensions;
238 self
239 }
240
241 fn default_dimensions(model: &str) -> usize {
242 match model {
243 "text-embedding-3-large" => 3072,
244 "text-embedding-3-small" => 1536,
245 "text-embedding-ada-002" => 1536,
246 _ => 1536,
247 }
248 }
249
250 pub fn model(&self) -> &str {
252 &self.model
253 }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAiEmbeddingProvider {
258 #[cfg(feature = "http-embeddings")]
259 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
260 let client = reqwest::Client::new();
261 let response = client
262 .post(&self.base_url)
263 .header("Authorization", format!("Bearer {}", self.api_key))
264 .json(&serde_json::json!({
265 "model": self.model,
266 "input": text,
267 }))
268 .send()
269 .await
270 .map_err(|e| ArgentorError::Http(format!("OpenAI embedding request failed: {e}")))?;
271
272 let status = response.status();
273 if !status.is_success() {
274 let body = response.text().await.unwrap_or_default();
275 return Err(ArgentorError::Http(format!(
276 "OpenAI API error {status}: {body}"
277 )));
278 }
279
280 let json: serde_json::Value = response.json().await.map_err(|e| {
281 ArgentorError::Http(format!("Failed to read OpenAI response body: {e}"))
282 })?;
283
284 parse_openai_embedding_response(&json)
285 }
286
287 #[cfg(not(feature = "http-embeddings"))]
288 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
289 Err(ArgentorError::Http(
290 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
291 or use LocalEmbedding for offline embeddings."
292 .to_string(),
293 ))
294 }
295
296 fn dimension(&self) -> usize {
297 self.dimensions
298 }
299}
300
301pub struct CohereEmbeddingProvider {
312 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
313 api_key: String,
314 model: String,
315 dimensions: usize,
316 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
317 input_type: String,
318}
319
320impl CohereEmbeddingProvider {
321 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
325 let model = model.unwrap_or_else(|| "embed-english-v3.0".to_string());
326 let dimensions = Self::default_dimensions(&model);
327 Self {
328 api_key: api_key.into(),
329 model,
330 dimensions,
331 input_type: "search_document".to_string(),
332 }
333 }
334
335 pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
337 self.input_type = input_type.into();
338 self
339 }
340
341 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
343 self.dimensions = dimensions;
344 self
345 }
346
347 fn default_dimensions(model: &str) -> usize {
348 match model {
349 "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
350 "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
351 _ => 1024,
352 }
353 }
354
355 pub fn model(&self) -> &str {
357 &self.model
358 }
359
360 pub fn input_type(&self) -> &str {
362 &self.input_type
363 }
364}
365
366#[async_trait]
367impl EmbeddingProvider for CohereEmbeddingProvider {
368 #[cfg(feature = "http-embeddings")]
369 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
370 let client = reqwest::Client::new();
371 let response = client
372 .post("https://api.cohere.com/v2/embed")
373 .header("Authorization", format!("Bearer {}", self.api_key))
374 .json(&serde_json::json!({
375 "model": self.model,
376 "texts": [text],
377 "input_type": self.input_type,
378 "embedding_types": ["float"],
379 }))
380 .send()
381 .await
382 .map_err(|e| ArgentorError::Http(format!("Cohere embedding request failed: {e}")))?;
383
384 let status = response.status();
385 if !status.is_success() {
386 let body = response.text().await.unwrap_or_default();
387 return Err(ArgentorError::Http(format!(
388 "Cohere API error {status}: {body}"
389 )));
390 }
391
392 let json: serde_json::Value = response.json().await.map_err(|e| {
393 ArgentorError::Http(format!("Failed to read Cohere response body: {e}"))
394 })?;
395
396 parse_cohere_embedding_response(&json)
397 }
398
399 #[cfg(not(feature = "http-embeddings"))]
400 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
401 Err(ArgentorError::Http(
402 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
403 or use LocalEmbedding for offline embeddings."
404 .to_string(),
405 ))
406 }
407
408 fn dimension(&self) -> usize {
409 self.dimensions
410 }
411}
412
413pub struct VoyageEmbeddingProvider {
424 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
425 api_key: String,
426 model: String,
427 dimensions: usize,
428}
429
430impl VoyageEmbeddingProvider {
431 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
435 let model = model.unwrap_or_else(|| "voyage-2".to_string());
436 let dimensions = Self::default_dimensions(&model);
437 Self {
438 api_key: api_key.into(),
439 model,
440 dimensions,
441 }
442 }
443
444 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
446 self.dimensions = dimensions;
447 self
448 }
449
450 fn default_dimensions(model: &str) -> usize {
451 match model {
452 "voyage-2" | "voyage-large-2" => 1024,
453 "voyage-lite-02-instruct" => 1024,
454 "voyage-3" => 1024,
455 "voyage-code-2" => 1536,
456 _ => 1024,
457 }
458 }
459
460 pub fn model(&self) -> &str {
462 &self.model
463 }
464}
465
466#[async_trait]
467impl EmbeddingProvider for VoyageEmbeddingProvider {
468 #[cfg(feature = "http-embeddings")]
469 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
470 let client = reqwest::Client::new();
471 let response = client
472 .post("https://api.voyageai.com/v1/embeddings")
473 .header("Authorization", format!("Bearer {}", self.api_key))
474 .json(&serde_json::json!({
475 "model": self.model,
476 "input": [text],
477 }))
478 .send()
479 .await
480 .map_err(|e| ArgentorError::Http(format!("Voyage embedding request failed: {e}")))?;
481
482 let status = response.status();
483 if !status.is_success() {
484 let body = response.text().await.unwrap_or_default();
485 return Err(ArgentorError::Http(format!(
486 "Voyage API error {status}: {body}"
487 )));
488 }
489
490 let json: serde_json::Value = response.json().await.map_err(|e| {
491 ArgentorError::Http(format!("Failed to read Voyage response body: {e}"))
492 })?;
493
494 parse_voyage_embedding_response(&json)
495 }
496
497 #[cfg(not(feature = "http-embeddings"))]
498 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
499 Err(ArgentorError::Http(
500 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
501 or use LocalEmbedding for offline embeddings."
502 .to_string(),
503 ))
504 }
505
506 fn dimension(&self) -> usize {
507 self.dimensions
508 }
509}
510
511#[derive(Debug, Clone, Default)]
517pub struct CacheStats {
518 pub hits: u64,
520 pub misses: u64,
522 pub size: usize,
524}
525
526pub struct CachedEmbeddingProvider {
531 inner: Arc<dyn EmbeddingProvider>,
532 cache: Arc<RwLock<HashMap<u64, Vec<f32>>>>,
533 max_cache_size: usize,
534 stats: Arc<RwLock<CacheStats>>,
535}
536
537impl CachedEmbeddingProvider {
538 pub fn new(inner: Arc<dyn EmbeddingProvider>, max_cache_size: usize) -> Self {
540 Self {
541 inner,
542 cache: Arc::new(RwLock::new(HashMap::new())),
543 max_cache_size,
544 stats: Arc::new(RwLock::new(CacheStats::default())),
545 }
546 }
547
548 pub async fn cache_stats(&self) -> CacheStats {
550 self.stats.read().await.clone()
551 }
552
553 pub async fn clear(&self) {
555 self.cache.write().await.clear();
556 let mut stats = self.stats.write().await;
557 stats.size = 0;
558 }
559
560 fn text_hash(text: &str) -> u64 {
561 fnv1a_hash(text.as_bytes())
562 }
563}
564
565#[async_trait]
566impl EmbeddingProvider for CachedEmbeddingProvider {
567 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
568 let key = Self::text_hash(text);
569
570 {
572 let cache = self.cache.read().await;
573 if let Some(cached) = cache.get(&key) {
574 let mut stats = self.stats.write().await;
575 stats.hits += 1;
576 return Ok(cached.clone());
577 }
578 }
579
580 let embedding = self.inner.embed(text).await?;
582
583 {
585 let mut cache = self.cache.write().await;
586
587 if cache.len() >= self.max_cache_size {
589 if let Some(&evict_key) = cache.keys().next() {
592 cache.remove(&evict_key);
593 }
594 }
595
596 cache.insert(key, embedding.clone());
597
598 let mut stats = self.stats.write().await;
599 stats.misses += 1;
600 stats.size = cache.len();
601 }
602
603 Ok(embedding)
604 }
605
606 fn dimension(&self) -> usize {
607 self.inner.dimension()
608 }
609}
610
611pub struct BatchEmbeddingProvider {
621 inner: Arc<dyn EmbeddingProvider>,
622}
623
624impl BatchEmbeddingProvider {
625 pub fn new(inner: Arc<dyn EmbeddingProvider>) -> Self {
627 Self { inner }
628 }
629
630 pub async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
632 self.inner.embed_batch(texts).await
633 }
634}
635
636#[async_trait]
637impl EmbeddingProvider for BatchEmbeddingProvider {
638 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
639 self.inner.embed(text).await
640 }
641
642 async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
643 self.inner.embed_batch(texts).await
644 }
645
646 fn dimension(&self) -> usize {
647 self.inner.dimension()
648 }
649}
650
651pub struct EmbeddingProviderFactory;
657
658impl EmbeddingProviderFactory {
659 pub fn create(
663 provider_name: &str,
664 api_key: impl Into<String>,
665 model: Option<String>,
666 ) -> ArgentorResult<Box<dyn EmbeddingProvider>> {
667 let api_key = api_key.into();
668 match provider_name {
669 "openai" => Ok(Box::new(OpenAiEmbeddingProvider::new(api_key, model))),
670 "cohere" => Ok(Box::new(CohereEmbeddingProvider::new(api_key, model))),
671 "voyage" => Ok(Box::new(VoyageEmbeddingProvider::new(api_key, model))),
672 "local" => {
673 let dim = model
674 .as_deref()
675 .and_then(|m| m.parse::<usize>().ok())
676 .unwrap_or(256);
677 Ok(Box::new(LocalEmbedding::new(dim)))
678 }
679 other => Err(ArgentorError::Config(format!(
680 "Unknown embedding provider: {other}"
681 ))),
682 }
683 }
684
685 pub fn available_providers() -> &'static [&'static str] {
687 &["openai", "cohere", "voyage", "local"]
688 }
689}
690
691#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct EmbeddingConfig {
698 pub provider: String,
700 #[serde(default)]
702 pub api_key: String,
703 #[serde(default)]
705 pub model: Option<String>,
706 #[serde(default)]
708 pub dimensions: Option<usize>,
709 #[serde(default)]
711 pub base_url: Option<String>,
712 #[serde(default)]
714 pub cache_size: Option<usize>,
715}
716
717impl EmbeddingConfig {
718 pub fn build(&self) -> ArgentorResult<Arc<dyn EmbeddingProvider>> {
722 let mut provider: Box<dyn EmbeddingProvider> = match self.provider.as_str() {
723 "openai" => {
724 let mut p = if let Some(ref url) = self.base_url {
725 OpenAiEmbeddingProvider::with_base_url(&self.api_key, self.model.clone(), url)
726 } else {
727 OpenAiEmbeddingProvider::new(&self.api_key, self.model.clone())
728 };
729 if let Some(dim) = self.dimensions {
730 p = p.with_dimensions(dim);
731 }
732 Box::new(p)
733 }
734 "cohere" => {
735 let mut p = CohereEmbeddingProvider::new(&self.api_key, self.model.clone());
736 if let Some(dim) = self.dimensions {
737 p = p.with_dimensions(dim);
738 }
739 Box::new(p)
740 }
741 "voyage" => {
742 let mut p = VoyageEmbeddingProvider::new(&self.api_key, self.model.clone());
743 if let Some(dim) = self.dimensions {
744 p = p.with_dimensions(dim);
745 }
746 Box::new(p)
747 }
748 "local" => {
749 let dim = self.dimensions.unwrap_or(256);
750 Box::new(LocalEmbedding::new(dim))
751 }
752 other => {
753 return Err(ArgentorError::Config(format!(
754 "Unknown embedding provider: {other}"
755 )));
756 }
757 };
758
759 let _ = &mut provider; let arc: Arc<dyn EmbeddingProvider> = Arc::from(provider);
764
765 if let Some(cache_size) = self.cache_size {
767 Ok(Arc::new(CachedEmbeddingProvider::new(arc, cache_size)))
768 } else {
769 Ok(arc)
770 }
771 }
772}
773
774impl Default for EmbeddingConfig {
775 fn default() -> Self {
776 Self {
777 provider: "local".to_string(),
778 api_key: String::new(),
779 model: None,
780 dimensions: None,
781 base_url: None,
782 cache_size: None,
783 }
784 }
785}
786
787#[cfg_attr(
797 all(feature = "http-embeddings", not(test)),
798 allow(dead_code)
799)]
800fn stub_embedding(text: &str, dimensions: usize) -> Vec<f32> {
801 let dim = dimensions.max(1);
802 let mut v = vec![0.0f32; dim];
803 for (i, b) in text.bytes().enumerate() {
804 v[i % dim] += (b as f32) / 255.0;
805 }
806 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
807 if norm > 0.0 {
808 for x in &mut v {
809 *x /= norm;
810 }
811 }
812 v
813}
814
815pub struct JinaEmbeddingProvider {
826 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
827 api_key: String,
828 model: String,
829 dimensions: usize,
830 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
831 base_url: String,
832}
833
834impl JinaEmbeddingProvider {
835 pub fn new(api_key: impl Into<String>) -> Self {
837 Self::with_model(api_key, "jina-embeddings-v3", 1024)
838 }
839
840 pub fn with_model(
842 api_key: impl Into<String>,
843 model: impl Into<String>,
844 dimensions: usize,
845 ) -> Self {
846 Self {
847 api_key: api_key.into(),
848 model: model.into(),
849 dimensions,
850 base_url: "https://api.jina.ai/v1/embeddings".to_string(),
851 }
852 }
853
854 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
856 self.base_url = base_url.into();
857 self
858 }
859
860 pub fn model(&self) -> &str {
862 &self.model
863 }
864
865 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
867 serde_json::json!({
868 "model": self.model,
869 "input": texts,
870 })
871 }
872}
873
874#[async_trait]
875impl EmbeddingProvider for JinaEmbeddingProvider {
876 #[cfg(feature = "http-embeddings")]
877 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
878 let client = reqwest::Client::new();
879 let payload = self.build_payload(&[text.to_string()]);
880 let response = client
881 .post(&self.base_url)
882 .header("Authorization", format!("Bearer {}", self.api_key))
883 .json(&payload)
884 .send()
885 .await
886 .map_err(|e| ArgentorError::Http(format!("Jina embedding request failed: {e}")))?;
887
888 let status = response.status();
889 if !status.is_success() {
890 let body = response.text().await.unwrap_or_default();
891 return Err(ArgentorError::Http(format!(
892 "Jina API error {status}: {body}"
893 )));
894 }
895
896 let json: serde_json::Value = response
897 .json()
898 .await
899 .map_err(|e| ArgentorError::Http(format!("Failed to read Jina response body: {e}")))?;
900
901 parse_openai_embedding_response(&json)
903 }
904
905 #[cfg(not(feature = "http-embeddings"))]
906 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
907 Ok(stub_embedding(text, self.dimensions))
908 }
909
910 fn dimension(&self) -> usize {
911 self.dimensions
912 }
913}
914
915pub struct MistralEmbedProvider {
924 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
925 api_key: String,
926 model: String,
927 dimensions: usize,
928 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
929 base_url: String,
930}
931
932impl MistralEmbedProvider {
933 pub fn new(api_key: impl Into<String>) -> Self {
935 Self::with_model(api_key, "mistral-embed", 1024)
936 }
937
938 pub fn with_model(
940 api_key: impl Into<String>,
941 model: impl Into<String>,
942 dimensions: usize,
943 ) -> Self {
944 Self {
945 api_key: api_key.into(),
946 model: model.into(),
947 dimensions,
948 base_url: "https://api.mistral.ai/v1/embeddings".to_string(),
949 }
950 }
951
952 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
954 self.base_url = base_url.into();
955 self
956 }
957
958 pub fn model(&self) -> &str {
960 &self.model
961 }
962
963 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
965 serde_json::json!({
966 "model": self.model,
967 "input": texts,
968 })
969 }
970}
971
972#[async_trait]
973impl EmbeddingProvider for MistralEmbedProvider {
974 #[cfg(feature = "http-embeddings")]
975 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
976 let client = reqwest::Client::new();
977 let payload = self.build_payload(&[text.to_string()]);
978 let response = client
979 .post(&self.base_url)
980 .header("Authorization", format!("Bearer {}", self.api_key))
981 .json(&payload)
982 .send()
983 .await
984 .map_err(|e| ArgentorError::Http(format!("Mistral embedding request failed: {e}")))?;
985
986 let status = response.status();
987 if !status.is_success() {
988 let body = response.text().await.unwrap_or_default();
989 return Err(ArgentorError::Http(format!(
990 "Mistral API error {status}: {body}"
991 )));
992 }
993
994 let json: serde_json::Value = response.json().await.map_err(|e| {
995 ArgentorError::Http(format!("Failed to read Mistral response body: {e}"))
996 })?;
997
998 parse_openai_embedding_response(&json)
999 }
1000
1001 #[cfg(not(feature = "http-embeddings"))]
1002 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1003 Ok(stub_embedding(text, self.dimensions))
1004 }
1005
1006 fn dimension(&self) -> usize {
1007 self.dimensions
1008 }
1009}
1010
1011pub struct NomicEmbedProvider {
1021 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1022 api_key: String,
1023 model: String,
1024 dimensions: usize,
1025 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1026 base_url: String,
1027 task_type: String,
1028}
1029
1030impl NomicEmbedProvider {
1031 pub fn new(api_key: impl Into<String>) -> Self {
1033 Self::with_model(api_key, "nomic-embed-text-v1.5", 768)
1034 }
1035
1036 pub fn with_model(
1038 api_key: impl Into<String>,
1039 model: impl Into<String>,
1040 dimensions: usize,
1041 ) -> Self {
1042 Self {
1043 api_key: api_key.into(),
1044 model: model.into(),
1045 dimensions,
1046 base_url: "https://api-atlas.nomic.ai/v1/embedding/text".to_string(),
1047 task_type: "search_document".to_string(),
1048 }
1049 }
1050
1051 pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
1053 self.task_type = task_type.into();
1054 self
1055 }
1056
1057 pub fn model(&self) -> &str {
1059 &self.model
1060 }
1061
1062 pub fn task_type(&self) -> &str {
1064 &self.task_type
1065 }
1066
1067 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1069 serde_json::json!({
1070 "model": self.model,
1071 "texts": texts,
1072 "task_type": self.task_type,
1073 })
1074 }
1075}
1076
1077#[async_trait]
1078impl EmbeddingProvider for NomicEmbedProvider {
1079 #[cfg(feature = "http-embeddings")]
1080 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1081 let client = reqwest::Client::new();
1082 let payload = self.build_payload(&[text.to_string()]);
1083 let response = client
1084 .post(&self.base_url)
1085 .header("Authorization", format!("Bearer {}", self.api_key))
1086 .json(&payload)
1087 .send()
1088 .await
1089 .map_err(|e| ArgentorError::Http(format!("Nomic embedding request failed: {e}")))?;
1090
1091 let status = response.status();
1092 if !status.is_success() {
1093 let body = response.text().await.unwrap_or_default();
1094 return Err(ArgentorError::Http(format!(
1095 "Nomic API error {status}: {body}"
1096 )));
1097 }
1098
1099 let json: serde_json::Value = response.json().await.map_err(|e| {
1100 ArgentorError::Http(format!("Failed to read Nomic response body: {e}"))
1101 })?;
1102
1103 let embeddings = json
1105 .get("embeddings")
1106 .and_then(|v| v.as_array())
1107 .ok_or_else(|| {
1108 ArgentorError::Agent("Nomic response missing 'embeddings' array".to_string())
1109 })?;
1110 let first = embeddings.first().ok_or_else(|| {
1111 ArgentorError::Agent("Nomic response contains no embedding vectors".to_string())
1112 })?;
1113 let vec: Vec<f32> = serde_json::from_value(first.clone()).map_err(|e| {
1114 ArgentorError::Agent(format!("Failed to parse Nomic embedding vector: {e}"))
1115 })?;
1116 Ok(vec)
1117 }
1118
1119 #[cfg(not(feature = "http-embeddings"))]
1120 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1121 Ok(stub_embedding(text, self.dimensions))
1122 }
1123
1124 fn dimension(&self) -> usize {
1125 self.dimensions
1126 }
1127}
1128
1129pub struct SentenceTransformersProvider {
1139 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1140 api_key: String,
1141 model: String,
1142 dimensions: usize,
1143 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1144 base_url: String,
1145}
1146
1147impl SentenceTransformersProvider {
1148 pub fn new(api_key: impl Into<String>) -> Self {
1150 Self::with_model(api_key, "sentence-transformers/all-MiniLM-L6-v2", 384)
1151 }
1152
1153 pub fn with_model(
1155 api_key: impl Into<String>,
1156 model: impl Into<String>,
1157 dimensions: usize,
1158 ) -> Self {
1159 let model = model.into();
1160 let base_url =
1161 format!("https://api-inference.huggingface.co/pipeline/feature-extraction/{model}");
1162 Self {
1163 api_key: api_key.into(),
1164 model,
1165 dimensions,
1166 base_url,
1167 }
1168 }
1169
1170 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1172 self.base_url = base_url.into();
1173 self
1174 }
1175
1176 pub fn model(&self) -> &str {
1178 &self.model
1179 }
1180
1181 pub fn default_dimensions(model: &str) -> usize {
1183 match model {
1184 "sentence-transformers/all-MiniLM-L6-v2" => 384,
1185 "sentence-transformers/all-mpnet-base-v2"
1186 | "sentence-transformers/multi-qa-mpnet-base-dot-v1" => 768,
1187 _ => 384,
1188 }
1189 }
1190
1191 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1193 serde_json::json!({
1194 "inputs": texts,
1195 "options": { "wait_for_model": true },
1196 })
1197 }
1198}
1199
1200#[async_trait]
1201impl EmbeddingProvider for SentenceTransformersProvider {
1202 #[cfg(feature = "http-embeddings")]
1203 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1204 let client = reqwest::Client::new();
1205 let payload = self.build_payload(&[text.to_string()]);
1206 let response = client
1207 .post(&self.base_url)
1208 .header("Authorization", format!("Bearer {}", self.api_key))
1209 .json(&payload)
1210 .send()
1211 .await
1212 .map_err(|e| {
1213 ArgentorError::Http(format!("HuggingFace embedding request failed: {e}"))
1214 })?;
1215
1216 let status = response.status();
1217 if !status.is_success() {
1218 let body = response.text().await.unwrap_or_default();
1219 return Err(ArgentorError::Http(format!(
1220 "HuggingFace API error {status}: {body}"
1221 )));
1222 }
1223
1224 let json: serde_json::Value = response.json().await.map_err(|e| {
1225 ArgentorError::Http(format!("Failed to read HuggingFace response body: {e}"))
1226 })?;
1227
1228 match &json {
1230 serde_json::Value::Array(arr)
1231 if arr.first().is_some_and(serde_json::Value::is_array) =>
1232 {
1233 let first = arr.first().cloned().ok_or_else(|| {
1234 ArgentorError::Agent("HuggingFace response empty".to_string())
1235 })?;
1236 serde_json::from_value(first).map_err(|e| {
1237 ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
1238 })
1239 }
1240 serde_json::Value::Array(_) => serde_json::from_value(json).map_err(|e| {
1241 ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
1242 }),
1243 _ => Err(ArgentorError::Agent(
1244 "HuggingFace response is not an array".to_string(),
1245 )),
1246 }
1247 }
1248
1249 #[cfg(not(feature = "http-embeddings"))]
1250 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1251 Ok(stub_embedding(text, self.dimensions))
1252 }
1253
1254 fn dimension(&self) -> usize {
1255 self.dimensions
1256 }
1257}
1258
1259pub struct TogetherEmbedProvider {
1268 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1269 api_key: String,
1270 model: String,
1271 dimensions: usize,
1272 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1273 base_url: String,
1274}
1275
1276impl TogetherEmbedProvider {
1277 pub fn new(api_key: impl Into<String>) -> Self {
1279 Self::with_model(
1280 api_key,
1281 "togethercomputer/m2-bert-80M-32k-retrieval",
1282 768,
1283 )
1284 }
1285
1286 pub fn with_model(
1288 api_key: impl Into<String>,
1289 model: impl Into<String>,
1290 dimensions: usize,
1291 ) -> Self {
1292 Self {
1293 api_key: api_key.into(),
1294 model: model.into(),
1295 dimensions,
1296 base_url: "https://api.together.xyz/v1/embeddings".to_string(),
1297 }
1298 }
1299
1300 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1302 self.base_url = base_url.into();
1303 self
1304 }
1305
1306 pub fn model(&self) -> &str {
1308 &self.model
1309 }
1310
1311 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1313 serde_json::json!({
1314 "model": self.model,
1315 "input": texts,
1316 })
1317 }
1318}
1319
1320#[async_trait]
1321impl EmbeddingProvider for TogetherEmbedProvider {
1322 #[cfg(feature = "http-embeddings")]
1323 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1324 let client = reqwest::Client::new();
1325 let payload = self.build_payload(&[text.to_string()]);
1326 let response = client
1327 .post(&self.base_url)
1328 .header("Authorization", format!("Bearer {}", self.api_key))
1329 .json(&payload)
1330 .send()
1331 .await
1332 .map_err(|e| ArgentorError::Http(format!("Together embedding request failed: {e}")))?;
1333
1334 let status = response.status();
1335 if !status.is_success() {
1336 let body = response.text().await.unwrap_or_default();
1337 return Err(ArgentorError::Http(format!(
1338 "Together API error {status}: {body}"
1339 )));
1340 }
1341
1342 let json: serde_json::Value = response.json().await.map_err(|e| {
1343 ArgentorError::Http(format!("Failed to read Together response body: {e}"))
1344 })?;
1345
1346 parse_openai_embedding_response(&json)
1347 }
1348
1349 #[cfg(not(feature = "http-embeddings"))]
1350 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1351 Ok(stub_embedding(text, self.dimensions))
1352 }
1353
1354 fn dimension(&self) -> usize {
1355 self.dimensions
1356 }
1357}
1358
1359pub struct CohereEmbedV4Provider {
1371 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1372 api_key: String,
1373 model: String,
1374 dimensions: usize,
1375 input_type: String,
1376 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1377 base_url: String,
1378}
1379
1380impl CohereEmbedV4Provider {
1381 pub fn new(api_key: impl Into<String>) -> Self {
1383 Self::with_model(api_key, "embed-english-v3.0", 1024)
1384 }
1385
1386 pub fn with_model(
1388 api_key: impl Into<String>,
1389 model: impl Into<String>,
1390 dimensions: usize,
1391 ) -> Self {
1392 Self {
1393 api_key: api_key.into(),
1394 model: model.into(),
1395 dimensions,
1396 input_type: "search_document".to_string(),
1397 base_url: "https://api.cohere.com/v2/embed".to_string(),
1398 }
1399 }
1400
1401 pub fn for_search_document(mut self) -> Self {
1403 self.input_type = "search_document".to_string();
1404 self
1405 }
1406
1407 pub fn for_search_query(mut self) -> Self {
1409 self.input_type = "search_query".to_string();
1410 self
1411 }
1412
1413 pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
1415 self.input_type = input_type.into();
1416 self
1417 }
1418
1419 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1421 self.base_url = base_url.into();
1422 self
1423 }
1424
1425 pub fn model(&self) -> &str {
1427 &self.model
1428 }
1429
1430 pub fn input_type(&self) -> &str {
1432 &self.input_type
1433 }
1434
1435 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1437 serde_json::json!({
1438 "model": self.model,
1439 "texts": texts,
1440 "input_type": self.input_type,
1441 "embedding_types": ["float"],
1442 })
1443 }
1444}
1445
1446#[async_trait]
1447impl EmbeddingProvider for CohereEmbedV4Provider {
1448 #[cfg(feature = "http-embeddings")]
1449 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1450 let client = reqwest::Client::new();
1451 let payload = self.build_payload(&[text.to_string()]);
1452 let response = client
1453 .post(&self.base_url)
1454 .header("Authorization", format!("Bearer {}", self.api_key))
1455 .json(&payload)
1456 .send()
1457 .await
1458 .map_err(|e| {
1459 ArgentorError::Http(format!("Cohere v4 embedding request failed: {e}"))
1460 })?;
1461
1462 let status = response.status();
1463 if !status.is_success() {
1464 let body = response.text().await.unwrap_or_default();
1465 return Err(ArgentorError::Http(format!(
1466 "Cohere v4 API error {status}: {body}"
1467 )));
1468 }
1469
1470 let json: serde_json::Value = response.json().await.map_err(|e| {
1471 ArgentorError::Http(format!("Failed to read Cohere v4 response body: {e}"))
1472 })?;
1473
1474 parse_cohere_embedding_response(&json)
1475 }
1476
1477 #[cfg(not(feature = "http-embeddings"))]
1478 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1479 Ok(stub_embedding(text, self.dimensions))
1480 }
1481
1482 fn dimension(&self) -> usize {
1483 self.dimensions
1484 }
1485}
1486
1487#[cfg(test)]
1492#[allow(clippy::unwrap_used, clippy::expect_used)]
1493mod tests {
1494 use super::*;
1495
1496 #[test]
1499 fn test_openai_provider_default_model() {
1500 let p = OpenAiEmbeddingProvider::new("sk-test", None);
1501 assert_eq!(p.model(), "text-embedding-3-small");
1502 assert_eq!(p.dimension(), 1536);
1503 }
1504
1505 #[test]
1506 fn test_openai_provider_large_model() {
1507 let p = OpenAiEmbeddingProvider::new("sk-test", Some("text-embedding-3-large".into()));
1508 assert_eq!(p.dimension(), 3072);
1509 }
1510
1511 #[test]
1512 fn test_openai_provider_custom_dimensions() {
1513 let p = OpenAiEmbeddingProvider::new("sk-test", None).with_dimensions(512);
1514 assert_eq!(p.dimension(), 512);
1515 }
1516
1517 #[test]
1518 fn test_openai_provider_custom_base_url() {
1519 let p = OpenAiEmbeddingProvider::with_base_url(
1520 "sk-test",
1521 None,
1522 "https://my-azure.openai.azure.com/openai/deployments/embed",
1523 );
1524 assert_eq!(p.dimension(), 1536);
1525 }
1526
1527 #[cfg(not(feature = "http-embeddings"))]
1528 #[tokio::test]
1529 async fn test_openai_provider_returns_feature_error() {
1530 let p = OpenAiEmbeddingProvider::new("sk-test", None);
1531 let err = p.embed("hello").await.unwrap_err();
1532 let msg = format!("{err}");
1533 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1534 }
1535
1536 #[test]
1537 fn test_cohere_provider_default() {
1538 let p = CohereEmbeddingProvider::new("key", None);
1539 assert_eq!(p.model(), "embed-english-v3.0");
1540 assert_eq!(p.dimension(), 1024);
1541 assert_eq!(p.input_type(), "search_document");
1542 }
1543
1544 #[test]
1545 fn test_cohere_provider_query_input_type() {
1546 let p = CohereEmbeddingProvider::new("key", None).with_input_type("search_query");
1547 assert_eq!(p.input_type(), "search_query");
1548 }
1549
1550 #[test]
1551 fn test_cohere_provider_light_model() {
1552 let p = CohereEmbeddingProvider::new("key", Some("embed-english-light-v3.0".into()));
1553 assert_eq!(p.dimension(), 384);
1554 }
1555
1556 #[cfg(not(feature = "http-embeddings"))]
1557 #[tokio::test]
1558 async fn test_cohere_provider_returns_feature_error() {
1559 let p = CohereEmbeddingProvider::new("key", None);
1560 let err = p.embed("hello").await.unwrap_err();
1561 let msg = format!("{err}");
1562 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1563 }
1564
1565 #[test]
1566 fn test_voyage_provider_default() {
1567 let p = VoyageEmbeddingProvider::new("key", None);
1568 assert_eq!(p.model(), "voyage-2");
1569 assert_eq!(p.dimension(), 1024);
1570 }
1571
1572 #[test]
1573 fn test_voyage_provider_code_model() {
1574 let p = VoyageEmbeddingProvider::new("key", Some("voyage-code-2".into()));
1575 assert_eq!(p.dimension(), 1536);
1576 }
1577
1578 #[cfg(not(feature = "http-embeddings"))]
1579 #[tokio::test]
1580 async fn test_voyage_provider_returns_feature_error() {
1581 let p = VoyageEmbeddingProvider::new("key", None);
1582 let err = p.embed("hello").await.unwrap_err();
1583 let msg = format!("{err}");
1584 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1585 }
1586
1587 #[test]
1590 fn test_parse_openai_embedding_response_valid() {
1591 let json = serde_json::json!({
1592 "data": [
1593 {
1594 "embedding": [0.1, 0.2, 0.3, 0.4],
1595 "index": 0
1596 }
1597 ],
1598 "model": "text-embedding-3-small"
1599 });
1600 let result = parse_openai_embedding_response(&json).unwrap();
1601 assert_eq!(result, vec![0.1, 0.2, 0.3, 0.4]);
1602 }
1603
1604 #[test]
1605 fn test_parse_openai_embedding_response_empty_data() {
1606 let json = serde_json::json!({
1607 "data": [],
1608 "model": "text-embedding-3-small"
1609 });
1610 let err = parse_openai_embedding_response(&json).unwrap_err();
1611 let msg = format!("{err}");
1612 assert!(msg.contains("no embedding data"), "got: {msg}");
1613 }
1614
1615 #[test]
1616 fn test_parse_openai_embedding_response_invalid_shape() {
1617 let json = serde_json::json!({ "error": "bad request" });
1618 let err = parse_openai_embedding_response(&json).unwrap_err();
1619 let msg = format!("{err}");
1620 assert!(msg.contains("Failed to parse"), "got: {msg}");
1621 }
1622
1623 #[test]
1624 fn test_parse_openai_embedding_response_multiple_picks_first() {
1625 let json = serde_json::json!({
1626 "data": [
1627 { "embedding": [1.0, 2.0], "index": 0 },
1628 { "embedding": [3.0, 4.0], "index": 1 }
1629 ],
1630 "model": "text-embedding-3-small"
1631 });
1632 let result = parse_openai_embedding_response(&json).unwrap();
1633 assert_eq!(result, vec![1.0, 2.0]);
1634 }
1635
1636 #[test]
1637 fn test_parse_cohere_embedding_response_valid() {
1638 let json = serde_json::json!({
1639 "embeddings": {
1640 "float": [
1641 [0.5, 0.6, 0.7]
1642 ]
1643 }
1644 });
1645 let result = parse_cohere_embedding_response(&json).unwrap();
1646 assert_eq!(result, vec![0.5, 0.6, 0.7]);
1647 }
1648
1649 #[test]
1650 fn test_parse_cohere_embedding_response_empty_float() {
1651 let json = serde_json::json!({
1652 "embeddings": {
1653 "float": []
1654 }
1655 });
1656 let err = parse_cohere_embedding_response(&json).unwrap_err();
1657 let msg = format!("{err}");
1658 assert!(msg.contains("no float embeddings"), "got: {msg}");
1659 }
1660
1661 #[test]
1662 fn test_parse_cohere_embedding_response_invalid_shape() {
1663 let json = serde_json::json!({ "message": "unauthorized" });
1664 let err = parse_cohere_embedding_response(&json).unwrap_err();
1665 let msg = format!("{err}");
1666 assert!(msg.contains("Failed to parse"), "got: {msg}");
1667 }
1668
1669 #[test]
1670 fn test_parse_cohere_embedding_response_missing_float_key() {
1671 let json = serde_json::json!({
1673 "embeddings": {}
1674 });
1675 let err = parse_cohere_embedding_response(&json).unwrap_err();
1676 let msg = format!("{err}");
1677 assert!(msg.contains("no float embeddings"), "got: {msg}");
1678 }
1679
1680 #[test]
1681 fn test_parse_voyage_embedding_response_valid() {
1682 let json = serde_json::json!({
1683 "data": [
1684 {
1685 "embedding": [0.9, 0.8, 0.7, 0.6, 0.5],
1686 "index": 0
1687 }
1688 ]
1689 });
1690 let result = parse_voyage_embedding_response(&json).unwrap();
1691 assert_eq!(result, vec![0.9, 0.8, 0.7, 0.6, 0.5]);
1692 }
1693
1694 #[test]
1695 fn test_parse_voyage_embedding_response_empty_data() {
1696 let json = serde_json::json!({ "data": [] });
1697 let err = parse_voyage_embedding_response(&json).unwrap_err();
1698 let msg = format!("{err}");
1699 assert!(msg.contains("no embedding data"), "got: {msg}");
1700 }
1701
1702 #[test]
1703 fn test_parse_voyage_embedding_response_invalid_shape() {
1704 let json = serde_json::json!({ "error": "invalid key" });
1705 let err = parse_voyage_embedding_response(&json).unwrap_err();
1706 let msg = format!("{err}");
1707 assert!(msg.contains("Failed to parse"), "got: {msg}");
1708 }
1709
1710 #[tokio::test]
1713 async fn test_cache_hit() {
1714 let local = Arc::new(LocalEmbedding::new(64));
1715 let cached = CachedEmbeddingProvider::new(local, 100);
1716
1717 let v1 = cached.embed("hello world").await.unwrap();
1718 let v2 = cached.embed("hello world").await.unwrap();
1719 assert_eq!(v1, v2);
1720
1721 let stats = cached.cache_stats().await;
1722 assert_eq!(stats.hits, 1);
1723 assert_eq!(stats.misses, 1);
1724 assert_eq!(stats.size, 1);
1725 }
1726
1727 #[tokio::test]
1728 async fn test_cache_miss_different_texts() {
1729 let local = Arc::new(LocalEmbedding::new(64));
1730 let cached = CachedEmbeddingProvider::new(local, 100);
1731
1732 let _ = cached.embed("alpha").await.unwrap();
1733 let _ = cached.embed("bravo").await.unwrap();
1734
1735 let stats = cached.cache_stats().await;
1736 assert_eq!(stats.misses, 2);
1737 assert_eq!(stats.hits, 0);
1738 assert_eq!(stats.size, 2);
1739 }
1740
1741 #[tokio::test]
1742 async fn test_cache_eviction() {
1743 let local = Arc::new(LocalEmbedding::new(64));
1744 let cached = CachedEmbeddingProvider::new(local, 2);
1745
1746 let _ = cached.embed("one").await.unwrap();
1747 let _ = cached.embed("two").await.unwrap();
1748 let _ = cached.embed("three").await.unwrap();
1749
1750 let stats = cached.cache_stats().await;
1751 assert!(stats.size <= 2, "size={} should be <= 2", stats.size);
1753 assert_eq!(stats.misses, 3);
1754 }
1755
1756 #[tokio::test]
1757 async fn test_cache_clear() {
1758 let local = Arc::new(LocalEmbedding::new(64));
1759 let cached = CachedEmbeddingProvider::new(local, 100);
1760
1761 let _ = cached.embed("text").await.unwrap();
1762 cached.clear().await;
1763
1764 let stats = cached.cache_stats().await;
1765 assert_eq!(stats.size, 0);
1766 }
1767
1768 #[tokio::test]
1769 async fn test_cache_dimension_delegates() {
1770 let local = Arc::new(LocalEmbedding::new(128));
1771 let cached = CachedEmbeddingProvider::new(local, 10);
1772 assert_eq!(cached.dimension(), 128);
1773 }
1774
1775 #[tokio::test]
1778 async fn test_batch_embed() {
1779 let local = Arc::new(LocalEmbedding::new(64));
1780 let batch = BatchEmbeddingProvider::new(local);
1781
1782 let results = batch
1783 .embed_batch(&["hello", "world", "test"])
1784 .await
1785 .unwrap();
1786 assert_eq!(results.len(), 3);
1787 for v in &results {
1788 assert_eq!(v.len(), 64);
1789 }
1790 }
1791
1792 #[tokio::test]
1793 async fn test_batch_single_embed_delegates() {
1794 let local = Arc::new(LocalEmbedding::new(64));
1795 let batch = BatchEmbeddingProvider::new(local);
1796
1797 let v = batch.embed("hello").await.unwrap();
1798 assert_eq!(v.len(), 64);
1799 }
1800
1801 #[tokio::test]
1802 async fn test_batch_empty() {
1803 let local = Arc::new(LocalEmbedding::new(64));
1804 let batch = BatchEmbeddingProvider::new(local);
1805
1806 let results = batch.embed_batch(&[]).await.unwrap();
1807 assert!(results.is_empty());
1808 }
1809
1810 #[tokio::test]
1811 async fn test_batch_dimension_delegates() {
1812 let local = Arc::new(LocalEmbedding::new(200));
1813 let batch = BatchEmbeddingProvider::new(local);
1814 assert_eq!(batch.dimension(), 200);
1815 }
1816
1817 #[test]
1820 fn test_factory_create_local() {
1821 let p = EmbeddingProviderFactory::create("local", "", None).unwrap();
1822 assert_eq!(p.dimension(), 256);
1823 }
1824
1825 #[test]
1826 fn test_factory_create_local_custom_dim() {
1827 let p = EmbeddingProviderFactory::create("local", "", Some("128".into())).unwrap();
1828 assert_eq!(p.dimension(), 128);
1829 }
1830
1831 #[test]
1832 fn test_factory_create_openai() {
1833 let p = EmbeddingProviderFactory::create("openai", "sk-test", None).unwrap();
1834 assert_eq!(p.dimension(), 1536);
1835 }
1836
1837 #[test]
1838 fn test_factory_create_cohere() {
1839 let p = EmbeddingProviderFactory::create("cohere", "key", None).unwrap();
1840 assert_eq!(p.dimension(), 1024);
1841 }
1842
1843 #[test]
1844 fn test_factory_create_voyage() {
1845 let p = EmbeddingProviderFactory::create("voyage", "key", None).unwrap();
1846 assert_eq!(p.dimension(), 1024);
1847 }
1848
1849 #[test]
1850 fn test_factory_unknown_provider() {
1851 let result = EmbeddingProviderFactory::create("unknown", "", None);
1852 assert!(result.is_err(), "Unknown provider should return Err");
1853 }
1854
1855 #[test]
1856 fn test_factory_available_providers() {
1857 let names = EmbeddingProviderFactory::available_providers();
1858 assert!(names.contains(&"openai"));
1859 assert!(names.contains(&"cohere"));
1860 assert!(names.contains(&"voyage"));
1861 assert!(names.contains(&"local"));
1862 }
1863
1864 #[test]
1867 fn test_config_default() {
1868 let cfg = EmbeddingConfig::default();
1869 assert_eq!(cfg.provider, "local");
1870 assert!(cfg.api_key.is_empty());
1871 assert!(cfg.model.is_none());
1872 assert!(cfg.dimensions.is_none());
1873 assert!(cfg.base_url.is_none());
1874 assert!(cfg.cache_size.is_none());
1875 }
1876
1877 #[test]
1878 fn test_config_serialize_deserialize() {
1879 let cfg = EmbeddingConfig {
1880 provider: "openai".to_string(),
1881 api_key: "sk-123".to_string(),
1882 model: Some("text-embedding-3-small".to_string()),
1883 dimensions: Some(1536),
1884 base_url: None,
1885 cache_size: Some(500),
1886 };
1887 let json = serde_json::to_string(&cfg).unwrap();
1888 let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
1889 assert_eq!(parsed.provider, "openai");
1890 assert_eq!(parsed.api_key, "sk-123");
1891 assert_eq!(parsed.dimensions, Some(1536));
1892 assert_eq!(parsed.cache_size, Some(500));
1893 }
1894
1895 #[test]
1896 fn test_config_deserialize_minimal() {
1897 let json = r#"{"provider":"local"}"#;
1898 let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
1899 assert_eq!(cfg.provider, "local");
1900 assert!(cfg.api_key.is_empty());
1901 }
1902
1903 #[tokio::test]
1904 async fn test_config_build_local() {
1905 let cfg = EmbeddingConfig::default();
1906 let provider = cfg.build().unwrap();
1907 assert_eq!(provider.dimension(), 256);
1908 let v = provider.embed("test text").await.unwrap();
1909 assert_eq!(v.len(), 256);
1910 }
1911
1912 #[tokio::test]
1913 async fn test_config_build_local_with_cache() {
1914 let cfg = EmbeddingConfig {
1915 provider: "local".to_string(),
1916 cache_size: Some(50),
1917 ..Default::default()
1918 };
1919 let provider = cfg.build().unwrap();
1920 assert_eq!(provider.dimension(), 256);
1922 let v1 = provider.embed("cached text").await.unwrap();
1924 let v2 = provider.embed("cached text").await.unwrap();
1925 assert_eq!(v1, v2);
1926 }
1927
1928 #[tokio::test]
1929 async fn test_config_build_local_custom_dimensions() {
1930 let cfg = EmbeddingConfig {
1931 provider: "local".to_string(),
1932 dimensions: Some(512),
1933 ..Default::default()
1934 };
1935 let provider = cfg.build().unwrap();
1936 assert_eq!(provider.dimension(), 512);
1937 }
1938
1939 #[test]
1940 fn test_config_build_unknown_provider() {
1941 let cfg = EmbeddingConfig {
1942 provider: "imaginary".to_string(),
1943 ..Default::default()
1944 };
1945 assert!(cfg.build().is_err());
1946 }
1947
1948 #[test]
1951 fn test_fnv_hash_deterministic() {
1952 let h1 = fnv1a_hash(b"hello world");
1953 let h2 = fnv1a_hash(b"hello world");
1954 assert_eq!(h1, h2);
1955 }
1956
1957 #[test]
1958 fn test_fnv_hash_different_inputs() {
1959 let h1 = fnv1a_hash(b"alpha");
1960 let h2 = fnv1a_hash(b"bravo");
1961 assert_ne!(h1, h2);
1962 }
1963
1964 #[test]
1969 fn test_stub_embedding_length() {
1970 let v = stub_embedding("hello", 128);
1971 assert_eq!(v.len(), 128);
1972 }
1973
1974 #[test]
1975 fn test_stub_embedding_deterministic() {
1976 let v1 = stub_embedding("same input", 64);
1977 let v2 = stub_embedding("same input", 64);
1978 assert_eq!(v1, v2);
1979 }
1980
1981 #[test]
1982 fn test_stub_embedding_normalized() {
1983 let v = stub_embedding("the quick brown fox", 256);
1984 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1985 assert!((norm - 1.0).abs() < 0.01, "norm={norm}");
1986 }
1987
1988 #[test]
1989 fn test_stub_embedding_different_inputs_differ() {
1990 let a = stub_embedding("alpha", 64);
1991 let b = stub_embedding("bravo", 64);
1992 assert_ne!(a, b);
1993 }
1994
1995 #[test]
1996 fn test_stub_embedding_empty_text_zeroes() {
1997 let v = stub_embedding("", 32);
1998 assert_eq!(v.len(), 32);
1999 assert!(v.iter().all(|&x| x == 0.0));
2000 }
2001
2002 #[test]
2003 fn test_stub_embedding_zero_dimension_safe() {
2004 let v = stub_embedding("hi", 0);
2006 assert_eq!(v.len(), 1);
2007 }
2008
2009 #[test]
2014 fn test_jina_default_construction() {
2015 let p = JinaEmbeddingProvider::new("jina-key");
2016 assert_eq!(p.model(), "jina-embeddings-v3");
2017 assert_eq!(p.dimension(), 1024);
2018 }
2019
2020 #[test]
2021 fn test_jina_with_model_clip() {
2022 let p = JinaEmbeddingProvider::with_model("k", "jina-clip-v2", 768);
2023 assert_eq!(p.model(), "jina-clip-v2");
2024 assert_eq!(p.dimension(), 768);
2025 }
2026
2027 #[test]
2028 fn test_jina_with_base_url() {
2029 let p = JinaEmbeddingProvider::new("k").with_base_url("https://custom.jina/v1");
2030 assert_eq!(p.model(), "jina-embeddings-v3");
2032 }
2033
2034 #[test]
2035 fn test_jina_build_payload_shape() {
2036 let p = JinaEmbeddingProvider::new("k");
2037 let payload = p.build_payload(&["hello".to_string(), "world".to_string()]);
2038 assert_eq!(payload["model"], "jina-embeddings-v3");
2039 assert_eq!(payload["input"][0], "hello");
2040 assert_eq!(payload["input"][1], "world");
2041 }
2042
2043 #[tokio::test]
2044 async fn test_jina_embed_length_matches_dimension() {
2045 let p = JinaEmbeddingProvider::new("k");
2046 #[cfg(not(feature = "http-embeddings"))]
2047 {
2048 let v = p.embed("hello jina").await.unwrap();
2049 assert_eq!(v.len(), 1024);
2050 }
2051 assert_eq!(p.dimension(), 1024);
2054 }
2055
2056 #[cfg(not(feature = "http-embeddings"))]
2057 #[tokio::test]
2058 async fn test_jina_stub_is_normalized() {
2059 let p = JinaEmbeddingProvider::new("k");
2060 let v = p.embed("some input").await.unwrap();
2061 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2062 assert!((norm - 1.0).abs() < 0.01);
2063 }
2064
2065 #[cfg(not(feature = "http-embeddings"))]
2066 #[tokio::test]
2067 async fn test_jina_stub_deterministic() {
2068 let p = JinaEmbeddingProvider::new("k");
2069 let a = p.embed("consistent").await.unwrap();
2070 let b = p.embed("consistent").await.unwrap();
2071 assert_eq!(a, b);
2072 }
2073
2074 #[test]
2079 fn test_mistral_default_construction() {
2080 let p = MistralEmbedProvider::new("mistral-key");
2081 assert_eq!(p.model(), "mistral-embed");
2082 assert_eq!(p.dimension(), 1024);
2083 }
2084
2085 #[test]
2086 fn test_mistral_with_model_and_dimensions() {
2087 let p = MistralEmbedProvider::with_model("k", "mistral-embed-large", 2048);
2088 assert_eq!(p.model(), "mistral-embed-large");
2089 assert_eq!(p.dimension(), 2048);
2090 }
2091
2092 #[test]
2093 fn test_mistral_build_payload_shape() {
2094 let p = MistralEmbedProvider::new("k");
2095 let payload = p.build_payload(&["alpha".to_string()]);
2096 assert_eq!(payload["model"], "mistral-embed");
2097 assert_eq!(payload["input"][0], "alpha");
2098 }
2099
2100 #[test]
2101 fn test_mistral_with_base_url() {
2102 let p = MistralEmbedProvider::new("k").with_base_url("https://custom.mistral/v1");
2103 assert_eq!(p.dimension(), 1024);
2104 }
2105
2106 #[cfg(not(feature = "http-embeddings"))]
2107 #[tokio::test]
2108 async fn test_mistral_embed_length() {
2109 let p = MistralEmbedProvider::new("k");
2110 let v = p.embed("hello mistral").await.unwrap();
2111 assert_eq!(v.len(), 1024);
2112 }
2113
2114 #[cfg(not(feature = "http-embeddings"))]
2115 #[tokio::test]
2116 async fn test_mistral_stub_normalized() {
2117 let p = MistralEmbedProvider::new("k");
2118 let v = p.embed("normalized?").await.unwrap();
2119 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2120 assert!((norm - 1.0).abs() < 0.01);
2121 }
2122
2123 #[test]
2128 fn test_nomic_default_construction() {
2129 let p = NomicEmbedProvider::new("nomic-key");
2130 assert_eq!(p.model(), "nomic-embed-text-v1.5");
2131 assert_eq!(p.dimension(), 768);
2132 assert_eq!(p.task_type(), "search_document");
2133 }
2134
2135 #[test]
2136 fn test_nomic_with_task_type() {
2137 let p = NomicEmbedProvider::new("k").with_task_type("search_query");
2138 assert_eq!(p.task_type(), "search_query");
2139 }
2140
2141 #[test]
2142 fn test_nomic_build_payload_shape() {
2143 let p = NomicEmbedProvider::new("k").with_task_type("clustering");
2144 let payload = p.build_payload(&["doc a".to_string(), "doc b".to_string()]);
2145 assert_eq!(payload["model"], "nomic-embed-text-v1.5");
2146 assert_eq!(payload["texts"][0], "doc a");
2147 assert_eq!(payload["texts"][1], "doc b");
2148 assert_eq!(payload["task_type"], "clustering");
2149 }
2150
2151 #[test]
2152 fn test_nomic_with_model_custom_dims() {
2153 let p = NomicEmbedProvider::with_model("k", "custom-nomic", 512);
2154 assert_eq!(p.dimension(), 512);
2155 }
2156
2157 #[cfg(not(feature = "http-embeddings"))]
2158 #[tokio::test]
2159 async fn test_nomic_embed_length() {
2160 let p = NomicEmbedProvider::new("k");
2161 let v = p.embed("nomic test").await.unwrap();
2162 assert_eq!(v.len(), 768);
2163 }
2164
2165 #[cfg(not(feature = "http-embeddings"))]
2166 #[tokio::test]
2167 async fn test_nomic_embed_normalized() {
2168 let p = NomicEmbedProvider::new("k");
2169 let v = p.embed("some text").await.unwrap();
2170 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2171 assert!((norm - 1.0).abs() < 0.01);
2172 }
2173
2174 #[test]
2179 fn test_sentence_transformers_default_construction() {
2180 let p = SentenceTransformersProvider::new("hf-key");
2181 assert_eq!(p.model(), "sentence-transformers/all-MiniLM-L6-v2");
2182 assert_eq!(p.dimension(), 384);
2183 }
2184
2185 #[test]
2186 fn test_sentence_transformers_mpnet_dims() {
2187 let dims = SentenceTransformersProvider::default_dimensions(
2188 "sentence-transformers/all-mpnet-base-v2",
2189 );
2190 assert_eq!(dims, 768);
2191 }
2192
2193 #[test]
2194 fn test_sentence_transformers_multi_qa_dims() {
2195 let dims = SentenceTransformersProvider::default_dimensions(
2196 "sentence-transformers/multi-qa-mpnet-base-dot-v1",
2197 );
2198 assert_eq!(dims, 768);
2199 }
2200
2201 #[test]
2202 fn test_sentence_transformers_unknown_model_fallback() {
2203 let dims = SentenceTransformersProvider::default_dimensions("sentence-transformers/unknown");
2204 assert_eq!(dims, 384);
2205 }
2206
2207 #[test]
2208 fn test_sentence_transformers_with_model() {
2209 let p = SentenceTransformersProvider::with_model(
2210 "k",
2211 "sentence-transformers/all-mpnet-base-v2",
2212 768,
2213 );
2214 assert_eq!(p.model(), "sentence-transformers/all-mpnet-base-v2");
2215 assert_eq!(p.dimension(), 768);
2216 }
2217
2218 #[test]
2219 fn test_sentence_transformers_build_payload_shape() {
2220 let p = SentenceTransformersProvider::new("k");
2221 let payload = p.build_payload(&["hi".to_string()]);
2222 assert_eq!(payload["inputs"][0], "hi");
2223 assert_eq!(payload["options"]["wait_for_model"], true);
2224 }
2225
2226 #[test]
2227 fn test_sentence_transformers_with_base_url() {
2228 let p = SentenceTransformersProvider::new("k")
2229 .with_base_url("https://self-hosted.hf/embed");
2230 assert_eq!(p.dimension(), 384);
2231 }
2232
2233 #[cfg(not(feature = "http-embeddings"))]
2234 #[tokio::test]
2235 async fn test_sentence_transformers_embed_length() {
2236 let p = SentenceTransformersProvider::new("k");
2237 let v = p.embed("minilm test").await.unwrap();
2238 assert_eq!(v.len(), 384);
2239 }
2240
2241 #[cfg(not(feature = "http-embeddings"))]
2242 #[tokio::test]
2243 async fn test_sentence_transformers_embed_normalized() {
2244 let p = SentenceTransformersProvider::new("k");
2245 let v = p.embed("some input").await.unwrap();
2246 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2247 assert!((norm - 1.0).abs() < 0.01);
2248 }
2249
2250 #[test]
2255 fn test_together_default_construction() {
2256 let p = TogetherEmbedProvider::new("together-key");
2257 assert_eq!(p.model(), "togethercomputer/m2-bert-80M-32k-retrieval");
2258 assert_eq!(p.dimension(), 768);
2259 }
2260
2261 #[test]
2262 fn test_together_with_model() {
2263 let p = TogetherEmbedProvider::with_model("k", "togethercomputer/custom", 1024);
2264 assert_eq!(p.model(), "togethercomputer/custom");
2265 assert_eq!(p.dimension(), 1024);
2266 }
2267
2268 #[test]
2269 fn test_together_build_payload_shape() {
2270 let p = TogetherEmbedProvider::new("k");
2271 let payload = p.build_payload(&["x".to_string(), "y".to_string()]);
2272 assert_eq!(payload["model"], "togethercomputer/m2-bert-80M-32k-retrieval");
2273 assert_eq!(payload["input"][0], "x");
2274 assert_eq!(payload["input"][1], "y");
2275 }
2276
2277 #[test]
2278 fn test_together_with_base_url() {
2279 let p = TogetherEmbedProvider::new("k").with_base_url("https://custom.together/v1");
2280 assert_eq!(p.dimension(), 768);
2281 }
2282
2283 #[cfg(not(feature = "http-embeddings"))]
2284 #[tokio::test]
2285 async fn test_together_embed_length() {
2286 let p = TogetherEmbedProvider::new("k");
2287 let v = p.embed("together test").await.unwrap();
2288 assert_eq!(v.len(), 768);
2289 }
2290
2291 #[cfg(not(feature = "http-embeddings"))]
2292 #[tokio::test]
2293 async fn test_together_embed_normalized() {
2294 let p = TogetherEmbedProvider::new("k");
2295 let v = p.embed("text").await.unwrap();
2296 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2297 assert!((norm - 1.0).abs() < 0.01);
2298 }
2299
2300 #[test]
2305 fn test_cohere_v4_default_construction() {
2306 let p = CohereEmbedV4Provider::new("cohere-key");
2307 assert_eq!(p.model(), "embed-english-v3.0");
2308 assert_eq!(p.dimension(), 1024);
2309 assert_eq!(p.input_type(), "search_document");
2310 }
2311
2312 #[test]
2313 fn test_cohere_v4_multilingual_model() {
2314 let p = CohereEmbedV4Provider::with_model("k", "embed-multilingual-v3.0", 1024);
2315 assert_eq!(p.model(), "embed-multilingual-v3.0");
2316 assert_eq!(p.dimension(), 1024);
2317 }
2318
2319 #[test]
2320 fn test_cohere_v4_for_search_document() {
2321 let p = CohereEmbedV4Provider::new("k").for_search_document();
2322 assert_eq!(p.input_type(), "search_document");
2323 }
2324
2325 #[test]
2326 fn test_cohere_v4_for_search_query() {
2327 let p = CohereEmbedV4Provider::new("k").for_search_query();
2328 assert_eq!(p.input_type(), "search_query");
2329 }
2330
2331 #[test]
2332 fn test_cohere_v4_with_input_type() {
2333 let p = CohereEmbedV4Provider::new("k").with_input_type("classification");
2334 assert_eq!(p.input_type(), "classification");
2335 }
2336
2337 #[test]
2338 fn test_cohere_v4_build_payload_shape_document() {
2339 let p = CohereEmbedV4Provider::new("k").for_search_document();
2340 let payload = p.build_payload(&["doc".to_string()]);
2341 assert_eq!(payload["model"], "embed-english-v3.0");
2342 assert_eq!(payload["texts"][0], "doc");
2343 assert_eq!(payload["input_type"], "search_document");
2344 assert_eq!(payload["embedding_types"][0], "float");
2345 }
2346
2347 #[test]
2348 fn test_cohere_v4_build_payload_shape_query() {
2349 let p = CohereEmbedV4Provider::new("k").for_search_query();
2350 let payload = p.build_payload(&["q".to_string()]);
2351 assert_eq!(payload["input_type"], "search_query");
2352 }
2353
2354 #[test]
2355 fn test_cohere_v4_with_base_url() {
2356 let p = CohereEmbedV4Provider::new("k").with_base_url("https://custom.cohere/v2/embed");
2357 assert_eq!(p.dimension(), 1024);
2358 }
2359
2360 #[cfg(not(feature = "http-embeddings"))]
2361 #[tokio::test]
2362 async fn test_cohere_v4_embed_length() {
2363 let p = CohereEmbedV4Provider::new("k");
2364 let v = p.embed("cohere v4 test").await.unwrap();
2365 assert_eq!(v.len(), 1024);
2366 }
2367
2368 #[cfg(not(feature = "http-embeddings"))]
2369 #[tokio::test]
2370 async fn test_cohere_v4_embed_normalized() {
2371 let p = CohereEmbedV4Provider::new("k");
2372 let v = p.embed("x").await.unwrap();
2373 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2374 assert!((norm - 1.0).abs() < 0.01);
2375 }
2376
2377 #[cfg(not(feature = "http-embeddings"))]
2378 #[tokio::test]
2379 async fn test_cohere_v4_embed_deterministic() {
2380 let p = CohereEmbedV4Provider::new("k");
2381 let a = p.embed("same").await.unwrap();
2382 let b = p.embed("same").await.unwrap();
2383 assert_eq!(a, b);
2384 }
2385
2386 #[test]
2391 fn test_all_new_providers_implement_embedding_provider_trait() {
2392 let _boxes: Vec<Box<dyn EmbeddingProvider>> = vec![
2395 Box::new(JinaEmbeddingProvider::new("k")),
2396 Box::new(MistralEmbedProvider::new("k")),
2397 Box::new(NomicEmbedProvider::new("k")),
2398 Box::new(SentenceTransformersProvider::new("k")),
2399 Box::new(TogetherEmbedProvider::new("k")),
2400 Box::new(CohereEmbedV4Provider::new("k")),
2401 ];
2402 }
2403
2404 #[test]
2405 fn test_new_providers_have_expected_dimensions() {
2406 assert_eq!(JinaEmbeddingProvider::new("k").dimension(), 1024);
2407 assert_eq!(MistralEmbedProvider::new("k").dimension(), 1024);
2408 assert_eq!(NomicEmbedProvider::new("k").dimension(), 768);
2409 assert_eq!(SentenceTransformersProvider::new("k").dimension(), 384);
2410 assert_eq!(TogetherEmbedProvider::new("k").dimension(), 768);
2411 assert_eq!(CohereEmbedV4Provider::new("k").dimension(), 1024);
2412 }
2413}