Skip to main content

argentor_memory/
embeddings_providers.rs

1//! Multiple embedding provider backends implementing [`EmbeddingProvider`].
2//!
3//! Includes API-backed providers ([`OpenAiEmbeddingProvider`],
4//! [`CohereEmbeddingProvider`], [`VoyageEmbeddingProvider`]) that call their
5//! respective HTTP APIs to compute real embeddings.
6//!
7//! # Feature flag: `http-embeddings`
8//!
9//! The actual HTTP calls require the **`http-embeddings`** Cargo feature, which
10//! pulls in `reqwest`. When the feature is **disabled** (the default), calling
11//! `embed()` on any API-backed provider returns a descriptive error suggesting
12//! the user either enable the feature or use [`LocalEmbedding`].
13//!
14//! ```toml
15//! # Enable real HTTP embedding calls:
16//! argentor-memory = { version = "0.1", features = ["http-embeddings"] }
17//! ```
18//!
19//! For embeddings that work out of the box without external dependencies, use
20//! the [`LocalEmbedding`] provider (deterministic, hash-based, zero API keys).
21//!
22//! Also provides fully functional utilities:
23//! [`CachedEmbeddingProvider`], [`BatchEmbeddingProvider`],
24//! [`EmbeddingProviderFactory`], and [`EmbeddingConfig`].
25
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use tokio::sync::RwLock;
32
33use argentor_core::{ArgentorError, ArgentorResult};
34
35use crate::embedding::{EmbeddingProvider, LocalEmbedding};
36
37// ---------------------------------------------------------------------------
38// FNV-1a hash (same algorithm as embedding.rs, re-implemented here to avoid
39// depending on a private function).
40// ---------------------------------------------------------------------------
41
42fn fnv1a_hash(data: &[u8]) -> u64 {
43    let mut hash: u64 = 14695981039346656037;
44    for &byte in data {
45        hash ^= byte as u64;
46        hash = hash.wrapping_mul(1099511628211);
47    }
48    hash
49}
50
51// ===========================================================================
52// API request/response structures
53// ===========================================================================
54
55/// OpenAI embeddings API request body.
56#[derive(Debug, Serialize)]
57pub struct OpenAiEmbeddingRequest {
58    /// Model identifier (e.g., `"text-embedding-3-small"`).
59    pub model: String,
60    /// Texts to embed.
61    pub input: Vec<String>,
62}
63
64/// A single embedding object from the OpenAI response.
65#[derive(Debug, Deserialize)]
66pub struct OpenAiEmbeddingObject {
67    /// The embedding vector.
68    pub embedding: Vec<f32>,
69    /// Position of this embedding in the input batch.
70    pub index: usize,
71}
72
73/// OpenAI embeddings API response body.
74#[derive(Debug, Deserialize)]
75pub struct OpenAiEmbeddingResponse {
76    /// Embedding results, one per input text.
77    pub data: Vec<OpenAiEmbeddingObject>,
78    /// Model that produced the embeddings.
79    pub model: String,
80}
81
82/// Cohere embed API request body (v2).
83#[derive(Debug, Serialize)]
84pub struct CohereEmbedRequest {
85    /// Model identifier (e.g., `"embed-english-v3.0"`).
86    pub model: String,
87    /// Texts to embed.
88    pub texts: Vec<String>,
89    /// Input type hint (e.g., `"search_document"`, `"search_query"`).
90    pub input_type: String,
91    /// Which embedding types to return.
92    pub embedding_types: Vec<String>,
93}
94
95/// Cohere embed API v2 response body.
96#[derive(Debug, Deserialize)]
97pub struct CohereEmbedResponse {
98    /// Embedding vectors keyed by type. We request `"float"`.
99    pub embeddings: CohereEmbeddingsMap,
100}
101
102/// Container for different embedding type outputs from Cohere.
103#[derive(Debug, Deserialize)]
104pub struct CohereEmbeddingsMap {
105    /// Float embeddings, one vector per input text.
106    #[serde(default)]
107    pub float: Vec<Vec<f32>>,
108}
109
110/// Voyage AI embeddings API request body.
111#[derive(Debug, Serialize)]
112pub struct VoyageEmbeddingRequest {
113    /// Model identifier.
114    pub model: String,
115    /// Texts to embed.
116    pub input: Vec<String>,
117}
118
119/// A single embedding object from the Voyage response.
120#[derive(Debug, Deserialize)]
121pub struct VoyageEmbeddingObject {
122    /// The embedding vector.
123    pub embedding: Vec<f32>,
124    /// Position of this embedding in the input batch.
125    pub index: usize,
126}
127
128/// Voyage AI embeddings API response body.
129#[derive(Debug, Deserialize)]
130pub struct VoyageEmbeddingResponse {
131    /// Embedding results, one per input text.
132    pub data: Vec<VoyageEmbeddingObject>,
133}
134
135// ===========================================================================
136// Response parsing helpers (testable without HTTP)
137// ===========================================================================
138
139/// Parse an OpenAI embedding response JSON into a single embedding vector.
140///
141/// Expects the standard OpenAI response shape with `data[0].embedding`.
142/// Returns `Err` if the response is missing required fields.
143pub fn parse_openai_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
144    let response: OpenAiEmbeddingResponse = serde_json::from_value(json.clone())
145        .map_err(|e| ArgentorError::Agent(format!("Failed to parse OpenAI response: {e}")))?;
146    response
147        .data
148        .into_iter()
149        .next()
150        .map(|obj| obj.embedding)
151        .ok_or_else(|| {
152            ArgentorError::Agent("OpenAI response contains no embedding data".to_string())
153        })
154}
155
156/// Parse a Cohere v2 embed response JSON into a single embedding vector.
157///
158/// Expects the v2 shape: `embeddings.float[0]`.
159/// Returns `Err` if the response is missing required fields.
160pub fn parse_cohere_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
161    let response: CohereEmbedResponse = serde_json::from_value(json.clone())
162        .map_err(|e| ArgentorError::Agent(format!("Failed to parse Cohere response: {e}")))?;
163    response.embeddings.float.into_iter().next().ok_or_else(|| {
164        ArgentorError::Agent("Cohere response contains no float embeddings".to_string())
165    })
166}
167
168/// Parse a Voyage AI embedding response JSON into a single embedding vector.
169///
170/// Expects the standard Voyage shape: `data[0].embedding`.
171/// Returns `Err` if the response is missing required fields.
172pub fn parse_voyage_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
173    let response: VoyageEmbeddingResponse = serde_json::from_value(json.clone())
174        .map_err(|e| ArgentorError::Agent(format!("Failed to parse Voyage response: {e}")))?;
175    response
176        .data
177        .into_iter()
178        .next()
179        .map(|obj| obj.embedding)
180        .ok_or_else(|| {
181            ArgentorError::Agent("Voyage response contains no embedding data".to_string())
182        })
183}
184
185// ===========================================================================
186// 1. OpenAiEmbeddingProvider
187// ===========================================================================
188
189/// Embedding provider backed by the OpenAI embeddings API.
190///
191/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
192/// performs a real HTTP request when the `http-embeddings` feature is enabled.
193/// Without the feature, it returns an error. For local/offline embeddings, use
194/// [`LocalEmbedding`].
195pub struct OpenAiEmbeddingProvider {
196    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
197    api_key: String,
198    model: String,
199    dimensions: usize,
200    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
201    base_url: String,
202}
203
204impl OpenAiEmbeddingProvider {
205    /// Create a new OpenAI embedding provider.
206    ///
207    /// `model` defaults to `"text-embedding-3-small"` when `None`.
208    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
209        let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
210        let dimensions = Self::default_dimensions(&model);
211        Self {
212            api_key: api_key.into(),
213            model,
214            dimensions,
215            base_url: "https://api.openai.com/v1/embeddings".to_string(),
216        }
217    }
218
219    /// Create with a custom base URL (e.g. Azure OpenAI endpoint).
220    pub fn with_base_url(
221        api_key: impl Into<String>,
222        model: Option<String>,
223        base_url: impl Into<String>,
224    ) -> Self {
225        let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
226        let dimensions = Self::default_dimensions(&model);
227        Self {
228            api_key: api_key.into(),
229            model,
230            dimensions,
231            base_url: base_url.into(),
232        }
233    }
234
235    /// Override the output dimension count.
236    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
237        self.dimensions = dimensions;
238        self
239    }
240
241    fn default_dimensions(model: &str) -> usize {
242        match model {
243            "text-embedding-3-large" => 3072,
244            "text-embedding-3-small" => 1536,
245            "text-embedding-ada-002" => 1536,
246            _ => 1536,
247        }
248    }
249
250    /// Returns the model name this provider is configured with.
251    pub fn model(&self) -> &str {
252        &self.model
253    }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAiEmbeddingProvider {
258    #[cfg(feature = "http-embeddings")]
259    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
260        let client = reqwest::Client::new();
261        let response = client
262            .post(&self.base_url)
263            .header("Authorization", format!("Bearer {}", self.api_key))
264            .json(&serde_json::json!({
265                "model": self.model,
266                "input": text,
267            }))
268            .send()
269            .await
270            .map_err(|e| ArgentorError::Http(format!("OpenAI embedding request failed: {e}")))?;
271
272        let status = response.status();
273        if !status.is_success() {
274            let body = response.text().await.unwrap_or_default();
275            return Err(ArgentorError::Http(format!(
276                "OpenAI API error {status}: {body}"
277            )));
278        }
279
280        let json: serde_json::Value = response.json().await.map_err(|e| {
281            ArgentorError::Http(format!("Failed to read OpenAI response body: {e}"))
282        })?;
283
284        parse_openai_embedding_response(&json)
285    }
286
287    #[cfg(not(feature = "http-embeddings"))]
288    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
289        Err(ArgentorError::Http(
290            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
291             or use LocalEmbedding for offline embeddings."
292                .to_string(),
293        ))
294    }
295
296    fn dimension(&self) -> usize {
297        self.dimensions
298    }
299}
300
301// ===========================================================================
302// 2. CohereEmbeddingProvider
303// ===========================================================================
304
305/// Embedding provider backed by the Cohere embed API (v2).
306///
307/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
308/// performs a real HTTP request when the `http-embeddings` feature is enabled.
309/// Without the feature, it returns an error. For local/offline embeddings, use
310/// [`LocalEmbedding`].
311pub struct CohereEmbeddingProvider {
312    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
313    api_key: String,
314    model: String,
315    dimensions: usize,
316    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
317    input_type: String,
318}
319
320impl CohereEmbeddingProvider {
321    /// Create a new Cohere embedding provider.
322    ///
323    /// `model` defaults to `"embed-english-v3.0"`.
324    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
325        let model = model.unwrap_or_else(|| "embed-english-v3.0".to_string());
326        let dimensions = Self::default_dimensions(&model);
327        Self {
328            api_key: api_key.into(),
329            model,
330            dimensions,
331            input_type: "search_document".to_string(),
332        }
333    }
334
335    /// Set the input type (`"search_document"` for indexing, `"search_query"` for querying).
336    pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
337        self.input_type = input_type.into();
338        self
339    }
340
341    /// Override the output dimension count.
342    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
343        self.dimensions = dimensions;
344        self
345    }
346
347    fn default_dimensions(model: &str) -> usize {
348        match model {
349            "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
350            "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
351            _ => 1024,
352        }
353    }
354
355    /// Returns the model name.
356    pub fn model(&self) -> &str {
357        &self.model
358    }
359
360    /// Returns the current input type.
361    pub fn input_type(&self) -> &str {
362        &self.input_type
363    }
364}
365
366#[async_trait]
367impl EmbeddingProvider for CohereEmbeddingProvider {
368    #[cfg(feature = "http-embeddings")]
369    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
370        let client = reqwest::Client::new();
371        let response = client
372            .post("https://api.cohere.com/v2/embed")
373            .header("Authorization", format!("Bearer {}", self.api_key))
374            .json(&serde_json::json!({
375                "model": self.model,
376                "texts": [text],
377                "input_type": self.input_type,
378                "embedding_types": ["float"],
379            }))
380            .send()
381            .await
382            .map_err(|e| ArgentorError::Http(format!("Cohere embedding request failed: {e}")))?;
383
384        let status = response.status();
385        if !status.is_success() {
386            let body = response.text().await.unwrap_or_default();
387            return Err(ArgentorError::Http(format!(
388                "Cohere API error {status}: {body}"
389            )));
390        }
391
392        let json: serde_json::Value = response.json().await.map_err(|e| {
393            ArgentorError::Http(format!("Failed to read Cohere response body: {e}"))
394        })?;
395
396        parse_cohere_embedding_response(&json)
397    }
398
399    #[cfg(not(feature = "http-embeddings"))]
400    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
401        Err(ArgentorError::Http(
402            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
403             or use LocalEmbedding for offline embeddings."
404                .to_string(),
405        ))
406    }
407
408    fn dimension(&self) -> usize {
409        self.dimensions
410    }
411}
412
413// ===========================================================================
414// 3. VoyageEmbeddingProvider
415// ===========================================================================
416
417/// Embedding provider backed by the Voyage AI embeddings API.
418///
419/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
420/// performs a real HTTP request when the `http-embeddings` feature is enabled.
421/// Without the feature, it returns an error. For local/offline embeddings, use
422/// [`LocalEmbedding`].
423pub struct VoyageEmbeddingProvider {
424    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
425    api_key: String,
426    model: String,
427    dimensions: usize,
428}
429
430impl VoyageEmbeddingProvider {
431    /// Create a new Voyage embedding provider.
432    ///
433    /// `model` defaults to `"voyage-2"`.
434    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
435        let model = model.unwrap_or_else(|| "voyage-2".to_string());
436        let dimensions = Self::default_dimensions(&model);
437        Self {
438            api_key: api_key.into(),
439            model,
440            dimensions,
441        }
442    }
443
444    /// Override the output dimension count.
445    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
446        self.dimensions = dimensions;
447        self
448    }
449
450    fn default_dimensions(model: &str) -> usize {
451        match model {
452            "voyage-2" | "voyage-large-2" => 1024,
453            "voyage-lite-02-instruct" => 1024,
454            "voyage-3" => 1024,
455            "voyage-code-2" => 1536,
456            _ => 1024,
457        }
458    }
459
460    /// Returns the model name.
461    pub fn model(&self) -> &str {
462        &self.model
463    }
464}
465
466#[async_trait]
467impl EmbeddingProvider for VoyageEmbeddingProvider {
468    #[cfg(feature = "http-embeddings")]
469    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
470        let client = reqwest::Client::new();
471        let response = client
472            .post("https://api.voyageai.com/v1/embeddings")
473            .header("Authorization", format!("Bearer {}", self.api_key))
474            .json(&serde_json::json!({
475                "model": self.model,
476                "input": [text],
477            }))
478            .send()
479            .await
480            .map_err(|e| ArgentorError::Http(format!("Voyage embedding request failed: {e}")))?;
481
482        let status = response.status();
483        if !status.is_success() {
484            let body = response.text().await.unwrap_or_default();
485            return Err(ArgentorError::Http(format!(
486                "Voyage API error {status}: {body}"
487            )));
488        }
489
490        let json: serde_json::Value = response.json().await.map_err(|e| {
491            ArgentorError::Http(format!("Failed to read Voyage response body: {e}"))
492        })?;
493
494        parse_voyage_embedding_response(&json)
495    }
496
497    #[cfg(not(feature = "http-embeddings"))]
498    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
499        Err(ArgentorError::Http(
500            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
501             or use LocalEmbedding for offline embeddings."
502                .to_string(),
503        ))
504    }
505
506    fn dimension(&self) -> usize {
507        self.dimensions
508    }
509}
510
511// ===========================================================================
512// 4. CachedEmbeddingProvider
513// ===========================================================================
514
515/// Statistics about cache usage.
516#[derive(Debug, Clone, Default)]
517pub struct CacheStats {
518    /// Number of cache hits.
519    pub hits: u64,
520    /// Number of cache misses.
521    pub misses: u64,
522    /// Current number of entries in the cache.
523    pub size: usize,
524}
525
526/// Wraps any [`EmbeddingProvider`] with a thread-safe in-memory LRU-ish cache.
527///
528/// Embeddings are cached by FNV-1a hash of the input text. When the cache
529/// exceeds `max_cache_size`, the oldest entry (by insertion order) is evicted.
530pub struct CachedEmbeddingProvider {
531    inner: Arc<dyn EmbeddingProvider>,
532    cache: Arc<RwLock<HashMap<u64, Vec<f32>>>>,
533    max_cache_size: usize,
534    stats: Arc<RwLock<CacheStats>>,
535}
536
537impl CachedEmbeddingProvider {
538    /// Wrap an existing provider with caching.
539    pub fn new(inner: Arc<dyn EmbeddingProvider>, max_cache_size: usize) -> Self {
540        Self {
541            inner,
542            cache: Arc::new(RwLock::new(HashMap::new())),
543            max_cache_size,
544            stats: Arc::new(RwLock::new(CacheStats::default())),
545        }
546    }
547
548    /// Returns current cache statistics.
549    pub async fn cache_stats(&self) -> CacheStats {
550        self.stats.read().await.clone()
551    }
552
553    /// Clears the cache and resets statistics.
554    pub async fn clear(&self) {
555        self.cache.write().await.clear();
556        let mut stats = self.stats.write().await;
557        stats.size = 0;
558    }
559
560    fn text_hash(text: &str) -> u64 {
561        fnv1a_hash(text.as_bytes())
562    }
563}
564
565#[async_trait]
566impl EmbeddingProvider for CachedEmbeddingProvider {
567    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
568        let key = Self::text_hash(text);
569
570        // Check cache (read lock).
571        {
572            let cache = self.cache.read().await;
573            if let Some(cached) = cache.get(&key) {
574                let mut stats = self.stats.write().await;
575                stats.hits += 1;
576                return Ok(cached.clone());
577            }
578        }
579
580        // Cache miss — compute embedding.
581        let embedding = self.inner.embed(text).await?;
582
583        // Insert into cache (write lock).
584        {
585            let mut cache = self.cache.write().await;
586
587            // Evict if at capacity.
588            if cache.len() >= self.max_cache_size {
589                // Remove an arbitrary entry (HashMap iteration order is random,
590                // which acts as a simple eviction strategy).
591                if let Some(&evict_key) = cache.keys().next() {
592                    cache.remove(&evict_key);
593                }
594            }
595
596            cache.insert(key, embedding.clone());
597
598            let mut stats = self.stats.write().await;
599            stats.misses += 1;
600            stats.size = cache.len();
601        }
602
603        Ok(embedding)
604    }
605
606    fn dimension(&self) -> usize {
607        self.inner.dimension()
608    }
609}
610
611// ===========================================================================
612// 5. BatchEmbeddingProvider
613// ===========================================================================
614
615/// Wraps any [`EmbeddingProvider`] to expose a convenience batch method.
616///
617/// Delegates to the inner provider's `embed_batch` (which by default calls
618/// `embed` sequentially). Providers that support native batching can override
619/// `embed_batch` on the trait for better performance.
620pub struct BatchEmbeddingProvider {
621    inner: Arc<dyn EmbeddingProvider>,
622}
623
624impl BatchEmbeddingProvider {
625    /// Wrap an existing provider for batch operations.
626    pub fn new(inner: Arc<dyn EmbeddingProvider>) -> Self {
627        Self { inner }
628    }
629
630    /// Embed multiple texts, returning one vector per input.
631    pub async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
632        self.inner.embed_batch(texts).await
633    }
634}
635
636#[async_trait]
637impl EmbeddingProvider for BatchEmbeddingProvider {
638    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
639        self.inner.embed(text).await
640    }
641
642    async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
643        self.inner.embed_batch(texts).await
644    }
645
646    fn dimension(&self) -> usize {
647        self.inner.dimension()
648    }
649}
650
651// ===========================================================================
652// 6. EmbeddingProviderFactory
653// ===========================================================================
654
655/// Factory that creates [`EmbeddingProvider`] instances by name.
656pub struct EmbeddingProviderFactory;
657
658impl EmbeddingProviderFactory {
659    /// Create an embedding provider from its string name.
660    ///
661    /// Supported names: `"openai"`, `"cohere"`, `"voyage"`, `"local"`.
662    pub fn create(
663        provider_name: &str,
664        api_key: impl Into<String>,
665        model: Option<String>,
666    ) -> ArgentorResult<Box<dyn EmbeddingProvider>> {
667        let api_key = api_key.into();
668        match provider_name {
669            "openai" => Ok(Box::new(OpenAiEmbeddingProvider::new(api_key, model))),
670            "cohere" => Ok(Box::new(CohereEmbeddingProvider::new(api_key, model))),
671            "voyage" => Ok(Box::new(VoyageEmbeddingProvider::new(api_key, model))),
672            "local" => {
673                let dim = model
674                    .as_deref()
675                    .and_then(|m| m.parse::<usize>().ok())
676                    .unwrap_or(256);
677                Ok(Box::new(LocalEmbedding::new(dim)))
678            }
679            other => Err(ArgentorError::Config(format!(
680                "Unknown embedding provider: {other}"
681            ))),
682        }
683    }
684
685    /// List all supported provider names.
686    pub fn available_providers() -> &'static [&'static str] {
687        &["openai", "cohere", "voyage", "local"]
688    }
689}
690
691// ===========================================================================
692// 7. EmbeddingConfig
693// ===========================================================================
694
695/// Serializable configuration for constructing an embedding provider.
696#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct EmbeddingConfig {
698    /// Provider name (`"openai"`, `"cohere"`, `"voyage"`, `"local"`).
699    pub provider: String,
700    /// API key (ignored for `"local"`).
701    #[serde(default)]
702    pub api_key: String,
703    /// Model name override.
704    #[serde(default)]
705    pub model: Option<String>,
706    /// Override for output dimensions.
707    #[serde(default)]
708    pub dimensions: Option<usize>,
709    /// Custom API base URL (e.g. Azure OpenAI).
710    #[serde(default)]
711    pub base_url: Option<String>,
712    /// If set, wraps the provider with a [`CachedEmbeddingProvider`].
713    #[serde(default)]
714    pub cache_size: Option<usize>,
715}
716
717impl EmbeddingConfig {
718    /// Build an [`EmbeddingProvider`] from this configuration.
719    ///
720    /// Returns an `Arc`-wrapped provider, optionally wrapped in a cache layer.
721    pub fn build(&self) -> ArgentorResult<Arc<dyn EmbeddingProvider>> {
722        let mut provider: Box<dyn EmbeddingProvider> = match self.provider.as_str() {
723            "openai" => {
724                let mut p = if let Some(ref url) = self.base_url {
725                    OpenAiEmbeddingProvider::with_base_url(&self.api_key, self.model.clone(), url)
726                } else {
727                    OpenAiEmbeddingProvider::new(&self.api_key, self.model.clone())
728                };
729                if let Some(dim) = self.dimensions {
730                    p = p.with_dimensions(dim);
731                }
732                Box::new(p)
733            }
734            "cohere" => {
735                let mut p = CohereEmbeddingProvider::new(&self.api_key, self.model.clone());
736                if let Some(dim) = self.dimensions {
737                    p = p.with_dimensions(dim);
738                }
739                Box::new(p)
740            }
741            "voyage" => {
742                let mut p = VoyageEmbeddingProvider::new(&self.api_key, self.model.clone());
743                if let Some(dim) = self.dimensions {
744                    p = p.with_dimensions(dim);
745                }
746                Box::new(p)
747            }
748            "local" => {
749                let dim = self.dimensions.unwrap_or(256);
750                Box::new(LocalEmbedding::new(dim))
751            }
752            other => {
753                return Err(ArgentorError::Config(format!(
754                    "Unknown embedding provider: {other}"
755                )));
756            }
757        };
758
759        // Wrap with dimensions override if the provider itself doesn't support
760        // it natively (already handled above for each provider).
761        let _ = &mut provider; // suppress unused-mut if branch is empty
762
763        let arc: Arc<dyn EmbeddingProvider> = Arc::from(provider);
764
765        // Optionally wrap with caching.
766        if let Some(cache_size) = self.cache_size {
767            Ok(Arc::new(CachedEmbeddingProvider::new(arc, cache_size)))
768        } else {
769            Ok(arc)
770        }
771    }
772}
773
774impl Default for EmbeddingConfig {
775    fn default() -> Self {
776        Self {
777            provider: "local".to_string(),
778            api_key: String::new(),
779            model: None,
780            dimensions: None,
781            base_url: None,
782            cache_size: None,
783        }
784    }
785}
786
787// ===========================================================================
788// Shared helpers for new providers (stub + payload builders)
789// ===========================================================================
790
791/// Build a deterministic stub embedding from the input text.
792///
793/// Used by all new providers when the `http-embeddings` feature is disabled,
794/// so tests and offline usage still get a usable L2-normalized vector.
795/// Also exercised by unit tests even when the HTTP feature is on.
796#[cfg_attr(
797    all(feature = "http-embeddings", not(test)),
798    allow(dead_code)
799)]
800fn stub_embedding(text: &str, dimensions: usize) -> Vec<f32> {
801    let dim = dimensions.max(1);
802    let mut v = vec![0.0f32; dim];
803    for (i, b) in text.bytes().enumerate() {
804        v[i % dim] += (b as f32) / 255.0;
805    }
806    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
807    if norm > 0.0 {
808        for x in &mut v {
809            *x /= norm;
810        }
811    }
812    v
813}
814
815// ===========================================================================
816// 8. JinaEmbeddingProvider
817// ===========================================================================
818
819/// Embedding provider backed by the Jina AI embeddings API.
820///
821/// Default model: `jina-embeddings-v3` (1024 dims). Also supports multimodal
822/// models such as `jina-clip-v2`. Calling [`embed()`] performs a real HTTP
823/// request when the `http-embeddings` feature is enabled. Without the feature,
824/// returns a deterministic stub vector (useful for offline tests).
825pub struct JinaEmbeddingProvider {
826    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
827    api_key: String,
828    model: String,
829    dimensions: usize,
830    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
831    base_url: String,
832}
833
834impl JinaEmbeddingProvider {
835    /// Create a new Jina provider with the default model (`jina-embeddings-v3`).
836    pub fn new(api_key: impl Into<String>) -> Self {
837        Self::with_model(api_key, "jina-embeddings-v3", 1024)
838    }
839
840    /// Create with an explicit model and dimension override.
841    pub fn with_model(
842        api_key: impl Into<String>,
843        model: impl Into<String>,
844        dimensions: usize,
845    ) -> Self {
846        Self {
847            api_key: api_key.into(),
848            model: model.into(),
849            dimensions,
850            base_url: "https://api.jina.ai/v1/embeddings".to_string(),
851        }
852    }
853
854    /// Override the API base URL.
855    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
856        self.base_url = base_url.into();
857        self
858    }
859
860    /// Returns the configured model name.
861    pub fn model(&self) -> &str {
862        &self.model
863    }
864
865    /// Build the request payload for the Jina embeddings API.
866    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
867        serde_json::json!({
868            "model": self.model,
869            "input": texts,
870        })
871    }
872}
873
874#[async_trait]
875impl EmbeddingProvider for JinaEmbeddingProvider {
876    #[cfg(feature = "http-embeddings")]
877    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
878        let client = reqwest::Client::new();
879        let payload = self.build_payload(&[text.to_string()]);
880        let response = client
881            .post(&self.base_url)
882            .header("Authorization", format!("Bearer {}", self.api_key))
883            .json(&payload)
884            .send()
885            .await
886            .map_err(|e| ArgentorError::Http(format!("Jina embedding request failed: {e}")))?;
887
888        let status = response.status();
889        if !status.is_success() {
890            let body = response.text().await.unwrap_or_default();
891            return Err(ArgentorError::Http(format!(
892                "Jina API error {status}: {body}"
893            )));
894        }
895
896        let json: serde_json::Value = response
897            .json()
898            .await
899            .map_err(|e| ArgentorError::Http(format!("Failed to read Jina response body: {e}")))?;
900
901        // Jina follows the OpenAI-compatible `data[].embedding` shape.
902        parse_openai_embedding_response(&json)
903    }
904
905    #[cfg(not(feature = "http-embeddings"))]
906    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
907        Ok(stub_embedding(text, self.dimensions))
908    }
909
910    fn dimension(&self) -> usize {
911        self.dimensions
912    }
913}
914
915// ===========================================================================
916// 9. MistralEmbedProvider
917// ===========================================================================
918
919/// Embedding provider backed by the Mistral AI embeddings API.
920///
921/// Default model: `mistral-embed` (1024 dims). Mistral's embedding endpoint
922/// follows an OpenAI-compatible request/response shape.
923pub struct MistralEmbedProvider {
924    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
925    api_key: String,
926    model: String,
927    dimensions: usize,
928    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
929    base_url: String,
930}
931
932impl MistralEmbedProvider {
933    /// Create a new Mistral provider with the default model (`mistral-embed`).
934    pub fn new(api_key: impl Into<String>) -> Self {
935        Self::with_model(api_key, "mistral-embed", 1024)
936    }
937
938    /// Create with an explicit model and dimension override.
939    pub fn with_model(
940        api_key: impl Into<String>,
941        model: impl Into<String>,
942        dimensions: usize,
943    ) -> Self {
944        Self {
945            api_key: api_key.into(),
946            model: model.into(),
947            dimensions,
948            base_url: "https://api.mistral.ai/v1/embeddings".to_string(),
949        }
950    }
951
952    /// Override the API base URL.
953    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
954        self.base_url = base_url.into();
955        self
956    }
957
958    /// Returns the configured model name.
959    pub fn model(&self) -> &str {
960        &self.model
961    }
962
963    /// Build the request payload for the Mistral embeddings API.
964    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
965        serde_json::json!({
966            "model": self.model,
967            "input": texts,
968        })
969    }
970}
971
972#[async_trait]
973impl EmbeddingProvider for MistralEmbedProvider {
974    #[cfg(feature = "http-embeddings")]
975    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
976        let client = reqwest::Client::new();
977        let payload = self.build_payload(&[text.to_string()]);
978        let response = client
979            .post(&self.base_url)
980            .header("Authorization", format!("Bearer {}", self.api_key))
981            .json(&payload)
982            .send()
983            .await
984            .map_err(|e| ArgentorError::Http(format!("Mistral embedding request failed: {e}")))?;
985
986        let status = response.status();
987        if !status.is_success() {
988            let body = response.text().await.unwrap_or_default();
989            return Err(ArgentorError::Http(format!(
990                "Mistral API error {status}: {body}"
991            )));
992        }
993
994        let json: serde_json::Value = response.json().await.map_err(|e| {
995            ArgentorError::Http(format!("Failed to read Mistral response body: {e}"))
996        })?;
997
998        parse_openai_embedding_response(&json)
999    }
1000
1001    #[cfg(not(feature = "http-embeddings"))]
1002    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1003        Ok(stub_embedding(text, self.dimensions))
1004    }
1005
1006    fn dimension(&self) -> usize {
1007        self.dimensions
1008    }
1009}
1010
1011// ===========================================================================
1012// 10. NomicEmbedProvider
1013// ===========================================================================
1014
1015/// Embedding provider backed by the Nomic Atlas embeddings API.
1016///
1017/// Default model: `nomic-embed-text-v1.5` (768 dims). The Nomic endpoint
1018/// accepts an array of texts under the `texts` key and responds with
1019/// `{ "embeddings": [ [...], ... ] }`.
1020pub struct NomicEmbedProvider {
1021    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1022    api_key: String,
1023    model: String,
1024    dimensions: usize,
1025    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1026    base_url: String,
1027    task_type: String,
1028}
1029
1030impl NomicEmbedProvider {
1031    /// Create a new Nomic provider with the default model (`nomic-embed-text-v1.5`).
1032    pub fn new(api_key: impl Into<String>) -> Self {
1033        Self::with_model(api_key, "nomic-embed-text-v1.5", 768)
1034    }
1035
1036    /// Create with an explicit model and dimension override.
1037    pub fn with_model(
1038        api_key: impl Into<String>,
1039        model: impl Into<String>,
1040        dimensions: usize,
1041    ) -> Self {
1042        Self {
1043            api_key: api_key.into(),
1044            model: model.into(),
1045            dimensions,
1046            base_url: "https://api-atlas.nomic.ai/v1/embedding/text".to_string(),
1047            task_type: "search_document".to_string(),
1048        }
1049    }
1050
1051    /// Set the `task_type` (`search_document`, `search_query`, `clustering`, `classification`).
1052    pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
1053        self.task_type = task_type.into();
1054        self
1055    }
1056
1057    /// Returns the configured model name.
1058    pub fn model(&self) -> &str {
1059        &self.model
1060    }
1061
1062    /// Returns the current task type.
1063    pub fn task_type(&self) -> &str {
1064        &self.task_type
1065    }
1066
1067    /// Build the request payload for the Nomic embeddings API.
1068    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1069        serde_json::json!({
1070            "model": self.model,
1071            "texts": texts,
1072            "task_type": self.task_type,
1073        })
1074    }
1075}
1076
1077#[async_trait]
1078impl EmbeddingProvider for NomicEmbedProvider {
1079    #[cfg(feature = "http-embeddings")]
1080    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1081        let client = reqwest::Client::new();
1082        let payload = self.build_payload(&[text.to_string()]);
1083        let response = client
1084            .post(&self.base_url)
1085            .header("Authorization", format!("Bearer {}", self.api_key))
1086            .json(&payload)
1087            .send()
1088            .await
1089            .map_err(|e| ArgentorError::Http(format!("Nomic embedding request failed: {e}")))?;
1090
1091        let status = response.status();
1092        if !status.is_success() {
1093            let body = response.text().await.unwrap_or_default();
1094            return Err(ArgentorError::Http(format!(
1095                "Nomic API error {status}: {body}"
1096            )));
1097        }
1098
1099        let json: serde_json::Value = response.json().await.map_err(|e| {
1100            ArgentorError::Http(format!("Failed to read Nomic response body: {e}"))
1101        })?;
1102
1103        // Nomic response shape: { "embeddings": [[...]] }
1104        let embeddings = json
1105            .get("embeddings")
1106            .and_then(|v| v.as_array())
1107            .ok_or_else(|| {
1108                ArgentorError::Agent("Nomic response missing 'embeddings' array".to_string())
1109            })?;
1110        let first = embeddings.first().ok_or_else(|| {
1111            ArgentorError::Agent("Nomic response contains no embedding vectors".to_string())
1112        })?;
1113        let vec: Vec<f32> = serde_json::from_value(first.clone()).map_err(|e| {
1114            ArgentorError::Agent(format!("Failed to parse Nomic embedding vector: {e}"))
1115        })?;
1116        Ok(vec)
1117    }
1118
1119    #[cfg(not(feature = "http-embeddings"))]
1120    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1121        Ok(stub_embedding(text, self.dimensions))
1122    }
1123
1124    fn dimension(&self) -> usize {
1125        self.dimensions
1126    }
1127}
1128
1129// ===========================================================================
1130// 11. SentenceTransformersProvider (via Hugging Face Inference API)
1131// ===========================================================================
1132
1133/// Embedding provider backed by the Hugging Face Inference API for
1134/// `sentence-transformers/*` models.
1135///
1136/// Default model: `sentence-transformers/all-MiniLM-L6-v2` (384 dims).
1137/// Also supports `all-mpnet-base-v2` (768) and `multi-qa-mpnet-base-dot-v1` (768).
1138pub struct SentenceTransformersProvider {
1139    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1140    api_key: String,
1141    model: String,
1142    dimensions: usize,
1143    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1144    base_url: String,
1145}
1146
1147impl SentenceTransformersProvider {
1148    /// Create a new provider with the default model (`all-MiniLM-L6-v2`, 384 dims).
1149    pub fn new(api_key: impl Into<String>) -> Self {
1150        Self::with_model(api_key, "sentence-transformers/all-MiniLM-L6-v2", 384)
1151    }
1152
1153    /// Create with an explicit model and dimension override.
1154    pub fn with_model(
1155        api_key: impl Into<String>,
1156        model: impl Into<String>,
1157        dimensions: usize,
1158    ) -> Self {
1159        let model = model.into();
1160        let base_url =
1161            format!("https://api-inference.huggingface.co/pipeline/feature-extraction/{model}");
1162        Self {
1163            api_key: api_key.into(),
1164            model,
1165            dimensions,
1166            base_url,
1167        }
1168    }
1169
1170    /// Override the API base URL (useful for self-hosted HF inference endpoints).
1171    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1172        self.base_url = base_url.into();
1173        self
1174    }
1175
1176    /// Returns the configured model name.
1177    pub fn model(&self) -> &str {
1178        &self.model
1179    }
1180
1181    /// Default dimension for well-known sentence-transformer models.
1182    pub fn default_dimensions(model: &str) -> usize {
1183        match model {
1184            "sentence-transformers/all-MiniLM-L6-v2" => 384,
1185            "sentence-transformers/all-mpnet-base-v2"
1186            | "sentence-transformers/multi-qa-mpnet-base-dot-v1" => 768,
1187            _ => 384,
1188        }
1189    }
1190
1191    /// Build the request payload for the HF Inference API.
1192    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1193        serde_json::json!({
1194            "inputs": texts,
1195            "options": { "wait_for_model": true },
1196        })
1197    }
1198}
1199
1200#[async_trait]
1201impl EmbeddingProvider for SentenceTransformersProvider {
1202    #[cfg(feature = "http-embeddings")]
1203    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1204        let client = reqwest::Client::new();
1205        let payload = self.build_payload(&[text.to_string()]);
1206        let response = client
1207            .post(&self.base_url)
1208            .header("Authorization", format!("Bearer {}", self.api_key))
1209            .json(&payload)
1210            .send()
1211            .await
1212            .map_err(|e| {
1213                ArgentorError::Http(format!("HuggingFace embedding request failed: {e}"))
1214            })?;
1215
1216        let status = response.status();
1217        if !status.is_success() {
1218            let body = response.text().await.unwrap_or_default();
1219            return Err(ArgentorError::Http(format!(
1220                "HuggingFace API error {status}: {body}"
1221            )));
1222        }
1223
1224        let json: serde_json::Value = response.json().await.map_err(|e| {
1225            ArgentorError::Http(format!("Failed to read HuggingFace response body: {e}"))
1226        })?;
1227
1228        // HF feature-extraction returns either `[[f32; D]]` (batch) or `[f32; D]` (single).
1229        match &json {
1230            serde_json::Value::Array(arr)
1231                if arr.first().is_some_and(serde_json::Value::is_array) =>
1232            {
1233                let first = arr.first().cloned().ok_or_else(|| {
1234                    ArgentorError::Agent("HuggingFace response empty".to_string())
1235                })?;
1236                serde_json::from_value(first).map_err(|e| {
1237                    ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
1238                })
1239            }
1240            serde_json::Value::Array(_) => serde_json::from_value(json).map_err(|e| {
1241                ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
1242            }),
1243            _ => Err(ArgentorError::Agent(
1244                "HuggingFace response is not an array".to_string(),
1245            )),
1246        }
1247    }
1248
1249    #[cfg(not(feature = "http-embeddings"))]
1250    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1251        Ok(stub_embedding(text, self.dimensions))
1252    }
1253
1254    fn dimension(&self) -> usize {
1255        self.dimensions
1256    }
1257}
1258
1259// ===========================================================================
1260// 12. TogetherEmbedProvider
1261// ===========================================================================
1262
1263/// Embedding provider backed by the Together AI embeddings API.
1264///
1265/// Default model: `togethercomputer/m2-bert-80M-32k-retrieval` (768 dims).
1266/// Together uses an OpenAI-compatible request/response shape.
1267pub struct TogetherEmbedProvider {
1268    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1269    api_key: String,
1270    model: String,
1271    dimensions: usize,
1272    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1273    base_url: String,
1274}
1275
1276impl TogetherEmbedProvider {
1277    /// Create a new Together provider with the default BERT retrieval model.
1278    pub fn new(api_key: impl Into<String>) -> Self {
1279        Self::with_model(
1280            api_key,
1281            "togethercomputer/m2-bert-80M-32k-retrieval",
1282            768,
1283        )
1284    }
1285
1286    /// Create with an explicit model and dimension override.
1287    pub fn with_model(
1288        api_key: impl Into<String>,
1289        model: impl Into<String>,
1290        dimensions: usize,
1291    ) -> Self {
1292        Self {
1293            api_key: api_key.into(),
1294            model: model.into(),
1295            dimensions,
1296            base_url: "https://api.together.xyz/v1/embeddings".to_string(),
1297        }
1298    }
1299
1300    /// Override the API base URL.
1301    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1302        self.base_url = base_url.into();
1303        self
1304    }
1305
1306    /// Returns the configured model name.
1307    pub fn model(&self) -> &str {
1308        &self.model
1309    }
1310
1311    /// Build the request payload for the Together API.
1312    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1313        serde_json::json!({
1314            "model": self.model,
1315            "input": texts,
1316        })
1317    }
1318}
1319
1320#[async_trait]
1321impl EmbeddingProvider for TogetherEmbedProvider {
1322    #[cfg(feature = "http-embeddings")]
1323    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1324        let client = reqwest::Client::new();
1325        let payload = self.build_payload(&[text.to_string()]);
1326        let response = client
1327            .post(&self.base_url)
1328            .header("Authorization", format!("Bearer {}", self.api_key))
1329            .json(&payload)
1330            .send()
1331            .await
1332            .map_err(|e| ArgentorError::Http(format!("Together embedding request failed: {e}")))?;
1333
1334        let status = response.status();
1335        if !status.is_success() {
1336            let body = response.text().await.unwrap_or_default();
1337            return Err(ArgentorError::Http(format!(
1338                "Together API error {status}: {body}"
1339            )));
1340        }
1341
1342        let json: serde_json::Value = response.json().await.map_err(|e| {
1343            ArgentorError::Http(format!("Failed to read Together response body: {e}"))
1344        })?;
1345
1346        parse_openai_embedding_response(&json)
1347    }
1348
1349    #[cfg(not(feature = "http-embeddings"))]
1350    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1351        Ok(stub_embedding(text, self.dimensions))
1352    }
1353
1354    fn dimension(&self) -> usize {
1355        self.dimensions
1356    }
1357}
1358
1359// ===========================================================================
1360// 13. CohereEmbedV4Provider (newer v4 embed endpoint)
1361// ===========================================================================
1362
1363/// Embedding provider backed by the Cohere v2 embed API (labeled "v4" here
1364/// to disambiguate from the existing [`CohereEmbeddingProvider`] and to
1365/// mirror the naming in higher-level integrations).
1366///
1367/// Differs from [`CohereEmbeddingProvider`] in that it exposes explicit
1368/// `input_type` helpers (`for_search_document`, `for_search_query`) and
1369/// supports `embed-english-v3.0` / `embed-multilingual-v3.0` at 1024 dims.
1370pub struct CohereEmbedV4Provider {
1371    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1372    api_key: String,
1373    model: String,
1374    dimensions: usize,
1375    input_type: String,
1376    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1377    base_url: String,
1378}
1379
1380impl CohereEmbedV4Provider {
1381    /// Create a new v4 provider with the default model (`embed-english-v3.0`).
1382    pub fn new(api_key: impl Into<String>) -> Self {
1383        Self::with_model(api_key, "embed-english-v3.0", 1024)
1384    }
1385
1386    /// Create with an explicit model and dimension override.
1387    pub fn with_model(
1388        api_key: impl Into<String>,
1389        model: impl Into<String>,
1390        dimensions: usize,
1391    ) -> Self {
1392        Self {
1393            api_key: api_key.into(),
1394            model: model.into(),
1395            dimensions,
1396            input_type: "search_document".to_string(),
1397            base_url: "https://api.cohere.com/v2/embed".to_string(),
1398        }
1399    }
1400
1401    /// Configure this provider for indexing documents (`search_document`).
1402    pub fn for_search_document(mut self) -> Self {
1403        self.input_type = "search_document".to_string();
1404        self
1405    }
1406
1407    /// Configure this provider for querying (`search_query`).
1408    pub fn for_search_query(mut self) -> Self {
1409        self.input_type = "search_query".to_string();
1410        self
1411    }
1412
1413    /// Set an arbitrary `input_type` string.
1414    pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
1415        self.input_type = input_type.into();
1416        self
1417    }
1418
1419    /// Override the API base URL.
1420    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1421        self.base_url = base_url.into();
1422        self
1423    }
1424
1425    /// Returns the configured model name.
1426    pub fn model(&self) -> &str {
1427        &self.model
1428    }
1429
1430    /// Returns the current input type.
1431    pub fn input_type(&self) -> &str {
1432        &self.input_type
1433    }
1434
1435    /// Build the request payload for the Cohere v2 embed endpoint.
1436    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1437        serde_json::json!({
1438            "model": self.model,
1439            "texts": texts,
1440            "input_type": self.input_type,
1441            "embedding_types": ["float"],
1442        })
1443    }
1444}
1445
1446#[async_trait]
1447impl EmbeddingProvider for CohereEmbedV4Provider {
1448    #[cfg(feature = "http-embeddings")]
1449    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1450        let client = reqwest::Client::new();
1451        let payload = self.build_payload(&[text.to_string()]);
1452        let response = client
1453            .post(&self.base_url)
1454            .header("Authorization", format!("Bearer {}", self.api_key))
1455            .json(&payload)
1456            .send()
1457            .await
1458            .map_err(|e| {
1459                ArgentorError::Http(format!("Cohere v4 embedding request failed: {e}"))
1460            })?;
1461
1462        let status = response.status();
1463        if !status.is_success() {
1464            let body = response.text().await.unwrap_or_default();
1465            return Err(ArgentorError::Http(format!(
1466                "Cohere v4 API error {status}: {body}"
1467            )));
1468        }
1469
1470        let json: serde_json::Value = response.json().await.map_err(|e| {
1471            ArgentorError::Http(format!("Failed to read Cohere v4 response body: {e}"))
1472        })?;
1473
1474        parse_cohere_embedding_response(&json)
1475    }
1476
1477    #[cfg(not(feature = "http-embeddings"))]
1478    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1479        Ok(stub_embedding(text, self.dimensions))
1480    }
1481
1482    fn dimension(&self) -> usize {
1483        self.dimensions
1484    }
1485}
1486
1487// ===========================================================================
1488// Tests
1489// ===========================================================================
1490
1491#[cfg(test)]
1492#[allow(clippy::unwrap_used, clippy::expect_used)]
1493mod tests {
1494    use super::*;
1495
1496    // -- Provider creation tests ------------------------------------------
1497
1498    #[test]
1499    fn test_openai_provider_default_model() {
1500        let p = OpenAiEmbeddingProvider::new("sk-test", None);
1501        assert_eq!(p.model(), "text-embedding-3-small");
1502        assert_eq!(p.dimension(), 1536);
1503    }
1504
1505    #[test]
1506    fn test_openai_provider_large_model() {
1507        let p = OpenAiEmbeddingProvider::new("sk-test", Some("text-embedding-3-large".into()));
1508        assert_eq!(p.dimension(), 3072);
1509    }
1510
1511    #[test]
1512    fn test_openai_provider_custom_dimensions() {
1513        let p = OpenAiEmbeddingProvider::new("sk-test", None).with_dimensions(512);
1514        assert_eq!(p.dimension(), 512);
1515    }
1516
1517    #[test]
1518    fn test_openai_provider_custom_base_url() {
1519        let p = OpenAiEmbeddingProvider::with_base_url(
1520            "sk-test",
1521            None,
1522            "https://my-azure.openai.azure.com/openai/deployments/embed",
1523        );
1524        assert_eq!(p.dimension(), 1536);
1525    }
1526
1527    #[cfg(not(feature = "http-embeddings"))]
1528    #[tokio::test]
1529    async fn test_openai_provider_returns_feature_error() {
1530        let p = OpenAiEmbeddingProvider::new("sk-test", None);
1531        let err = p.embed("hello").await.unwrap_err();
1532        let msg = format!("{err}");
1533        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1534    }
1535
1536    #[test]
1537    fn test_cohere_provider_default() {
1538        let p = CohereEmbeddingProvider::new("key", None);
1539        assert_eq!(p.model(), "embed-english-v3.0");
1540        assert_eq!(p.dimension(), 1024);
1541        assert_eq!(p.input_type(), "search_document");
1542    }
1543
1544    #[test]
1545    fn test_cohere_provider_query_input_type() {
1546        let p = CohereEmbeddingProvider::new("key", None).with_input_type("search_query");
1547        assert_eq!(p.input_type(), "search_query");
1548    }
1549
1550    #[test]
1551    fn test_cohere_provider_light_model() {
1552        let p = CohereEmbeddingProvider::new("key", Some("embed-english-light-v3.0".into()));
1553        assert_eq!(p.dimension(), 384);
1554    }
1555
1556    #[cfg(not(feature = "http-embeddings"))]
1557    #[tokio::test]
1558    async fn test_cohere_provider_returns_feature_error() {
1559        let p = CohereEmbeddingProvider::new("key", None);
1560        let err = p.embed("hello").await.unwrap_err();
1561        let msg = format!("{err}");
1562        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1563    }
1564
1565    #[test]
1566    fn test_voyage_provider_default() {
1567        let p = VoyageEmbeddingProvider::new("key", None);
1568        assert_eq!(p.model(), "voyage-2");
1569        assert_eq!(p.dimension(), 1024);
1570    }
1571
1572    #[test]
1573    fn test_voyage_provider_code_model() {
1574        let p = VoyageEmbeddingProvider::new("key", Some("voyage-code-2".into()));
1575        assert_eq!(p.dimension(), 1536);
1576    }
1577
1578    #[cfg(not(feature = "http-embeddings"))]
1579    #[tokio::test]
1580    async fn test_voyage_provider_returns_feature_error() {
1581        let p = VoyageEmbeddingProvider::new("key", None);
1582        let err = p.embed("hello").await.unwrap_err();
1583        let msg = format!("{err}");
1584        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1585    }
1586
1587    // -- Response parsing tests -------------------------------------------
1588
1589    #[test]
1590    fn test_parse_openai_embedding_response_valid() {
1591        let json = serde_json::json!({
1592            "data": [
1593                {
1594                    "embedding": [0.1, 0.2, 0.3, 0.4],
1595                    "index": 0
1596                }
1597            ],
1598            "model": "text-embedding-3-small"
1599        });
1600        let result = parse_openai_embedding_response(&json).unwrap();
1601        assert_eq!(result, vec![0.1, 0.2, 0.3, 0.4]);
1602    }
1603
1604    #[test]
1605    fn test_parse_openai_embedding_response_empty_data() {
1606        let json = serde_json::json!({
1607            "data": [],
1608            "model": "text-embedding-3-small"
1609        });
1610        let err = parse_openai_embedding_response(&json).unwrap_err();
1611        let msg = format!("{err}");
1612        assert!(msg.contains("no embedding data"), "got: {msg}");
1613    }
1614
1615    #[test]
1616    fn test_parse_openai_embedding_response_invalid_shape() {
1617        let json = serde_json::json!({ "error": "bad request" });
1618        let err = parse_openai_embedding_response(&json).unwrap_err();
1619        let msg = format!("{err}");
1620        assert!(msg.contains("Failed to parse"), "got: {msg}");
1621    }
1622
1623    #[test]
1624    fn test_parse_openai_embedding_response_multiple_picks_first() {
1625        let json = serde_json::json!({
1626            "data": [
1627                { "embedding": [1.0, 2.0], "index": 0 },
1628                { "embedding": [3.0, 4.0], "index": 1 }
1629            ],
1630            "model": "text-embedding-3-small"
1631        });
1632        let result = parse_openai_embedding_response(&json).unwrap();
1633        assert_eq!(result, vec![1.0, 2.0]);
1634    }
1635
1636    #[test]
1637    fn test_parse_cohere_embedding_response_valid() {
1638        let json = serde_json::json!({
1639            "embeddings": {
1640                "float": [
1641                    [0.5, 0.6, 0.7]
1642                ]
1643            }
1644        });
1645        let result = parse_cohere_embedding_response(&json).unwrap();
1646        assert_eq!(result, vec![0.5, 0.6, 0.7]);
1647    }
1648
1649    #[test]
1650    fn test_parse_cohere_embedding_response_empty_float() {
1651        let json = serde_json::json!({
1652            "embeddings": {
1653                "float": []
1654            }
1655        });
1656        let err = parse_cohere_embedding_response(&json).unwrap_err();
1657        let msg = format!("{err}");
1658        assert!(msg.contains("no float embeddings"), "got: {msg}");
1659    }
1660
1661    #[test]
1662    fn test_parse_cohere_embedding_response_invalid_shape() {
1663        let json = serde_json::json!({ "message": "unauthorized" });
1664        let err = parse_cohere_embedding_response(&json).unwrap_err();
1665        let msg = format!("{err}");
1666        assert!(msg.contains("Failed to parse"), "got: {msg}");
1667    }
1668
1669    #[test]
1670    fn test_parse_cohere_embedding_response_missing_float_key() {
1671        // If "float" key is absent, serde default gives empty vec.
1672        let json = serde_json::json!({
1673            "embeddings": {}
1674        });
1675        let err = parse_cohere_embedding_response(&json).unwrap_err();
1676        let msg = format!("{err}");
1677        assert!(msg.contains("no float embeddings"), "got: {msg}");
1678    }
1679
1680    #[test]
1681    fn test_parse_voyage_embedding_response_valid() {
1682        let json = serde_json::json!({
1683            "data": [
1684                {
1685                    "embedding": [0.9, 0.8, 0.7, 0.6, 0.5],
1686                    "index": 0
1687                }
1688            ]
1689        });
1690        let result = parse_voyage_embedding_response(&json).unwrap();
1691        assert_eq!(result, vec![0.9, 0.8, 0.7, 0.6, 0.5]);
1692    }
1693
1694    #[test]
1695    fn test_parse_voyage_embedding_response_empty_data() {
1696        let json = serde_json::json!({ "data": [] });
1697        let err = parse_voyage_embedding_response(&json).unwrap_err();
1698        let msg = format!("{err}");
1699        assert!(msg.contains("no embedding data"), "got: {msg}");
1700    }
1701
1702    #[test]
1703    fn test_parse_voyage_embedding_response_invalid_shape() {
1704        let json = serde_json::json!({ "error": "invalid key" });
1705        let err = parse_voyage_embedding_response(&json).unwrap_err();
1706        let msg = format!("{err}");
1707        assert!(msg.contains("Failed to parse"), "got: {msg}");
1708    }
1709
1710    // -- CachedEmbeddingProvider tests ------------------------------------
1711
1712    #[tokio::test]
1713    async fn test_cache_hit() {
1714        let local = Arc::new(LocalEmbedding::new(64));
1715        let cached = CachedEmbeddingProvider::new(local, 100);
1716
1717        let v1 = cached.embed("hello world").await.unwrap();
1718        let v2 = cached.embed("hello world").await.unwrap();
1719        assert_eq!(v1, v2);
1720
1721        let stats = cached.cache_stats().await;
1722        assert_eq!(stats.hits, 1);
1723        assert_eq!(stats.misses, 1);
1724        assert_eq!(stats.size, 1);
1725    }
1726
1727    #[tokio::test]
1728    async fn test_cache_miss_different_texts() {
1729        let local = Arc::new(LocalEmbedding::new(64));
1730        let cached = CachedEmbeddingProvider::new(local, 100);
1731
1732        let _ = cached.embed("alpha").await.unwrap();
1733        let _ = cached.embed("bravo").await.unwrap();
1734
1735        let stats = cached.cache_stats().await;
1736        assert_eq!(stats.misses, 2);
1737        assert_eq!(stats.hits, 0);
1738        assert_eq!(stats.size, 2);
1739    }
1740
1741    #[tokio::test]
1742    async fn test_cache_eviction() {
1743        let local = Arc::new(LocalEmbedding::new(64));
1744        let cached = CachedEmbeddingProvider::new(local, 2);
1745
1746        let _ = cached.embed("one").await.unwrap();
1747        let _ = cached.embed("two").await.unwrap();
1748        let _ = cached.embed("three").await.unwrap();
1749
1750        let stats = cached.cache_stats().await;
1751        // After eviction, cache should still have at most max_cache_size entries.
1752        assert!(stats.size <= 2, "size={} should be <= 2", stats.size);
1753        assert_eq!(stats.misses, 3);
1754    }
1755
1756    #[tokio::test]
1757    async fn test_cache_clear() {
1758        let local = Arc::new(LocalEmbedding::new(64));
1759        let cached = CachedEmbeddingProvider::new(local, 100);
1760
1761        let _ = cached.embed("text").await.unwrap();
1762        cached.clear().await;
1763
1764        let stats = cached.cache_stats().await;
1765        assert_eq!(stats.size, 0);
1766    }
1767
1768    #[tokio::test]
1769    async fn test_cache_dimension_delegates() {
1770        let local = Arc::new(LocalEmbedding::new(128));
1771        let cached = CachedEmbeddingProvider::new(local, 10);
1772        assert_eq!(cached.dimension(), 128);
1773    }
1774
1775    // -- BatchEmbeddingProvider tests -------------------------------------
1776
1777    #[tokio::test]
1778    async fn test_batch_embed() {
1779        let local = Arc::new(LocalEmbedding::new(64));
1780        let batch = BatchEmbeddingProvider::new(local);
1781
1782        let results = batch
1783            .embed_batch(&["hello", "world", "test"])
1784            .await
1785            .unwrap();
1786        assert_eq!(results.len(), 3);
1787        for v in &results {
1788            assert_eq!(v.len(), 64);
1789        }
1790    }
1791
1792    #[tokio::test]
1793    async fn test_batch_single_embed_delegates() {
1794        let local = Arc::new(LocalEmbedding::new(64));
1795        let batch = BatchEmbeddingProvider::new(local);
1796
1797        let v = batch.embed("hello").await.unwrap();
1798        assert_eq!(v.len(), 64);
1799    }
1800
1801    #[tokio::test]
1802    async fn test_batch_empty() {
1803        let local = Arc::new(LocalEmbedding::new(64));
1804        let batch = BatchEmbeddingProvider::new(local);
1805
1806        let results = batch.embed_batch(&[]).await.unwrap();
1807        assert!(results.is_empty());
1808    }
1809
1810    #[tokio::test]
1811    async fn test_batch_dimension_delegates() {
1812        let local = Arc::new(LocalEmbedding::new(200));
1813        let batch = BatchEmbeddingProvider::new(local);
1814        assert_eq!(batch.dimension(), 200);
1815    }
1816
1817    // -- Factory tests ----------------------------------------------------
1818
1819    #[test]
1820    fn test_factory_create_local() {
1821        let p = EmbeddingProviderFactory::create("local", "", None).unwrap();
1822        assert_eq!(p.dimension(), 256);
1823    }
1824
1825    #[test]
1826    fn test_factory_create_local_custom_dim() {
1827        let p = EmbeddingProviderFactory::create("local", "", Some("128".into())).unwrap();
1828        assert_eq!(p.dimension(), 128);
1829    }
1830
1831    #[test]
1832    fn test_factory_create_openai() {
1833        let p = EmbeddingProviderFactory::create("openai", "sk-test", None).unwrap();
1834        assert_eq!(p.dimension(), 1536);
1835    }
1836
1837    #[test]
1838    fn test_factory_create_cohere() {
1839        let p = EmbeddingProviderFactory::create("cohere", "key", None).unwrap();
1840        assert_eq!(p.dimension(), 1024);
1841    }
1842
1843    #[test]
1844    fn test_factory_create_voyage() {
1845        let p = EmbeddingProviderFactory::create("voyage", "key", None).unwrap();
1846        assert_eq!(p.dimension(), 1024);
1847    }
1848
1849    #[test]
1850    fn test_factory_unknown_provider() {
1851        let result = EmbeddingProviderFactory::create("unknown", "", None);
1852        assert!(result.is_err(), "Unknown provider should return Err");
1853    }
1854
1855    #[test]
1856    fn test_factory_available_providers() {
1857        let names = EmbeddingProviderFactory::available_providers();
1858        assert!(names.contains(&"openai"));
1859        assert!(names.contains(&"cohere"));
1860        assert!(names.contains(&"voyage"));
1861        assert!(names.contains(&"local"));
1862    }
1863
1864    // -- Config tests -----------------------------------------------------
1865
1866    #[test]
1867    fn test_config_default() {
1868        let cfg = EmbeddingConfig::default();
1869        assert_eq!(cfg.provider, "local");
1870        assert!(cfg.api_key.is_empty());
1871        assert!(cfg.model.is_none());
1872        assert!(cfg.dimensions.is_none());
1873        assert!(cfg.base_url.is_none());
1874        assert!(cfg.cache_size.is_none());
1875    }
1876
1877    #[test]
1878    fn test_config_serialize_deserialize() {
1879        let cfg = EmbeddingConfig {
1880            provider: "openai".to_string(),
1881            api_key: "sk-123".to_string(),
1882            model: Some("text-embedding-3-small".to_string()),
1883            dimensions: Some(1536),
1884            base_url: None,
1885            cache_size: Some(500),
1886        };
1887        let json = serde_json::to_string(&cfg).unwrap();
1888        let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
1889        assert_eq!(parsed.provider, "openai");
1890        assert_eq!(parsed.api_key, "sk-123");
1891        assert_eq!(parsed.dimensions, Some(1536));
1892        assert_eq!(parsed.cache_size, Some(500));
1893    }
1894
1895    #[test]
1896    fn test_config_deserialize_minimal() {
1897        let json = r#"{"provider":"local"}"#;
1898        let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
1899        assert_eq!(cfg.provider, "local");
1900        assert!(cfg.api_key.is_empty());
1901    }
1902
1903    #[tokio::test]
1904    async fn test_config_build_local() {
1905        let cfg = EmbeddingConfig::default();
1906        let provider = cfg.build().unwrap();
1907        assert_eq!(provider.dimension(), 256);
1908        let v = provider.embed("test text").await.unwrap();
1909        assert_eq!(v.len(), 256);
1910    }
1911
1912    #[tokio::test]
1913    async fn test_config_build_local_with_cache() {
1914        let cfg = EmbeddingConfig {
1915            provider: "local".to_string(),
1916            cache_size: Some(50),
1917            ..Default::default()
1918        };
1919        let provider = cfg.build().unwrap();
1920        // Dimension from local default.
1921        assert_eq!(provider.dimension(), 256);
1922        // Should work — cache wraps local.
1923        let v1 = provider.embed("cached text").await.unwrap();
1924        let v2 = provider.embed("cached text").await.unwrap();
1925        assert_eq!(v1, v2);
1926    }
1927
1928    #[tokio::test]
1929    async fn test_config_build_local_custom_dimensions() {
1930        let cfg = EmbeddingConfig {
1931            provider: "local".to_string(),
1932            dimensions: Some(512),
1933            ..Default::default()
1934        };
1935        let provider = cfg.build().unwrap();
1936        assert_eq!(provider.dimension(), 512);
1937    }
1938
1939    #[test]
1940    fn test_config_build_unknown_provider() {
1941        let cfg = EmbeddingConfig {
1942            provider: "imaginary".to_string(),
1943            ..Default::default()
1944        };
1945        assert!(cfg.build().is_err());
1946    }
1947
1948    // -- Misc / edge cases ------------------------------------------------
1949
1950    #[test]
1951    fn test_fnv_hash_deterministic() {
1952        let h1 = fnv1a_hash(b"hello world");
1953        let h2 = fnv1a_hash(b"hello world");
1954        assert_eq!(h1, h2);
1955    }
1956
1957    #[test]
1958    fn test_fnv_hash_different_inputs() {
1959        let h1 = fnv1a_hash(b"alpha");
1960        let h2 = fnv1a_hash(b"bravo");
1961        assert_ne!(h1, h2);
1962    }
1963
1964    // =====================================================================
1965    // Stub helper tests
1966    // =====================================================================
1967
1968    #[test]
1969    fn test_stub_embedding_length() {
1970        let v = stub_embedding("hello", 128);
1971        assert_eq!(v.len(), 128);
1972    }
1973
1974    #[test]
1975    fn test_stub_embedding_deterministic() {
1976        let v1 = stub_embedding("same input", 64);
1977        let v2 = stub_embedding("same input", 64);
1978        assert_eq!(v1, v2);
1979    }
1980
1981    #[test]
1982    fn test_stub_embedding_normalized() {
1983        let v = stub_embedding("the quick brown fox", 256);
1984        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1985        assert!((norm - 1.0).abs() < 0.01, "norm={norm}");
1986    }
1987
1988    #[test]
1989    fn test_stub_embedding_different_inputs_differ() {
1990        let a = stub_embedding("alpha", 64);
1991        let b = stub_embedding("bravo", 64);
1992        assert_ne!(a, b);
1993    }
1994
1995    #[test]
1996    fn test_stub_embedding_empty_text_zeroes() {
1997        let v = stub_embedding("", 32);
1998        assert_eq!(v.len(), 32);
1999        assert!(v.iter().all(|&x| x == 0.0));
2000    }
2001
2002    #[test]
2003    fn test_stub_embedding_zero_dimension_safe() {
2004        // Must not panic; helper clamps dimension to at least 1.
2005        let v = stub_embedding("hi", 0);
2006        assert_eq!(v.len(), 1);
2007    }
2008
2009    // =====================================================================
2010    // JinaEmbeddingProvider tests
2011    // =====================================================================
2012
2013    #[test]
2014    fn test_jina_default_construction() {
2015        let p = JinaEmbeddingProvider::new("jina-key");
2016        assert_eq!(p.model(), "jina-embeddings-v3");
2017        assert_eq!(p.dimension(), 1024);
2018    }
2019
2020    #[test]
2021    fn test_jina_with_model_clip() {
2022        let p = JinaEmbeddingProvider::with_model("k", "jina-clip-v2", 768);
2023        assert_eq!(p.model(), "jina-clip-v2");
2024        assert_eq!(p.dimension(), 768);
2025    }
2026
2027    #[test]
2028    fn test_jina_with_base_url() {
2029        let p = JinaEmbeddingProvider::new("k").with_base_url("https://custom.jina/v1");
2030        // Indirect check: construction succeeds and model unchanged.
2031        assert_eq!(p.model(), "jina-embeddings-v3");
2032    }
2033
2034    #[test]
2035    fn test_jina_build_payload_shape() {
2036        let p = JinaEmbeddingProvider::new("k");
2037        let payload = p.build_payload(&["hello".to_string(), "world".to_string()]);
2038        assert_eq!(payload["model"], "jina-embeddings-v3");
2039        assert_eq!(payload["input"][0], "hello");
2040        assert_eq!(payload["input"][1], "world");
2041    }
2042
2043    #[tokio::test]
2044    async fn test_jina_embed_length_matches_dimension() {
2045        let p = JinaEmbeddingProvider::new("k");
2046        #[cfg(not(feature = "http-embeddings"))]
2047        {
2048            let v = p.embed("hello jina").await.unwrap();
2049            assert_eq!(v.len(), 1024);
2050        }
2051        // When http-embeddings is enabled, we don't hit the real API in tests;
2052        // just confirm dimension() reports 1024.
2053        assert_eq!(p.dimension(), 1024);
2054    }
2055
2056    #[cfg(not(feature = "http-embeddings"))]
2057    #[tokio::test]
2058    async fn test_jina_stub_is_normalized() {
2059        let p = JinaEmbeddingProvider::new("k");
2060        let v = p.embed("some input").await.unwrap();
2061        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2062        assert!((norm - 1.0).abs() < 0.01);
2063    }
2064
2065    #[cfg(not(feature = "http-embeddings"))]
2066    #[tokio::test]
2067    async fn test_jina_stub_deterministic() {
2068        let p = JinaEmbeddingProvider::new("k");
2069        let a = p.embed("consistent").await.unwrap();
2070        let b = p.embed("consistent").await.unwrap();
2071        assert_eq!(a, b);
2072    }
2073
2074    // =====================================================================
2075    // MistralEmbedProvider tests
2076    // =====================================================================
2077
2078    #[test]
2079    fn test_mistral_default_construction() {
2080        let p = MistralEmbedProvider::new("mistral-key");
2081        assert_eq!(p.model(), "mistral-embed");
2082        assert_eq!(p.dimension(), 1024);
2083    }
2084
2085    #[test]
2086    fn test_mistral_with_model_and_dimensions() {
2087        let p = MistralEmbedProvider::with_model("k", "mistral-embed-large", 2048);
2088        assert_eq!(p.model(), "mistral-embed-large");
2089        assert_eq!(p.dimension(), 2048);
2090    }
2091
2092    #[test]
2093    fn test_mistral_build_payload_shape() {
2094        let p = MistralEmbedProvider::new("k");
2095        let payload = p.build_payload(&["alpha".to_string()]);
2096        assert_eq!(payload["model"], "mistral-embed");
2097        assert_eq!(payload["input"][0], "alpha");
2098    }
2099
2100    #[test]
2101    fn test_mistral_with_base_url() {
2102        let p = MistralEmbedProvider::new("k").with_base_url("https://custom.mistral/v1");
2103        assert_eq!(p.dimension(), 1024);
2104    }
2105
2106    #[cfg(not(feature = "http-embeddings"))]
2107    #[tokio::test]
2108    async fn test_mistral_embed_length() {
2109        let p = MistralEmbedProvider::new("k");
2110        let v = p.embed("hello mistral").await.unwrap();
2111        assert_eq!(v.len(), 1024);
2112    }
2113
2114    #[cfg(not(feature = "http-embeddings"))]
2115    #[tokio::test]
2116    async fn test_mistral_stub_normalized() {
2117        let p = MistralEmbedProvider::new("k");
2118        let v = p.embed("normalized?").await.unwrap();
2119        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2120        assert!((norm - 1.0).abs() < 0.01);
2121    }
2122
2123    // =====================================================================
2124    // NomicEmbedProvider tests
2125    // =====================================================================
2126
2127    #[test]
2128    fn test_nomic_default_construction() {
2129        let p = NomicEmbedProvider::new("nomic-key");
2130        assert_eq!(p.model(), "nomic-embed-text-v1.5");
2131        assert_eq!(p.dimension(), 768);
2132        assert_eq!(p.task_type(), "search_document");
2133    }
2134
2135    #[test]
2136    fn test_nomic_with_task_type() {
2137        let p = NomicEmbedProvider::new("k").with_task_type("search_query");
2138        assert_eq!(p.task_type(), "search_query");
2139    }
2140
2141    #[test]
2142    fn test_nomic_build_payload_shape() {
2143        let p = NomicEmbedProvider::new("k").with_task_type("clustering");
2144        let payload = p.build_payload(&["doc a".to_string(), "doc b".to_string()]);
2145        assert_eq!(payload["model"], "nomic-embed-text-v1.5");
2146        assert_eq!(payload["texts"][0], "doc a");
2147        assert_eq!(payload["texts"][1], "doc b");
2148        assert_eq!(payload["task_type"], "clustering");
2149    }
2150
2151    #[test]
2152    fn test_nomic_with_model_custom_dims() {
2153        let p = NomicEmbedProvider::with_model("k", "custom-nomic", 512);
2154        assert_eq!(p.dimension(), 512);
2155    }
2156
2157    #[cfg(not(feature = "http-embeddings"))]
2158    #[tokio::test]
2159    async fn test_nomic_embed_length() {
2160        let p = NomicEmbedProvider::new("k");
2161        let v = p.embed("nomic test").await.unwrap();
2162        assert_eq!(v.len(), 768);
2163    }
2164
2165    #[cfg(not(feature = "http-embeddings"))]
2166    #[tokio::test]
2167    async fn test_nomic_embed_normalized() {
2168        let p = NomicEmbedProvider::new("k");
2169        let v = p.embed("some text").await.unwrap();
2170        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2171        assert!((norm - 1.0).abs() < 0.01);
2172    }
2173
2174    // =====================================================================
2175    // SentenceTransformersProvider tests
2176    // =====================================================================
2177
2178    #[test]
2179    fn test_sentence_transformers_default_construction() {
2180        let p = SentenceTransformersProvider::new("hf-key");
2181        assert_eq!(p.model(), "sentence-transformers/all-MiniLM-L6-v2");
2182        assert_eq!(p.dimension(), 384);
2183    }
2184
2185    #[test]
2186    fn test_sentence_transformers_mpnet_dims() {
2187        let dims = SentenceTransformersProvider::default_dimensions(
2188            "sentence-transformers/all-mpnet-base-v2",
2189        );
2190        assert_eq!(dims, 768);
2191    }
2192
2193    #[test]
2194    fn test_sentence_transformers_multi_qa_dims() {
2195        let dims = SentenceTransformersProvider::default_dimensions(
2196            "sentence-transformers/multi-qa-mpnet-base-dot-v1",
2197        );
2198        assert_eq!(dims, 768);
2199    }
2200
2201    #[test]
2202    fn test_sentence_transformers_unknown_model_fallback() {
2203        let dims = SentenceTransformersProvider::default_dimensions("sentence-transformers/unknown");
2204        assert_eq!(dims, 384);
2205    }
2206
2207    #[test]
2208    fn test_sentence_transformers_with_model() {
2209        let p = SentenceTransformersProvider::with_model(
2210            "k",
2211            "sentence-transformers/all-mpnet-base-v2",
2212            768,
2213        );
2214        assert_eq!(p.model(), "sentence-transformers/all-mpnet-base-v2");
2215        assert_eq!(p.dimension(), 768);
2216    }
2217
2218    #[test]
2219    fn test_sentence_transformers_build_payload_shape() {
2220        let p = SentenceTransformersProvider::new("k");
2221        let payload = p.build_payload(&["hi".to_string()]);
2222        assert_eq!(payload["inputs"][0], "hi");
2223        assert_eq!(payload["options"]["wait_for_model"], true);
2224    }
2225
2226    #[test]
2227    fn test_sentence_transformers_with_base_url() {
2228        let p = SentenceTransformersProvider::new("k")
2229            .with_base_url("https://self-hosted.hf/embed");
2230        assert_eq!(p.dimension(), 384);
2231    }
2232
2233    #[cfg(not(feature = "http-embeddings"))]
2234    #[tokio::test]
2235    async fn test_sentence_transformers_embed_length() {
2236        let p = SentenceTransformersProvider::new("k");
2237        let v = p.embed("minilm test").await.unwrap();
2238        assert_eq!(v.len(), 384);
2239    }
2240
2241    #[cfg(not(feature = "http-embeddings"))]
2242    #[tokio::test]
2243    async fn test_sentence_transformers_embed_normalized() {
2244        let p = SentenceTransformersProvider::new("k");
2245        let v = p.embed("some input").await.unwrap();
2246        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2247        assert!((norm - 1.0).abs() < 0.01);
2248    }
2249
2250    // =====================================================================
2251    // TogetherEmbedProvider tests
2252    // =====================================================================
2253
2254    #[test]
2255    fn test_together_default_construction() {
2256        let p = TogetherEmbedProvider::new("together-key");
2257        assert_eq!(p.model(), "togethercomputer/m2-bert-80M-32k-retrieval");
2258        assert_eq!(p.dimension(), 768);
2259    }
2260
2261    #[test]
2262    fn test_together_with_model() {
2263        let p = TogetherEmbedProvider::with_model("k", "togethercomputer/custom", 1024);
2264        assert_eq!(p.model(), "togethercomputer/custom");
2265        assert_eq!(p.dimension(), 1024);
2266    }
2267
2268    #[test]
2269    fn test_together_build_payload_shape() {
2270        let p = TogetherEmbedProvider::new("k");
2271        let payload = p.build_payload(&["x".to_string(), "y".to_string()]);
2272        assert_eq!(payload["model"], "togethercomputer/m2-bert-80M-32k-retrieval");
2273        assert_eq!(payload["input"][0], "x");
2274        assert_eq!(payload["input"][1], "y");
2275    }
2276
2277    #[test]
2278    fn test_together_with_base_url() {
2279        let p = TogetherEmbedProvider::new("k").with_base_url("https://custom.together/v1");
2280        assert_eq!(p.dimension(), 768);
2281    }
2282
2283    #[cfg(not(feature = "http-embeddings"))]
2284    #[tokio::test]
2285    async fn test_together_embed_length() {
2286        let p = TogetherEmbedProvider::new("k");
2287        let v = p.embed("together test").await.unwrap();
2288        assert_eq!(v.len(), 768);
2289    }
2290
2291    #[cfg(not(feature = "http-embeddings"))]
2292    #[tokio::test]
2293    async fn test_together_embed_normalized() {
2294        let p = TogetherEmbedProvider::new("k");
2295        let v = p.embed("text").await.unwrap();
2296        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2297        assert!((norm - 1.0).abs() < 0.01);
2298    }
2299
2300    // =====================================================================
2301    // CohereEmbedV4Provider tests
2302    // =====================================================================
2303
2304    #[test]
2305    fn test_cohere_v4_default_construction() {
2306        let p = CohereEmbedV4Provider::new("cohere-key");
2307        assert_eq!(p.model(), "embed-english-v3.0");
2308        assert_eq!(p.dimension(), 1024);
2309        assert_eq!(p.input_type(), "search_document");
2310    }
2311
2312    #[test]
2313    fn test_cohere_v4_multilingual_model() {
2314        let p = CohereEmbedV4Provider::with_model("k", "embed-multilingual-v3.0", 1024);
2315        assert_eq!(p.model(), "embed-multilingual-v3.0");
2316        assert_eq!(p.dimension(), 1024);
2317    }
2318
2319    #[test]
2320    fn test_cohere_v4_for_search_document() {
2321        let p = CohereEmbedV4Provider::new("k").for_search_document();
2322        assert_eq!(p.input_type(), "search_document");
2323    }
2324
2325    #[test]
2326    fn test_cohere_v4_for_search_query() {
2327        let p = CohereEmbedV4Provider::new("k").for_search_query();
2328        assert_eq!(p.input_type(), "search_query");
2329    }
2330
2331    #[test]
2332    fn test_cohere_v4_with_input_type() {
2333        let p = CohereEmbedV4Provider::new("k").with_input_type("classification");
2334        assert_eq!(p.input_type(), "classification");
2335    }
2336
2337    #[test]
2338    fn test_cohere_v4_build_payload_shape_document() {
2339        let p = CohereEmbedV4Provider::new("k").for_search_document();
2340        let payload = p.build_payload(&["doc".to_string()]);
2341        assert_eq!(payload["model"], "embed-english-v3.0");
2342        assert_eq!(payload["texts"][0], "doc");
2343        assert_eq!(payload["input_type"], "search_document");
2344        assert_eq!(payload["embedding_types"][0], "float");
2345    }
2346
2347    #[test]
2348    fn test_cohere_v4_build_payload_shape_query() {
2349        let p = CohereEmbedV4Provider::new("k").for_search_query();
2350        let payload = p.build_payload(&["q".to_string()]);
2351        assert_eq!(payload["input_type"], "search_query");
2352    }
2353
2354    #[test]
2355    fn test_cohere_v4_with_base_url() {
2356        let p = CohereEmbedV4Provider::new("k").with_base_url("https://custom.cohere/v2/embed");
2357        assert_eq!(p.dimension(), 1024);
2358    }
2359
2360    #[cfg(not(feature = "http-embeddings"))]
2361    #[tokio::test]
2362    async fn test_cohere_v4_embed_length() {
2363        let p = CohereEmbedV4Provider::new("k");
2364        let v = p.embed("cohere v4 test").await.unwrap();
2365        assert_eq!(v.len(), 1024);
2366    }
2367
2368    #[cfg(not(feature = "http-embeddings"))]
2369    #[tokio::test]
2370    async fn test_cohere_v4_embed_normalized() {
2371        let p = CohereEmbedV4Provider::new("k");
2372        let v = p.embed("x").await.unwrap();
2373        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2374        assert!((norm - 1.0).abs() < 0.01);
2375    }
2376
2377    #[cfg(not(feature = "http-embeddings"))]
2378    #[tokio::test]
2379    async fn test_cohere_v4_embed_deterministic() {
2380        let p = CohereEmbedV4Provider::new("k");
2381        let a = p.embed("same").await.unwrap();
2382        let b = p.embed("same").await.unwrap();
2383        assert_eq!(a, b);
2384    }
2385
2386    // =====================================================================
2387    // Cross-provider checks
2388    // =====================================================================
2389
2390    #[test]
2391    fn test_all_new_providers_implement_embedding_provider_trait() {
2392        // Compile-time check — if these coerce into `Box<dyn EmbeddingProvider>`,
2393        // they correctly implement the trait.
2394        let _boxes: Vec<Box<dyn EmbeddingProvider>> = vec![
2395            Box::new(JinaEmbeddingProvider::new("k")),
2396            Box::new(MistralEmbedProvider::new("k")),
2397            Box::new(NomicEmbedProvider::new("k")),
2398            Box::new(SentenceTransformersProvider::new("k")),
2399            Box::new(TogetherEmbedProvider::new("k")),
2400            Box::new(CohereEmbedV4Provider::new("k")),
2401        ];
2402    }
2403
2404    #[test]
2405    fn test_new_providers_have_expected_dimensions() {
2406        assert_eq!(JinaEmbeddingProvider::new("k").dimension(), 1024);
2407        assert_eq!(MistralEmbedProvider::new("k").dimension(), 1024);
2408        assert_eq!(NomicEmbedProvider::new("k").dimension(), 768);
2409        assert_eq!(SentenceTransformersProvider::new("k").dimension(), 384);
2410        assert_eq!(TogetherEmbedProvider::new("k").dimension(), 768);
2411        assert_eq!(CohereEmbedV4Provider::new("k").dimension(), 1024);
2412    }
2413}