1use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use tokio::sync::RwLock;
32
33use argentor_core::{ArgentorError, ArgentorResult};
34
35use crate::embedding::{EmbeddingProvider, LocalEmbedding};
36
37fn fnv1a_hash(data: &[u8]) -> u64 {
43 let mut hash: u64 = 14695981039346656037;
44 for &byte in data {
45 hash ^= byte as u64;
46 hash = hash.wrapping_mul(1099511628211);
47 }
48 hash
49}
50
51#[derive(Debug, Serialize)]
57pub struct OpenAiEmbeddingRequest {
58 pub model: String,
60 pub input: Vec<String>,
62}
63
64#[derive(Debug, Deserialize)]
66pub struct OpenAiEmbeddingObject {
67 pub embedding: Vec<f32>,
69 pub index: usize,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct OpenAiEmbeddingResponse {
76 pub data: Vec<OpenAiEmbeddingObject>,
78 pub model: String,
80}
81
82#[derive(Debug, Serialize)]
84pub struct CohereEmbedRequest {
85 pub model: String,
87 pub texts: Vec<String>,
89 pub input_type: String,
91 pub embedding_types: Vec<String>,
93}
94
95#[derive(Debug, Deserialize)]
97pub struct CohereEmbedResponse {
98 pub embeddings: CohereEmbeddingsMap,
100}
101
102#[derive(Debug, Deserialize)]
104pub struct CohereEmbeddingsMap {
105 #[serde(default)]
107 pub float: Vec<Vec<f32>>,
108}
109
110#[derive(Debug, Serialize)]
112pub struct VoyageEmbeddingRequest {
113 pub model: String,
115 pub input: Vec<String>,
117}
118
119#[derive(Debug, Deserialize)]
121pub struct VoyageEmbeddingObject {
122 pub embedding: Vec<f32>,
124 pub index: usize,
126}
127
128#[derive(Debug, Deserialize)]
130pub struct VoyageEmbeddingResponse {
131 pub data: Vec<VoyageEmbeddingObject>,
133}
134
135pub fn parse_openai_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
144 let response: OpenAiEmbeddingResponse = serde_json::from_value(json.clone())
145 .map_err(|e| ArgentorError::Agent(format!("Failed to parse OpenAI response: {e}")))?;
146 response
147 .data
148 .into_iter()
149 .next()
150 .map(|obj| obj.embedding)
151 .ok_or_else(|| {
152 ArgentorError::Agent("OpenAI response contains no embedding data".to_string())
153 })
154}
155
156pub fn parse_cohere_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
161 let response: CohereEmbedResponse = serde_json::from_value(json.clone())
162 .map_err(|e| ArgentorError::Agent(format!("Failed to parse Cohere response: {e}")))?;
163 response.embeddings.float.into_iter().next().ok_or_else(|| {
164 ArgentorError::Agent("Cohere response contains no float embeddings".to_string())
165 })
166}
167
168pub fn parse_voyage_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
173 let response: VoyageEmbeddingResponse = serde_json::from_value(json.clone())
174 .map_err(|e| ArgentorError::Agent(format!("Failed to parse Voyage response: {e}")))?;
175 response
176 .data
177 .into_iter()
178 .next()
179 .map(|obj| obj.embedding)
180 .ok_or_else(|| {
181 ArgentorError::Agent("Voyage response contains no embedding data".to_string())
182 })
183}
184
185pub struct OpenAiEmbeddingProvider {
196 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
197 api_key: String,
198 model: String,
199 dimensions: usize,
200 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
201 base_url: String,
202}
203
204impl OpenAiEmbeddingProvider {
205 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
209 let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
210 let dimensions = Self::default_dimensions(&model);
211 Self {
212 api_key: api_key.into(),
213 model,
214 dimensions,
215 base_url: "https://api.openai.com/v1/embeddings".to_string(),
216 }
217 }
218
219 pub fn with_base_url(
221 api_key: impl Into<String>,
222 model: Option<String>,
223 base_url: impl Into<String>,
224 ) -> Self {
225 let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
226 let dimensions = Self::default_dimensions(&model);
227 Self {
228 api_key: api_key.into(),
229 model,
230 dimensions,
231 base_url: base_url.into(),
232 }
233 }
234
235 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
237 self.dimensions = dimensions;
238 self
239 }
240
241 fn default_dimensions(model: &str) -> usize {
242 match model {
243 "text-embedding-3-large" => 3072,
244 "text-embedding-3-small" => 1536,
245 "text-embedding-ada-002" => 1536,
246 _ => 1536,
247 }
248 }
249
250 pub fn model(&self) -> &str {
252 &self.model
253 }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAiEmbeddingProvider {
258 #[cfg(feature = "http-embeddings")]
259 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
260 let client = reqwest::Client::new();
261 let response = client
262 .post(&self.base_url)
263 .header("Authorization", format!("Bearer {}", self.api_key))
264 .json(&serde_json::json!({
265 "model": self.model,
266 "input": text,
267 }))
268 .send()
269 .await
270 .map_err(|e| ArgentorError::Http(format!("OpenAI embedding request failed: {e}")))?;
271
272 let status = response.status();
273 if !status.is_success() {
274 let body = response.text().await.unwrap_or_default();
275 return Err(ArgentorError::Http(format!(
276 "OpenAI API error {status}: {body}"
277 )));
278 }
279
280 let json: serde_json::Value = response.json().await.map_err(|e| {
281 ArgentorError::Http(format!("Failed to read OpenAI response body: {e}"))
282 })?;
283
284 parse_openai_embedding_response(&json)
285 }
286
287 #[cfg(not(feature = "http-embeddings"))]
288 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
289 Err(ArgentorError::Http(
290 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
291 or use LocalEmbedding for offline embeddings."
292 .to_string(),
293 ))
294 }
295
296 fn dimension(&self) -> usize {
297 self.dimensions
298 }
299}
300
301pub struct CohereEmbeddingProvider {
312 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
313 api_key: String,
314 model: String,
315 dimensions: usize,
316 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
317 input_type: String,
318}
319
320impl CohereEmbeddingProvider {
321 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
325 let model = model.unwrap_or_else(|| "embed-english-v3.0".to_string());
326 let dimensions = Self::default_dimensions(&model);
327 Self {
328 api_key: api_key.into(),
329 model,
330 dimensions,
331 input_type: "search_document".to_string(),
332 }
333 }
334
335 pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
337 self.input_type = input_type.into();
338 self
339 }
340
341 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
343 self.dimensions = dimensions;
344 self
345 }
346
347 fn default_dimensions(model: &str) -> usize {
348 match model {
349 "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
350 "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
351 _ => 1024,
352 }
353 }
354
355 pub fn model(&self) -> &str {
357 &self.model
358 }
359
360 pub fn input_type(&self) -> &str {
362 &self.input_type
363 }
364}
365
366#[async_trait]
367impl EmbeddingProvider for CohereEmbeddingProvider {
368 #[cfg(feature = "http-embeddings")]
369 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
370 let client = reqwest::Client::new();
371 let response = client
372 .post("https://api.cohere.com/v2/embed")
373 .header("Authorization", format!("Bearer {}", self.api_key))
374 .json(&serde_json::json!({
375 "model": self.model,
376 "texts": [text],
377 "input_type": self.input_type,
378 "embedding_types": ["float"],
379 }))
380 .send()
381 .await
382 .map_err(|e| ArgentorError::Http(format!("Cohere embedding request failed: {e}")))?;
383
384 let status = response.status();
385 if !status.is_success() {
386 let body = response.text().await.unwrap_or_default();
387 return Err(ArgentorError::Http(format!(
388 "Cohere API error {status}: {body}"
389 )));
390 }
391
392 let json: serde_json::Value = response.json().await.map_err(|e| {
393 ArgentorError::Http(format!("Failed to read Cohere response body: {e}"))
394 })?;
395
396 parse_cohere_embedding_response(&json)
397 }
398
399 #[cfg(not(feature = "http-embeddings"))]
400 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
401 Err(ArgentorError::Http(
402 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
403 or use LocalEmbedding for offline embeddings."
404 .to_string(),
405 ))
406 }
407
408 fn dimension(&self) -> usize {
409 self.dimensions
410 }
411}
412
413pub struct VoyageEmbeddingProvider {
424 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
425 api_key: String,
426 model: String,
427 dimensions: usize,
428}
429
430impl VoyageEmbeddingProvider {
431 pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
435 let model = model.unwrap_or_else(|| "voyage-2".to_string());
436 let dimensions = Self::default_dimensions(&model);
437 Self {
438 api_key: api_key.into(),
439 model,
440 dimensions,
441 }
442 }
443
444 pub fn with_dimensions(mut self, dimensions: usize) -> Self {
446 self.dimensions = dimensions;
447 self
448 }
449
450 fn default_dimensions(model: &str) -> usize {
451 match model {
452 "voyage-2" | "voyage-large-2" => 1024,
453 "voyage-lite-02-instruct" => 1024,
454 "voyage-3" => 1024,
455 "voyage-code-2" => 1536,
456 _ => 1024,
457 }
458 }
459
460 pub fn model(&self) -> &str {
462 &self.model
463 }
464}
465
466#[async_trait]
467impl EmbeddingProvider for VoyageEmbeddingProvider {
468 #[cfg(feature = "http-embeddings")]
469 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
470 let client = reqwest::Client::new();
471 let response = client
472 .post("https://api.voyageai.com/v1/embeddings")
473 .header("Authorization", format!("Bearer {}", self.api_key))
474 .json(&serde_json::json!({
475 "model": self.model,
476 "input": [text],
477 }))
478 .send()
479 .await
480 .map_err(|e| ArgentorError::Http(format!("Voyage embedding request failed: {e}")))?;
481
482 let status = response.status();
483 if !status.is_success() {
484 let body = response.text().await.unwrap_or_default();
485 return Err(ArgentorError::Http(format!(
486 "Voyage API error {status}: {body}"
487 )));
488 }
489
490 let json: serde_json::Value = response.json().await.map_err(|e| {
491 ArgentorError::Http(format!("Failed to read Voyage response body: {e}"))
492 })?;
493
494 parse_voyage_embedding_response(&json)
495 }
496
497 #[cfg(not(feature = "http-embeddings"))]
498 async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
499 Err(ArgentorError::Http(
500 "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
501 or use LocalEmbedding for offline embeddings."
502 .to_string(),
503 ))
504 }
505
506 fn dimension(&self) -> usize {
507 self.dimensions
508 }
509}
510
511#[derive(Debug, Clone, Default)]
517pub struct CacheStats {
518 pub hits: u64,
520 pub misses: u64,
522 pub size: usize,
524}
525
526pub struct CachedEmbeddingProvider {
531 inner: Arc<dyn EmbeddingProvider>,
532 cache: Arc<RwLock<HashMap<u64, Vec<f32>>>>,
533 max_cache_size: usize,
534 stats: Arc<RwLock<CacheStats>>,
535}
536
537impl CachedEmbeddingProvider {
538 pub fn new(inner: Arc<dyn EmbeddingProvider>, max_cache_size: usize) -> Self {
540 Self {
541 inner,
542 cache: Arc::new(RwLock::new(HashMap::new())),
543 max_cache_size,
544 stats: Arc::new(RwLock::new(CacheStats::default())),
545 }
546 }
547
548 pub async fn cache_stats(&self) -> CacheStats {
550 self.stats.read().await.clone()
551 }
552
553 pub async fn clear(&self) {
555 self.cache.write().await.clear();
556 let mut stats = self.stats.write().await;
557 stats.size = 0;
558 }
559
560 fn text_hash(text: &str) -> u64 {
561 fnv1a_hash(text.as_bytes())
562 }
563}
564
565#[async_trait]
566impl EmbeddingProvider for CachedEmbeddingProvider {
567 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
568 let key = Self::text_hash(text);
569
570 {
572 let cache = self.cache.read().await;
573 if let Some(cached) = cache.get(&key) {
574 let mut stats = self.stats.write().await;
575 stats.hits += 1;
576 return Ok(cached.clone());
577 }
578 }
579
580 let embedding = self.inner.embed(text).await?;
582
583 {
585 let mut cache = self.cache.write().await;
586
587 if cache.len() >= self.max_cache_size {
589 if let Some(&evict_key) = cache.keys().next() {
592 cache.remove(&evict_key);
593 }
594 }
595
596 cache.insert(key, embedding.clone());
597
598 let mut stats = self.stats.write().await;
599 stats.misses += 1;
600 stats.size = cache.len();
601 }
602
603 Ok(embedding)
604 }
605
606 fn dimension(&self) -> usize {
607 self.inner.dimension()
608 }
609}
610
611pub struct BatchEmbeddingProvider {
621 inner: Arc<dyn EmbeddingProvider>,
622}
623
624impl BatchEmbeddingProvider {
625 pub fn new(inner: Arc<dyn EmbeddingProvider>) -> Self {
627 Self { inner }
628 }
629
630 pub async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
632 self.inner.embed_batch(texts).await
633 }
634}
635
636#[async_trait]
637impl EmbeddingProvider for BatchEmbeddingProvider {
638 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
639 self.inner.embed(text).await
640 }
641
642 async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
643 self.inner.embed_batch(texts).await
644 }
645
646 fn dimension(&self) -> usize {
647 self.inner.dimension()
648 }
649}
650
651pub struct EmbeddingProviderFactory;
657
658impl EmbeddingProviderFactory {
659 pub fn create(
663 provider_name: &str,
664 api_key: impl Into<String>,
665 model: Option<String>,
666 ) -> ArgentorResult<Box<dyn EmbeddingProvider>> {
667 let api_key = api_key.into();
668 match provider_name {
669 "openai" => Ok(Box::new(OpenAiEmbeddingProvider::new(api_key, model))),
670 "cohere" => Ok(Box::new(CohereEmbeddingProvider::new(api_key, model))),
671 "voyage" => Ok(Box::new(VoyageEmbeddingProvider::new(api_key, model))),
672 "local" => {
673 let dim = model
674 .as_deref()
675 .and_then(|m| m.parse::<usize>().ok())
676 .unwrap_or(256);
677 Ok(Box::new(LocalEmbedding::new(dim)))
678 }
679 other => Err(ArgentorError::Config(format!(
680 "Unknown embedding provider: {other}"
681 ))),
682 }
683 }
684
685 pub fn available_providers() -> &'static [&'static str] {
687 &["openai", "cohere", "voyage", "local"]
688 }
689}
690
691#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct EmbeddingConfig {
698 pub provider: String,
700 #[serde(default)]
702 pub api_key: String,
703 #[serde(default)]
705 pub model: Option<String>,
706 #[serde(default)]
708 pub dimensions: Option<usize>,
709 #[serde(default)]
711 pub base_url: Option<String>,
712 #[serde(default)]
714 pub cache_size: Option<usize>,
715}
716
717impl EmbeddingConfig {
718 pub fn build(&self) -> ArgentorResult<Arc<dyn EmbeddingProvider>> {
722 let mut provider: Box<dyn EmbeddingProvider> = match self.provider.as_str() {
723 "openai" => {
724 let mut p = if let Some(ref url) = self.base_url {
725 OpenAiEmbeddingProvider::with_base_url(&self.api_key, self.model.clone(), url)
726 } else {
727 OpenAiEmbeddingProvider::new(&self.api_key, self.model.clone())
728 };
729 if let Some(dim) = self.dimensions {
730 p = p.with_dimensions(dim);
731 }
732 Box::new(p)
733 }
734 "cohere" => {
735 let mut p = CohereEmbeddingProvider::new(&self.api_key, self.model.clone());
736 if let Some(dim) = self.dimensions {
737 p = p.with_dimensions(dim);
738 }
739 Box::new(p)
740 }
741 "voyage" => {
742 let mut p = VoyageEmbeddingProvider::new(&self.api_key, self.model.clone());
743 if let Some(dim) = self.dimensions {
744 p = p.with_dimensions(dim);
745 }
746 Box::new(p)
747 }
748 "local" => {
749 let dim = self.dimensions.unwrap_or(256);
750 Box::new(LocalEmbedding::new(dim))
751 }
752 other => {
753 return Err(ArgentorError::Config(format!(
754 "Unknown embedding provider: {other}"
755 )));
756 }
757 };
758
759 let _ = &mut provider; let arc: Arc<dyn EmbeddingProvider> = Arc::from(provider);
764
765 if let Some(cache_size) = self.cache_size {
767 Ok(Arc::new(CachedEmbeddingProvider::new(arc, cache_size)))
768 } else {
769 Ok(arc)
770 }
771 }
772}
773
774impl Default for EmbeddingConfig {
775 fn default() -> Self {
776 Self {
777 provider: "local".to_string(),
778 api_key: String::new(),
779 model: None,
780 dimensions: None,
781 base_url: None,
782 cache_size: None,
783 }
784 }
785}
786
787#[cfg_attr(all(feature = "http-embeddings", not(test)), allow(dead_code))]
797fn stub_embedding(text: &str, dimensions: usize) -> Vec<f32> {
798 let dim = dimensions.max(1);
799 let mut v = vec![0.0f32; dim];
800 for (i, b) in text.bytes().enumerate() {
801 v[i % dim] += (b as f32) / 255.0;
802 }
803 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
804 if norm > 0.0 {
805 for x in &mut v {
806 *x /= norm;
807 }
808 }
809 v
810}
811
812pub struct JinaEmbeddingProvider {
823 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
824 api_key: String,
825 model: String,
826 dimensions: usize,
827 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
828 base_url: String,
829}
830
831impl JinaEmbeddingProvider {
832 pub fn new(api_key: impl Into<String>) -> Self {
834 Self::with_model(api_key, "jina-embeddings-v3", 1024)
835 }
836
837 pub fn with_model(
839 api_key: impl Into<String>,
840 model: impl Into<String>,
841 dimensions: usize,
842 ) -> Self {
843 Self {
844 api_key: api_key.into(),
845 model: model.into(),
846 dimensions,
847 base_url: "https://api.jina.ai/v1/embeddings".to_string(),
848 }
849 }
850
851 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
853 self.base_url = base_url.into();
854 self
855 }
856
857 pub fn model(&self) -> &str {
859 &self.model
860 }
861
862 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
864 serde_json::json!({
865 "model": self.model,
866 "input": texts,
867 })
868 }
869}
870
871#[async_trait]
872impl EmbeddingProvider for JinaEmbeddingProvider {
873 #[cfg(feature = "http-embeddings")]
874 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
875 let client = reqwest::Client::new();
876 let payload = self.build_payload(&[text.to_string()]);
877 let response = client
878 .post(&self.base_url)
879 .header("Authorization", format!("Bearer {}", self.api_key))
880 .json(&payload)
881 .send()
882 .await
883 .map_err(|e| ArgentorError::Http(format!("Jina embedding request failed: {e}")))?;
884
885 let status = response.status();
886 if !status.is_success() {
887 let body = response.text().await.unwrap_or_default();
888 return Err(ArgentorError::Http(format!(
889 "Jina API error {status}: {body}"
890 )));
891 }
892
893 let json: serde_json::Value = response
894 .json()
895 .await
896 .map_err(|e| ArgentorError::Http(format!("Failed to read Jina response body: {e}")))?;
897
898 parse_openai_embedding_response(&json)
900 }
901
902 #[cfg(not(feature = "http-embeddings"))]
903 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
904 Ok(stub_embedding(text, self.dimensions))
905 }
906
907 fn dimension(&self) -> usize {
908 self.dimensions
909 }
910}
911
912pub struct MistralEmbedProvider {
921 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
922 api_key: String,
923 model: String,
924 dimensions: usize,
925 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
926 base_url: String,
927}
928
929impl MistralEmbedProvider {
930 pub fn new(api_key: impl Into<String>) -> Self {
932 Self::with_model(api_key, "mistral-embed", 1024)
933 }
934
935 pub fn with_model(
937 api_key: impl Into<String>,
938 model: impl Into<String>,
939 dimensions: usize,
940 ) -> Self {
941 Self {
942 api_key: api_key.into(),
943 model: model.into(),
944 dimensions,
945 base_url: "https://api.mistral.ai/v1/embeddings".to_string(),
946 }
947 }
948
949 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
951 self.base_url = base_url.into();
952 self
953 }
954
955 pub fn model(&self) -> &str {
957 &self.model
958 }
959
960 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
962 serde_json::json!({
963 "model": self.model,
964 "input": texts,
965 })
966 }
967}
968
969#[async_trait]
970impl EmbeddingProvider for MistralEmbedProvider {
971 #[cfg(feature = "http-embeddings")]
972 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
973 let client = reqwest::Client::new();
974 let payload = self.build_payload(&[text.to_string()]);
975 let response = client
976 .post(&self.base_url)
977 .header("Authorization", format!("Bearer {}", self.api_key))
978 .json(&payload)
979 .send()
980 .await
981 .map_err(|e| ArgentorError::Http(format!("Mistral embedding request failed: {e}")))?;
982
983 let status = response.status();
984 if !status.is_success() {
985 let body = response.text().await.unwrap_or_default();
986 return Err(ArgentorError::Http(format!(
987 "Mistral API error {status}: {body}"
988 )));
989 }
990
991 let json: serde_json::Value = response.json().await.map_err(|e| {
992 ArgentorError::Http(format!("Failed to read Mistral response body: {e}"))
993 })?;
994
995 parse_openai_embedding_response(&json)
996 }
997
998 #[cfg(not(feature = "http-embeddings"))]
999 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1000 Ok(stub_embedding(text, self.dimensions))
1001 }
1002
1003 fn dimension(&self) -> usize {
1004 self.dimensions
1005 }
1006}
1007
1008pub struct NomicEmbedProvider {
1018 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1019 api_key: String,
1020 model: String,
1021 dimensions: usize,
1022 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1023 base_url: String,
1024 task_type: String,
1025}
1026
1027impl NomicEmbedProvider {
1028 pub fn new(api_key: impl Into<String>) -> Self {
1030 Self::with_model(api_key, "nomic-embed-text-v1.5", 768)
1031 }
1032
1033 pub fn with_model(
1035 api_key: impl Into<String>,
1036 model: impl Into<String>,
1037 dimensions: usize,
1038 ) -> Self {
1039 Self {
1040 api_key: api_key.into(),
1041 model: model.into(),
1042 dimensions,
1043 base_url: "https://api-atlas.nomic.ai/v1/embedding/text".to_string(),
1044 task_type: "search_document".to_string(),
1045 }
1046 }
1047
1048 pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
1050 self.task_type = task_type.into();
1051 self
1052 }
1053
1054 pub fn model(&self) -> &str {
1056 &self.model
1057 }
1058
1059 pub fn task_type(&self) -> &str {
1061 &self.task_type
1062 }
1063
1064 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1066 serde_json::json!({
1067 "model": self.model,
1068 "texts": texts,
1069 "task_type": self.task_type,
1070 })
1071 }
1072}
1073
1074#[async_trait]
1075impl EmbeddingProvider for NomicEmbedProvider {
1076 #[cfg(feature = "http-embeddings")]
1077 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1078 let client = reqwest::Client::new();
1079 let payload = self.build_payload(&[text.to_string()]);
1080 let response = client
1081 .post(&self.base_url)
1082 .header("Authorization", format!("Bearer {}", self.api_key))
1083 .json(&payload)
1084 .send()
1085 .await
1086 .map_err(|e| ArgentorError::Http(format!("Nomic embedding request failed: {e}")))?;
1087
1088 let status = response.status();
1089 if !status.is_success() {
1090 let body = response.text().await.unwrap_or_default();
1091 return Err(ArgentorError::Http(format!(
1092 "Nomic API error {status}: {body}"
1093 )));
1094 }
1095
1096 let json: serde_json::Value = response
1097 .json()
1098 .await
1099 .map_err(|e| ArgentorError::Http(format!("Failed to read Nomic response body: {e}")))?;
1100
1101 let embeddings = json
1103 .get("embeddings")
1104 .and_then(|v| v.as_array())
1105 .ok_or_else(|| {
1106 ArgentorError::Agent("Nomic response missing 'embeddings' array".to_string())
1107 })?;
1108 let first = embeddings.first().ok_or_else(|| {
1109 ArgentorError::Agent("Nomic response contains no embedding vectors".to_string())
1110 })?;
1111 let vec: Vec<f32> = serde_json::from_value(first.clone()).map_err(|e| {
1112 ArgentorError::Agent(format!("Failed to parse Nomic embedding vector: {e}"))
1113 })?;
1114 Ok(vec)
1115 }
1116
1117 #[cfg(not(feature = "http-embeddings"))]
1118 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1119 Ok(stub_embedding(text, self.dimensions))
1120 }
1121
1122 fn dimension(&self) -> usize {
1123 self.dimensions
1124 }
1125}
1126
1127pub struct SentenceTransformersProvider {
1137 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1138 api_key: String,
1139 model: String,
1140 dimensions: usize,
1141 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1142 base_url: String,
1143}
1144
1145impl SentenceTransformersProvider {
1146 pub fn new(api_key: impl Into<String>) -> Self {
1148 Self::with_model(api_key, "sentence-transformers/all-MiniLM-L6-v2", 384)
1149 }
1150
1151 pub fn with_model(
1153 api_key: impl Into<String>,
1154 model: impl Into<String>,
1155 dimensions: usize,
1156 ) -> Self {
1157 let model = model.into();
1158 let base_url =
1159 format!("https://api-inference.huggingface.co/pipeline/feature-extraction/{model}");
1160 Self {
1161 api_key: api_key.into(),
1162 model,
1163 dimensions,
1164 base_url,
1165 }
1166 }
1167
1168 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1170 self.base_url = base_url.into();
1171 self
1172 }
1173
1174 pub fn model(&self) -> &str {
1176 &self.model
1177 }
1178
1179 pub fn default_dimensions(model: &str) -> usize {
1181 match model {
1182 "sentence-transformers/all-MiniLM-L6-v2" => 384,
1183 "sentence-transformers/all-mpnet-base-v2"
1184 | "sentence-transformers/multi-qa-mpnet-base-dot-v1" => 768,
1185 _ => 384,
1186 }
1187 }
1188
1189 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1191 serde_json::json!({
1192 "inputs": texts,
1193 "options": { "wait_for_model": true },
1194 })
1195 }
1196}
1197
1198#[async_trait]
1199impl EmbeddingProvider for SentenceTransformersProvider {
1200 #[cfg(feature = "http-embeddings")]
1201 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1202 let client = reqwest::Client::new();
1203 let payload = self.build_payload(&[text.to_string()]);
1204 let response = client
1205 .post(&self.base_url)
1206 .header("Authorization", format!("Bearer {}", self.api_key))
1207 .json(&payload)
1208 .send()
1209 .await
1210 .map_err(|e| {
1211 ArgentorError::Http(format!("HuggingFace embedding request failed: {e}"))
1212 })?;
1213
1214 let status = response.status();
1215 if !status.is_success() {
1216 let body = response.text().await.unwrap_or_default();
1217 return Err(ArgentorError::Http(format!(
1218 "HuggingFace API error {status}: {body}"
1219 )));
1220 }
1221
1222 let json: serde_json::Value = response.json().await.map_err(|e| {
1223 ArgentorError::Http(format!("Failed to read HuggingFace response body: {e}"))
1224 })?;
1225
1226 match &json {
1228 serde_json::Value::Array(arr)
1229 if arr.first().is_some_and(serde_json::Value::is_array) =>
1230 {
1231 let first = arr.first().cloned().ok_or_else(|| {
1232 ArgentorError::Agent("HuggingFace response empty".to_string())
1233 })?;
1234 serde_json::from_value(first)
1235 .map_err(|e| ArgentorError::Agent(format!("Failed to parse HF vector: {e}")))
1236 }
1237 serde_json::Value::Array(_) => serde_json::from_value(json)
1238 .map_err(|e| ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))),
1239 _ => Err(ArgentorError::Agent(
1240 "HuggingFace response is not an array".to_string(),
1241 )),
1242 }
1243 }
1244
1245 #[cfg(not(feature = "http-embeddings"))]
1246 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1247 Ok(stub_embedding(text, self.dimensions))
1248 }
1249
1250 fn dimension(&self) -> usize {
1251 self.dimensions
1252 }
1253}
1254
1255pub struct TogetherEmbedProvider {
1264 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1265 api_key: String,
1266 model: String,
1267 dimensions: usize,
1268 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1269 base_url: String,
1270}
1271
1272impl TogetherEmbedProvider {
1273 pub fn new(api_key: impl Into<String>) -> Self {
1275 Self::with_model(api_key, "togethercomputer/m2-bert-80M-32k-retrieval", 768)
1276 }
1277
1278 pub fn with_model(
1280 api_key: impl Into<String>,
1281 model: impl Into<String>,
1282 dimensions: usize,
1283 ) -> Self {
1284 Self {
1285 api_key: api_key.into(),
1286 model: model.into(),
1287 dimensions,
1288 base_url: "https://api.together.xyz/v1/embeddings".to_string(),
1289 }
1290 }
1291
1292 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1294 self.base_url = base_url.into();
1295 self
1296 }
1297
1298 pub fn model(&self) -> &str {
1300 &self.model
1301 }
1302
1303 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1305 serde_json::json!({
1306 "model": self.model,
1307 "input": texts,
1308 })
1309 }
1310}
1311
1312#[async_trait]
1313impl EmbeddingProvider for TogetherEmbedProvider {
1314 #[cfg(feature = "http-embeddings")]
1315 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1316 let client = reqwest::Client::new();
1317 let payload = self.build_payload(&[text.to_string()]);
1318 let response = client
1319 .post(&self.base_url)
1320 .header("Authorization", format!("Bearer {}", self.api_key))
1321 .json(&payload)
1322 .send()
1323 .await
1324 .map_err(|e| ArgentorError::Http(format!("Together embedding request failed: {e}")))?;
1325
1326 let status = response.status();
1327 if !status.is_success() {
1328 let body = response.text().await.unwrap_or_default();
1329 return Err(ArgentorError::Http(format!(
1330 "Together API error {status}: {body}"
1331 )));
1332 }
1333
1334 let json: serde_json::Value = response.json().await.map_err(|e| {
1335 ArgentorError::Http(format!("Failed to read Together response body: {e}"))
1336 })?;
1337
1338 parse_openai_embedding_response(&json)
1339 }
1340
1341 #[cfg(not(feature = "http-embeddings"))]
1342 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1343 Ok(stub_embedding(text, self.dimensions))
1344 }
1345
1346 fn dimension(&self) -> usize {
1347 self.dimensions
1348 }
1349}
1350
1351pub struct CohereEmbedV4Provider {
1363 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1364 api_key: String,
1365 model: String,
1366 dimensions: usize,
1367 input_type: String,
1368 #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1369 base_url: String,
1370}
1371
1372impl CohereEmbedV4Provider {
1373 pub fn new(api_key: impl Into<String>) -> Self {
1375 Self::with_model(api_key, "embed-english-v3.0", 1024)
1376 }
1377
1378 pub fn with_model(
1380 api_key: impl Into<String>,
1381 model: impl Into<String>,
1382 dimensions: usize,
1383 ) -> Self {
1384 Self {
1385 api_key: api_key.into(),
1386 model: model.into(),
1387 dimensions,
1388 input_type: "search_document".to_string(),
1389 base_url: "https://api.cohere.com/v2/embed".to_string(),
1390 }
1391 }
1392
1393 pub fn for_search_document(mut self) -> Self {
1395 self.input_type = "search_document".to_string();
1396 self
1397 }
1398
1399 pub fn for_search_query(mut self) -> Self {
1401 self.input_type = "search_query".to_string();
1402 self
1403 }
1404
1405 pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
1407 self.input_type = input_type.into();
1408 self
1409 }
1410
1411 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1413 self.base_url = base_url.into();
1414 self
1415 }
1416
1417 pub fn model(&self) -> &str {
1419 &self.model
1420 }
1421
1422 pub fn input_type(&self) -> &str {
1424 &self.input_type
1425 }
1426
1427 pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1429 serde_json::json!({
1430 "model": self.model,
1431 "texts": texts,
1432 "input_type": self.input_type,
1433 "embedding_types": ["float"],
1434 })
1435 }
1436}
1437
1438#[async_trait]
1439impl EmbeddingProvider for CohereEmbedV4Provider {
1440 #[cfg(feature = "http-embeddings")]
1441 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1442 let client = reqwest::Client::new();
1443 let payload = self.build_payload(&[text.to_string()]);
1444 let response = client
1445 .post(&self.base_url)
1446 .header("Authorization", format!("Bearer {}", self.api_key))
1447 .json(&payload)
1448 .send()
1449 .await
1450 .map_err(|e| ArgentorError::Http(format!("Cohere v4 embedding request failed: {e}")))?;
1451
1452 let status = response.status();
1453 if !status.is_success() {
1454 let body = response.text().await.unwrap_or_default();
1455 return Err(ArgentorError::Http(format!(
1456 "Cohere v4 API error {status}: {body}"
1457 )));
1458 }
1459
1460 let json: serde_json::Value = response.json().await.map_err(|e| {
1461 ArgentorError::Http(format!("Failed to read Cohere v4 response body: {e}"))
1462 })?;
1463
1464 parse_cohere_embedding_response(&json)
1465 }
1466
1467 #[cfg(not(feature = "http-embeddings"))]
1468 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1469 Ok(stub_embedding(text, self.dimensions))
1470 }
1471
1472 fn dimension(&self) -> usize {
1473 self.dimensions
1474 }
1475}
1476
1477#[cfg(test)]
1482#[allow(clippy::unwrap_used, clippy::expect_used)]
1483mod tests {
1484 use super::*;
1485
1486 #[test]
1489 fn test_openai_provider_default_model() {
1490 let p = OpenAiEmbeddingProvider::new("sk-test", None);
1491 assert_eq!(p.model(), "text-embedding-3-small");
1492 assert_eq!(p.dimension(), 1536);
1493 }
1494
1495 #[test]
1496 fn test_openai_provider_large_model() {
1497 let p = OpenAiEmbeddingProvider::new("sk-test", Some("text-embedding-3-large".into()));
1498 assert_eq!(p.dimension(), 3072);
1499 }
1500
1501 #[test]
1502 fn test_openai_provider_custom_dimensions() {
1503 let p = OpenAiEmbeddingProvider::new("sk-test", None).with_dimensions(512);
1504 assert_eq!(p.dimension(), 512);
1505 }
1506
1507 #[test]
1508 fn test_openai_provider_custom_base_url() {
1509 let p = OpenAiEmbeddingProvider::with_base_url(
1510 "sk-test",
1511 None,
1512 "https://my-azure.openai.azure.com/openai/deployments/embed",
1513 );
1514 assert_eq!(p.dimension(), 1536);
1515 }
1516
1517 #[cfg(not(feature = "http-embeddings"))]
1518 #[tokio::test]
1519 async fn test_openai_provider_returns_feature_error() {
1520 let p = OpenAiEmbeddingProvider::new("sk-test", None);
1521 let err = p.embed("hello").await.unwrap_err();
1522 let msg = format!("{err}");
1523 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1524 }
1525
1526 #[test]
1527 fn test_cohere_provider_default() {
1528 let p = CohereEmbeddingProvider::new("key", None);
1529 assert_eq!(p.model(), "embed-english-v3.0");
1530 assert_eq!(p.dimension(), 1024);
1531 assert_eq!(p.input_type(), "search_document");
1532 }
1533
1534 #[test]
1535 fn test_cohere_provider_query_input_type() {
1536 let p = CohereEmbeddingProvider::new("key", None).with_input_type("search_query");
1537 assert_eq!(p.input_type(), "search_query");
1538 }
1539
1540 #[test]
1541 fn test_cohere_provider_light_model() {
1542 let p = CohereEmbeddingProvider::new("key", Some("embed-english-light-v3.0".into()));
1543 assert_eq!(p.dimension(), 384);
1544 }
1545
1546 #[cfg(not(feature = "http-embeddings"))]
1547 #[tokio::test]
1548 async fn test_cohere_provider_returns_feature_error() {
1549 let p = CohereEmbeddingProvider::new("key", None);
1550 let err = p.embed("hello").await.unwrap_err();
1551 let msg = format!("{err}");
1552 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1553 }
1554
1555 #[test]
1556 fn test_voyage_provider_default() {
1557 let p = VoyageEmbeddingProvider::new("key", None);
1558 assert_eq!(p.model(), "voyage-2");
1559 assert_eq!(p.dimension(), 1024);
1560 }
1561
1562 #[test]
1563 fn test_voyage_provider_code_model() {
1564 let p = VoyageEmbeddingProvider::new("key", Some("voyage-code-2".into()));
1565 assert_eq!(p.dimension(), 1536);
1566 }
1567
1568 #[cfg(not(feature = "http-embeddings"))]
1569 #[tokio::test]
1570 async fn test_voyage_provider_returns_feature_error() {
1571 let p = VoyageEmbeddingProvider::new("key", None);
1572 let err = p.embed("hello").await.unwrap_err();
1573 let msg = format!("{err}");
1574 assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1575 }
1576
1577 #[test]
1580 fn test_parse_openai_embedding_response_valid() {
1581 let json = serde_json::json!({
1582 "data": [
1583 {
1584 "embedding": [0.1, 0.2, 0.3, 0.4],
1585 "index": 0
1586 }
1587 ],
1588 "model": "text-embedding-3-small"
1589 });
1590 let result = parse_openai_embedding_response(&json).unwrap();
1591 assert_eq!(result, vec![0.1, 0.2, 0.3, 0.4]);
1592 }
1593
1594 #[test]
1595 fn test_parse_openai_embedding_response_empty_data() {
1596 let json = serde_json::json!({
1597 "data": [],
1598 "model": "text-embedding-3-small"
1599 });
1600 let err = parse_openai_embedding_response(&json).unwrap_err();
1601 let msg = format!("{err}");
1602 assert!(msg.contains("no embedding data"), "got: {msg}");
1603 }
1604
1605 #[test]
1606 fn test_parse_openai_embedding_response_invalid_shape() {
1607 let json = serde_json::json!({ "error": "bad request" });
1608 let err = parse_openai_embedding_response(&json).unwrap_err();
1609 let msg = format!("{err}");
1610 assert!(msg.contains("Failed to parse"), "got: {msg}");
1611 }
1612
1613 #[test]
1614 fn test_parse_openai_embedding_response_multiple_picks_first() {
1615 let json = serde_json::json!({
1616 "data": [
1617 { "embedding": [1.0, 2.0], "index": 0 },
1618 { "embedding": [3.0, 4.0], "index": 1 }
1619 ],
1620 "model": "text-embedding-3-small"
1621 });
1622 let result = parse_openai_embedding_response(&json).unwrap();
1623 assert_eq!(result, vec![1.0, 2.0]);
1624 }
1625
1626 #[test]
1627 fn test_parse_cohere_embedding_response_valid() {
1628 let json = serde_json::json!({
1629 "embeddings": {
1630 "float": [
1631 [0.5, 0.6, 0.7]
1632 ]
1633 }
1634 });
1635 let result = parse_cohere_embedding_response(&json).unwrap();
1636 assert_eq!(result, vec![0.5, 0.6, 0.7]);
1637 }
1638
1639 #[test]
1640 fn test_parse_cohere_embedding_response_empty_float() {
1641 let json = serde_json::json!({
1642 "embeddings": {
1643 "float": []
1644 }
1645 });
1646 let err = parse_cohere_embedding_response(&json).unwrap_err();
1647 let msg = format!("{err}");
1648 assert!(msg.contains("no float embeddings"), "got: {msg}");
1649 }
1650
1651 #[test]
1652 fn test_parse_cohere_embedding_response_invalid_shape() {
1653 let json = serde_json::json!({ "message": "unauthorized" });
1654 let err = parse_cohere_embedding_response(&json).unwrap_err();
1655 let msg = format!("{err}");
1656 assert!(msg.contains("Failed to parse"), "got: {msg}");
1657 }
1658
1659 #[test]
1660 fn test_parse_cohere_embedding_response_missing_float_key() {
1661 let json = serde_json::json!({
1663 "embeddings": {}
1664 });
1665 let err = parse_cohere_embedding_response(&json).unwrap_err();
1666 let msg = format!("{err}");
1667 assert!(msg.contains("no float embeddings"), "got: {msg}");
1668 }
1669
1670 #[test]
1671 fn test_parse_voyage_embedding_response_valid() {
1672 let json = serde_json::json!({
1673 "data": [
1674 {
1675 "embedding": [0.9, 0.8, 0.7, 0.6, 0.5],
1676 "index": 0
1677 }
1678 ]
1679 });
1680 let result = parse_voyage_embedding_response(&json).unwrap();
1681 assert_eq!(result, vec![0.9, 0.8, 0.7, 0.6, 0.5]);
1682 }
1683
1684 #[test]
1685 fn test_parse_voyage_embedding_response_empty_data() {
1686 let json = serde_json::json!({ "data": [] });
1687 let err = parse_voyage_embedding_response(&json).unwrap_err();
1688 let msg = format!("{err}");
1689 assert!(msg.contains("no embedding data"), "got: {msg}");
1690 }
1691
1692 #[test]
1693 fn test_parse_voyage_embedding_response_invalid_shape() {
1694 let json = serde_json::json!({ "error": "invalid key" });
1695 let err = parse_voyage_embedding_response(&json).unwrap_err();
1696 let msg = format!("{err}");
1697 assert!(msg.contains("Failed to parse"), "got: {msg}");
1698 }
1699
1700 #[tokio::test]
1703 async fn test_cache_hit() {
1704 let local = Arc::new(LocalEmbedding::new(64));
1705 let cached = CachedEmbeddingProvider::new(local, 100);
1706
1707 let v1 = cached.embed("hello world").await.unwrap();
1708 let v2 = cached.embed("hello world").await.unwrap();
1709 assert_eq!(v1, v2);
1710
1711 let stats = cached.cache_stats().await;
1712 assert_eq!(stats.hits, 1);
1713 assert_eq!(stats.misses, 1);
1714 assert_eq!(stats.size, 1);
1715 }
1716
1717 #[tokio::test]
1718 async fn test_cache_miss_different_texts() {
1719 let local = Arc::new(LocalEmbedding::new(64));
1720 let cached = CachedEmbeddingProvider::new(local, 100);
1721
1722 let _ = cached.embed("alpha").await.unwrap();
1723 let _ = cached.embed("bravo").await.unwrap();
1724
1725 let stats = cached.cache_stats().await;
1726 assert_eq!(stats.misses, 2);
1727 assert_eq!(stats.hits, 0);
1728 assert_eq!(stats.size, 2);
1729 }
1730
1731 #[tokio::test]
1732 async fn test_cache_eviction() {
1733 let local = Arc::new(LocalEmbedding::new(64));
1734 let cached = CachedEmbeddingProvider::new(local, 2);
1735
1736 let _ = cached.embed("one").await.unwrap();
1737 let _ = cached.embed("two").await.unwrap();
1738 let _ = cached.embed("three").await.unwrap();
1739
1740 let stats = cached.cache_stats().await;
1741 assert!(stats.size <= 2, "size={} should be <= 2", stats.size);
1743 assert_eq!(stats.misses, 3);
1744 }
1745
1746 #[tokio::test]
1747 async fn test_cache_clear() {
1748 let local = Arc::new(LocalEmbedding::new(64));
1749 let cached = CachedEmbeddingProvider::new(local, 100);
1750
1751 let _ = cached.embed("text").await.unwrap();
1752 cached.clear().await;
1753
1754 let stats = cached.cache_stats().await;
1755 assert_eq!(stats.size, 0);
1756 }
1757
1758 #[tokio::test]
1759 async fn test_cache_dimension_delegates() {
1760 let local = Arc::new(LocalEmbedding::new(128));
1761 let cached = CachedEmbeddingProvider::new(local, 10);
1762 assert_eq!(cached.dimension(), 128);
1763 }
1764
1765 #[tokio::test]
1768 async fn test_batch_embed() {
1769 let local = Arc::new(LocalEmbedding::new(64));
1770 let batch = BatchEmbeddingProvider::new(local);
1771
1772 let results = batch
1773 .embed_batch(&["hello", "world", "test"])
1774 .await
1775 .unwrap();
1776 assert_eq!(results.len(), 3);
1777 for v in &results {
1778 assert_eq!(v.len(), 64);
1779 }
1780 }
1781
1782 #[tokio::test]
1783 async fn test_batch_single_embed_delegates() {
1784 let local = Arc::new(LocalEmbedding::new(64));
1785 let batch = BatchEmbeddingProvider::new(local);
1786
1787 let v = batch.embed("hello").await.unwrap();
1788 assert_eq!(v.len(), 64);
1789 }
1790
1791 #[tokio::test]
1792 async fn test_batch_empty() {
1793 let local = Arc::new(LocalEmbedding::new(64));
1794 let batch = BatchEmbeddingProvider::new(local);
1795
1796 let results = batch.embed_batch(&[]).await.unwrap();
1797 assert!(results.is_empty());
1798 }
1799
1800 #[tokio::test]
1801 async fn test_batch_dimension_delegates() {
1802 let local = Arc::new(LocalEmbedding::new(200));
1803 let batch = BatchEmbeddingProvider::new(local);
1804 assert_eq!(batch.dimension(), 200);
1805 }
1806
1807 #[test]
1810 fn test_factory_create_local() {
1811 let p = EmbeddingProviderFactory::create("local", "", None).unwrap();
1812 assert_eq!(p.dimension(), 256);
1813 }
1814
1815 #[test]
1816 fn test_factory_create_local_custom_dim() {
1817 let p = EmbeddingProviderFactory::create("local", "", Some("128".into())).unwrap();
1818 assert_eq!(p.dimension(), 128);
1819 }
1820
1821 #[test]
1822 fn test_factory_create_openai() {
1823 let p = EmbeddingProviderFactory::create("openai", "sk-test", None).unwrap();
1824 assert_eq!(p.dimension(), 1536);
1825 }
1826
1827 #[test]
1828 fn test_factory_create_cohere() {
1829 let p = EmbeddingProviderFactory::create("cohere", "key", None).unwrap();
1830 assert_eq!(p.dimension(), 1024);
1831 }
1832
1833 #[test]
1834 fn test_factory_create_voyage() {
1835 let p = EmbeddingProviderFactory::create("voyage", "key", None).unwrap();
1836 assert_eq!(p.dimension(), 1024);
1837 }
1838
1839 #[test]
1840 fn test_factory_unknown_provider() {
1841 let result = EmbeddingProviderFactory::create("unknown", "", None);
1842 assert!(result.is_err(), "Unknown provider should return Err");
1843 }
1844
1845 #[test]
1846 fn test_factory_available_providers() {
1847 let names = EmbeddingProviderFactory::available_providers();
1848 assert!(names.contains(&"openai"));
1849 assert!(names.contains(&"cohere"));
1850 assert!(names.contains(&"voyage"));
1851 assert!(names.contains(&"local"));
1852 }
1853
1854 #[test]
1857 fn test_config_default() {
1858 let cfg = EmbeddingConfig::default();
1859 assert_eq!(cfg.provider, "local");
1860 assert!(cfg.api_key.is_empty());
1861 assert!(cfg.model.is_none());
1862 assert!(cfg.dimensions.is_none());
1863 assert!(cfg.base_url.is_none());
1864 assert!(cfg.cache_size.is_none());
1865 }
1866
1867 #[test]
1868 fn test_config_serialize_deserialize() {
1869 let cfg = EmbeddingConfig {
1870 provider: "openai".to_string(),
1871 api_key: "sk-123".to_string(),
1872 model: Some("text-embedding-3-small".to_string()),
1873 dimensions: Some(1536),
1874 base_url: None,
1875 cache_size: Some(500),
1876 };
1877 let json = serde_json::to_string(&cfg).unwrap();
1878 let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
1879 assert_eq!(parsed.provider, "openai");
1880 assert_eq!(parsed.api_key, "sk-123");
1881 assert_eq!(parsed.dimensions, Some(1536));
1882 assert_eq!(parsed.cache_size, Some(500));
1883 }
1884
1885 #[test]
1886 fn test_config_deserialize_minimal() {
1887 let json = r#"{"provider":"local"}"#;
1888 let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
1889 assert_eq!(cfg.provider, "local");
1890 assert!(cfg.api_key.is_empty());
1891 }
1892
1893 #[tokio::test]
1894 async fn test_config_build_local() {
1895 let cfg = EmbeddingConfig::default();
1896 let provider = cfg.build().unwrap();
1897 assert_eq!(provider.dimension(), 256);
1898 let v = provider.embed("test text").await.unwrap();
1899 assert_eq!(v.len(), 256);
1900 }
1901
1902 #[tokio::test]
1903 async fn test_config_build_local_with_cache() {
1904 let cfg = EmbeddingConfig {
1905 provider: "local".to_string(),
1906 cache_size: Some(50),
1907 ..Default::default()
1908 };
1909 let provider = cfg.build().unwrap();
1910 assert_eq!(provider.dimension(), 256);
1912 let v1 = provider.embed("cached text").await.unwrap();
1914 let v2 = provider.embed("cached text").await.unwrap();
1915 assert_eq!(v1, v2);
1916 }
1917
1918 #[tokio::test]
1919 async fn test_config_build_local_custom_dimensions() {
1920 let cfg = EmbeddingConfig {
1921 provider: "local".to_string(),
1922 dimensions: Some(512),
1923 ..Default::default()
1924 };
1925 let provider = cfg.build().unwrap();
1926 assert_eq!(provider.dimension(), 512);
1927 }
1928
1929 #[test]
1930 fn test_config_build_unknown_provider() {
1931 let cfg = EmbeddingConfig {
1932 provider: "imaginary".to_string(),
1933 ..Default::default()
1934 };
1935 assert!(cfg.build().is_err());
1936 }
1937
1938 #[test]
1941 fn test_fnv_hash_deterministic() {
1942 let h1 = fnv1a_hash(b"hello world");
1943 let h2 = fnv1a_hash(b"hello world");
1944 assert_eq!(h1, h2);
1945 }
1946
1947 #[test]
1948 fn test_fnv_hash_different_inputs() {
1949 let h1 = fnv1a_hash(b"alpha");
1950 let h2 = fnv1a_hash(b"bravo");
1951 assert_ne!(h1, h2);
1952 }
1953
1954 #[test]
1959 fn test_stub_embedding_length() {
1960 let v = stub_embedding("hello", 128);
1961 assert_eq!(v.len(), 128);
1962 }
1963
1964 #[test]
1965 fn test_stub_embedding_deterministic() {
1966 let v1 = stub_embedding("same input", 64);
1967 let v2 = stub_embedding("same input", 64);
1968 assert_eq!(v1, v2);
1969 }
1970
1971 #[test]
1972 fn test_stub_embedding_normalized() {
1973 let v = stub_embedding("the quick brown fox", 256);
1974 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1975 assert!((norm - 1.0).abs() < 0.01, "norm={norm}");
1976 }
1977
1978 #[test]
1979 fn test_stub_embedding_different_inputs_differ() {
1980 let a = stub_embedding("alpha", 64);
1981 let b = stub_embedding("bravo", 64);
1982 assert_ne!(a, b);
1983 }
1984
1985 #[test]
1986 fn test_stub_embedding_empty_text_zeroes() {
1987 let v = stub_embedding("", 32);
1988 assert_eq!(v.len(), 32);
1989 assert!(v.iter().all(|&x| x == 0.0));
1990 }
1991
1992 #[test]
1993 fn test_stub_embedding_zero_dimension_safe() {
1994 let v = stub_embedding("hi", 0);
1996 assert_eq!(v.len(), 1);
1997 }
1998
1999 #[test]
2004 fn test_jina_default_construction() {
2005 let p = JinaEmbeddingProvider::new("jina-key");
2006 assert_eq!(p.model(), "jina-embeddings-v3");
2007 assert_eq!(p.dimension(), 1024);
2008 }
2009
2010 #[test]
2011 fn test_jina_with_model_clip() {
2012 let p = JinaEmbeddingProvider::with_model("k", "jina-clip-v2", 768);
2013 assert_eq!(p.model(), "jina-clip-v2");
2014 assert_eq!(p.dimension(), 768);
2015 }
2016
2017 #[test]
2018 fn test_jina_with_base_url() {
2019 let p = JinaEmbeddingProvider::new("k").with_base_url("https://custom.jina/v1");
2020 assert_eq!(p.model(), "jina-embeddings-v3");
2022 }
2023
2024 #[test]
2025 fn test_jina_build_payload_shape() {
2026 let p = JinaEmbeddingProvider::new("k");
2027 let payload = p.build_payload(&["hello".to_string(), "world".to_string()]);
2028 assert_eq!(payload["model"], "jina-embeddings-v3");
2029 assert_eq!(payload["input"][0], "hello");
2030 assert_eq!(payload["input"][1], "world");
2031 }
2032
2033 #[tokio::test]
2034 async fn test_jina_embed_length_matches_dimension() {
2035 let p = JinaEmbeddingProvider::new("k");
2036 #[cfg(not(feature = "http-embeddings"))]
2037 {
2038 let v = p.embed("hello jina").await.unwrap();
2039 assert_eq!(v.len(), 1024);
2040 }
2041 assert_eq!(p.dimension(), 1024);
2044 }
2045
2046 #[cfg(not(feature = "http-embeddings"))]
2047 #[tokio::test]
2048 async fn test_jina_stub_is_normalized() {
2049 let p = JinaEmbeddingProvider::new("k");
2050 let v = p.embed("some input").await.unwrap();
2051 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2052 assert!((norm - 1.0).abs() < 0.01);
2053 }
2054
2055 #[cfg(not(feature = "http-embeddings"))]
2056 #[tokio::test]
2057 async fn test_jina_stub_deterministic() {
2058 let p = JinaEmbeddingProvider::new("k");
2059 let a = p.embed("consistent").await.unwrap();
2060 let b = p.embed("consistent").await.unwrap();
2061 assert_eq!(a, b);
2062 }
2063
2064 #[test]
2069 fn test_mistral_default_construction() {
2070 let p = MistralEmbedProvider::new("mistral-key");
2071 assert_eq!(p.model(), "mistral-embed");
2072 assert_eq!(p.dimension(), 1024);
2073 }
2074
2075 #[test]
2076 fn test_mistral_with_model_and_dimensions() {
2077 let p = MistralEmbedProvider::with_model("k", "mistral-embed-large", 2048);
2078 assert_eq!(p.model(), "mistral-embed-large");
2079 assert_eq!(p.dimension(), 2048);
2080 }
2081
2082 #[test]
2083 fn test_mistral_build_payload_shape() {
2084 let p = MistralEmbedProvider::new("k");
2085 let payload = p.build_payload(&["alpha".to_string()]);
2086 assert_eq!(payload["model"], "mistral-embed");
2087 assert_eq!(payload["input"][0], "alpha");
2088 }
2089
2090 #[test]
2091 fn test_mistral_with_base_url() {
2092 let p = MistralEmbedProvider::new("k").with_base_url("https://custom.mistral/v1");
2093 assert_eq!(p.dimension(), 1024);
2094 }
2095
2096 #[cfg(not(feature = "http-embeddings"))]
2097 #[tokio::test]
2098 async fn test_mistral_embed_length() {
2099 let p = MistralEmbedProvider::new("k");
2100 let v = p.embed("hello mistral").await.unwrap();
2101 assert_eq!(v.len(), 1024);
2102 }
2103
2104 #[cfg(not(feature = "http-embeddings"))]
2105 #[tokio::test]
2106 async fn test_mistral_stub_normalized() {
2107 let p = MistralEmbedProvider::new("k");
2108 let v = p.embed("normalized?").await.unwrap();
2109 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2110 assert!((norm - 1.0).abs() < 0.01);
2111 }
2112
2113 #[test]
2118 fn test_nomic_default_construction() {
2119 let p = NomicEmbedProvider::new("nomic-key");
2120 assert_eq!(p.model(), "nomic-embed-text-v1.5");
2121 assert_eq!(p.dimension(), 768);
2122 assert_eq!(p.task_type(), "search_document");
2123 }
2124
2125 #[test]
2126 fn test_nomic_with_task_type() {
2127 let p = NomicEmbedProvider::new("k").with_task_type("search_query");
2128 assert_eq!(p.task_type(), "search_query");
2129 }
2130
2131 #[test]
2132 fn test_nomic_build_payload_shape() {
2133 let p = NomicEmbedProvider::new("k").with_task_type("clustering");
2134 let payload = p.build_payload(&["doc a".to_string(), "doc b".to_string()]);
2135 assert_eq!(payload["model"], "nomic-embed-text-v1.5");
2136 assert_eq!(payload["texts"][0], "doc a");
2137 assert_eq!(payload["texts"][1], "doc b");
2138 assert_eq!(payload["task_type"], "clustering");
2139 }
2140
2141 #[test]
2142 fn test_nomic_with_model_custom_dims() {
2143 let p = NomicEmbedProvider::with_model("k", "custom-nomic", 512);
2144 assert_eq!(p.dimension(), 512);
2145 }
2146
2147 #[cfg(not(feature = "http-embeddings"))]
2148 #[tokio::test]
2149 async fn test_nomic_embed_length() {
2150 let p = NomicEmbedProvider::new("k");
2151 let v = p.embed("nomic test").await.unwrap();
2152 assert_eq!(v.len(), 768);
2153 }
2154
2155 #[cfg(not(feature = "http-embeddings"))]
2156 #[tokio::test]
2157 async fn test_nomic_embed_normalized() {
2158 let p = NomicEmbedProvider::new("k");
2159 let v = p.embed("some text").await.unwrap();
2160 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2161 assert!((norm - 1.0).abs() < 0.01);
2162 }
2163
2164 #[test]
2169 fn test_sentence_transformers_default_construction() {
2170 let p = SentenceTransformersProvider::new("hf-key");
2171 assert_eq!(p.model(), "sentence-transformers/all-MiniLM-L6-v2");
2172 assert_eq!(p.dimension(), 384);
2173 }
2174
2175 #[test]
2176 fn test_sentence_transformers_mpnet_dims() {
2177 let dims = SentenceTransformersProvider::default_dimensions(
2178 "sentence-transformers/all-mpnet-base-v2",
2179 );
2180 assert_eq!(dims, 768);
2181 }
2182
2183 #[test]
2184 fn test_sentence_transformers_multi_qa_dims() {
2185 let dims = SentenceTransformersProvider::default_dimensions(
2186 "sentence-transformers/multi-qa-mpnet-base-dot-v1",
2187 );
2188 assert_eq!(dims, 768);
2189 }
2190
2191 #[test]
2192 fn test_sentence_transformers_unknown_model_fallback() {
2193 let dims =
2194 SentenceTransformersProvider::default_dimensions("sentence-transformers/unknown");
2195 assert_eq!(dims, 384);
2196 }
2197
2198 #[test]
2199 fn test_sentence_transformers_with_model() {
2200 let p = SentenceTransformersProvider::with_model(
2201 "k",
2202 "sentence-transformers/all-mpnet-base-v2",
2203 768,
2204 );
2205 assert_eq!(p.model(), "sentence-transformers/all-mpnet-base-v2");
2206 assert_eq!(p.dimension(), 768);
2207 }
2208
2209 #[test]
2210 fn test_sentence_transformers_build_payload_shape() {
2211 let p = SentenceTransformersProvider::new("k");
2212 let payload = p.build_payload(&["hi".to_string()]);
2213 assert_eq!(payload["inputs"][0], "hi");
2214 assert_eq!(payload["options"]["wait_for_model"], true);
2215 }
2216
2217 #[test]
2218 fn test_sentence_transformers_with_base_url() {
2219 let p =
2220 SentenceTransformersProvider::new("k").with_base_url("https://self-hosted.hf/embed");
2221 assert_eq!(p.dimension(), 384);
2222 }
2223
2224 #[cfg(not(feature = "http-embeddings"))]
2225 #[tokio::test]
2226 async fn test_sentence_transformers_embed_length() {
2227 let p = SentenceTransformersProvider::new("k");
2228 let v = p.embed("minilm test").await.unwrap();
2229 assert_eq!(v.len(), 384);
2230 }
2231
2232 #[cfg(not(feature = "http-embeddings"))]
2233 #[tokio::test]
2234 async fn test_sentence_transformers_embed_normalized() {
2235 let p = SentenceTransformersProvider::new("k");
2236 let v = p.embed("some input").await.unwrap();
2237 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2238 assert!((norm - 1.0).abs() < 0.01);
2239 }
2240
2241 #[test]
2246 fn test_together_default_construction() {
2247 let p = TogetherEmbedProvider::new("together-key");
2248 assert_eq!(p.model(), "togethercomputer/m2-bert-80M-32k-retrieval");
2249 assert_eq!(p.dimension(), 768);
2250 }
2251
2252 #[test]
2253 fn test_together_with_model() {
2254 let p = TogetherEmbedProvider::with_model("k", "togethercomputer/custom", 1024);
2255 assert_eq!(p.model(), "togethercomputer/custom");
2256 assert_eq!(p.dimension(), 1024);
2257 }
2258
2259 #[test]
2260 fn test_together_build_payload_shape() {
2261 let p = TogetherEmbedProvider::new("k");
2262 let payload = p.build_payload(&["x".to_string(), "y".to_string()]);
2263 assert_eq!(
2264 payload["model"],
2265 "togethercomputer/m2-bert-80M-32k-retrieval"
2266 );
2267 assert_eq!(payload["input"][0], "x");
2268 assert_eq!(payload["input"][1], "y");
2269 }
2270
2271 #[test]
2272 fn test_together_with_base_url() {
2273 let p = TogetherEmbedProvider::new("k").with_base_url("https://custom.together/v1");
2274 assert_eq!(p.dimension(), 768);
2275 }
2276
2277 #[cfg(not(feature = "http-embeddings"))]
2278 #[tokio::test]
2279 async fn test_together_embed_length() {
2280 let p = TogetherEmbedProvider::new("k");
2281 let v = p.embed("together test").await.unwrap();
2282 assert_eq!(v.len(), 768);
2283 }
2284
2285 #[cfg(not(feature = "http-embeddings"))]
2286 #[tokio::test]
2287 async fn test_together_embed_normalized() {
2288 let p = TogetherEmbedProvider::new("k");
2289 let v = p.embed("text").await.unwrap();
2290 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2291 assert!((norm - 1.0).abs() < 0.01);
2292 }
2293
2294 #[test]
2299 fn test_cohere_v4_default_construction() {
2300 let p = CohereEmbedV4Provider::new("cohere-key");
2301 assert_eq!(p.model(), "embed-english-v3.0");
2302 assert_eq!(p.dimension(), 1024);
2303 assert_eq!(p.input_type(), "search_document");
2304 }
2305
2306 #[test]
2307 fn test_cohere_v4_multilingual_model() {
2308 let p = CohereEmbedV4Provider::with_model("k", "embed-multilingual-v3.0", 1024);
2309 assert_eq!(p.model(), "embed-multilingual-v3.0");
2310 assert_eq!(p.dimension(), 1024);
2311 }
2312
2313 #[test]
2314 fn test_cohere_v4_for_search_document() {
2315 let p = CohereEmbedV4Provider::new("k").for_search_document();
2316 assert_eq!(p.input_type(), "search_document");
2317 }
2318
2319 #[test]
2320 fn test_cohere_v4_for_search_query() {
2321 let p = CohereEmbedV4Provider::new("k").for_search_query();
2322 assert_eq!(p.input_type(), "search_query");
2323 }
2324
2325 #[test]
2326 fn test_cohere_v4_with_input_type() {
2327 let p = CohereEmbedV4Provider::new("k").with_input_type("classification");
2328 assert_eq!(p.input_type(), "classification");
2329 }
2330
2331 #[test]
2332 fn test_cohere_v4_build_payload_shape_document() {
2333 let p = CohereEmbedV4Provider::new("k").for_search_document();
2334 let payload = p.build_payload(&["doc".to_string()]);
2335 assert_eq!(payload["model"], "embed-english-v3.0");
2336 assert_eq!(payload["texts"][0], "doc");
2337 assert_eq!(payload["input_type"], "search_document");
2338 assert_eq!(payload["embedding_types"][0], "float");
2339 }
2340
2341 #[test]
2342 fn test_cohere_v4_build_payload_shape_query() {
2343 let p = CohereEmbedV4Provider::new("k").for_search_query();
2344 let payload = p.build_payload(&["q".to_string()]);
2345 assert_eq!(payload["input_type"], "search_query");
2346 }
2347
2348 #[test]
2349 fn test_cohere_v4_with_base_url() {
2350 let p = CohereEmbedV4Provider::new("k").with_base_url("https://custom.cohere/v2/embed");
2351 assert_eq!(p.dimension(), 1024);
2352 }
2353
2354 #[cfg(not(feature = "http-embeddings"))]
2355 #[tokio::test]
2356 async fn test_cohere_v4_embed_length() {
2357 let p = CohereEmbedV4Provider::new("k");
2358 let v = p.embed("cohere v4 test").await.unwrap();
2359 assert_eq!(v.len(), 1024);
2360 }
2361
2362 #[cfg(not(feature = "http-embeddings"))]
2363 #[tokio::test]
2364 async fn test_cohere_v4_embed_normalized() {
2365 let p = CohereEmbedV4Provider::new("k");
2366 let v = p.embed("x").await.unwrap();
2367 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2368 assert!((norm - 1.0).abs() < 0.01);
2369 }
2370
2371 #[cfg(not(feature = "http-embeddings"))]
2372 #[tokio::test]
2373 async fn test_cohere_v4_embed_deterministic() {
2374 let p = CohereEmbedV4Provider::new("k");
2375 let a = p.embed("same").await.unwrap();
2376 let b = p.embed("same").await.unwrap();
2377 assert_eq!(a, b);
2378 }
2379
2380 #[test]
2385 fn test_all_new_providers_implement_embedding_provider_trait() {
2386 let _boxes: Vec<Box<dyn EmbeddingProvider>> = vec![
2389 Box::new(JinaEmbeddingProvider::new("k")),
2390 Box::new(MistralEmbedProvider::new("k")),
2391 Box::new(NomicEmbedProvider::new("k")),
2392 Box::new(SentenceTransformersProvider::new("k")),
2393 Box::new(TogetherEmbedProvider::new("k")),
2394 Box::new(CohereEmbedV4Provider::new("k")),
2395 ];
2396 }
2397
2398 #[test]
2399 fn test_new_providers_have_expected_dimensions() {
2400 assert_eq!(JinaEmbeddingProvider::new("k").dimension(), 1024);
2401 assert_eq!(MistralEmbedProvider::new("k").dimension(), 1024);
2402 assert_eq!(NomicEmbedProvider::new("k").dimension(), 768);
2403 assert_eq!(SentenceTransformersProvider::new("k").dimension(), 384);
2404 assert_eq!(TogetherEmbedProvider::new("k").dimension(), 768);
2405 assert_eq!(CohereEmbedV4Provider::new("k").dimension(), 1024);
2406 }
2407}