Skip to main content

argentor_memory/
embeddings_providers.rs

1//! Multiple embedding provider backends implementing [`EmbeddingProvider`].
2//!
3//! Includes API-backed providers ([`OpenAiEmbeddingProvider`],
4//! [`CohereEmbeddingProvider`], [`VoyageEmbeddingProvider`]) that call their
5//! respective HTTP APIs to compute real embeddings.
6//!
7//! # Feature flag: `http-embeddings`
8//!
9//! The actual HTTP calls require the **`http-embeddings`** Cargo feature, which
10//! pulls in `reqwest`. When the feature is **disabled** (the default), calling
11//! `embed()` on any API-backed provider returns a descriptive error suggesting
12//! the user either enable the feature or use [`LocalEmbedding`].
13//!
14//! ```toml
15//! # Enable real HTTP embedding calls:
16//! argentor-memory = { version = "0.1", features = ["http-embeddings"] }
17//! ```
18//!
19//! For embeddings that work out of the box without external dependencies, use
20//! the [`LocalEmbedding`] provider (deterministic, hash-based, zero API keys).
21//!
22//! Also provides fully functional utilities:
23//! [`CachedEmbeddingProvider`], [`BatchEmbeddingProvider`],
24//! [`EmbeddingProviderFactory`], and [`EmbeddingConfig`].
25
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use tokio::sync::RwLock;
32
33use argentor_core::{ArgentorError, ArgentorResult};
34
35use crate::embedding::{EmbeddingProvider, LocalEmbedding};
36
37// ---------------------------------------------------------------------------
38// FNV-1a hash (same algorithm as embedding.rs, re-implemented here to avoid
39// depending on a private function).
40// ---------------------------------------------------------------------------
41
42fn fnv1a_hash(data: &[u8]) -> u64 {
43    let mut hash: u64 = 14695981039346656037;
44    for &byte in data {
45        hash ^= byte as u64;
46        hash = hash.wrapping_mul(1099511628211);
47    }
48    hash
49}
50
51// ===========================================================================
52// API request/response structures
53// ===========================================================================
54
55/// OpenAI embeddings API request body.
56#[derive(Debug, Serialize)]
57pub struct OpenAiEmbeddingRequest {
58    /// Model identifier (e.g., `"text-embedding-3-small"`).
59    pub model: String,
60    /// Texts to embed.
61    pub input: Vec<String>,
62}
63
64/// A single embedding object from the OpenAI response.
65#[derive(Debug, Deserialize)]
66pub struct OpenAiEmbeddingObject {
67    /// The embedding vector.
68    pub embedding: Vec<f32>,
69    /// Position of this embedding in the input batch.
70    pub index: usize,
71}
72
73/// OpenAI embeddings API response body.
74#[derive(Debug, Deserialize)]
75pub struct OpenAiEmbeddingResponse {
76    /// Embedding results, one per input text.
77    pub data: Vec<OpenAiEmbeddingObject>,
78    /// Model that produced the embeddings.
79    pub model: String,
80}
81
82/// Cohere embed API request body (v2).
83#[derive(Debug, Serialize)]
84pub struct CohereEmbedRequest {
85    /// Model identifier (e.g., `"embed-english-v3.0"`).
86    pub model: String,
87    /// Texts to embed.
88    pub texts: Vec<String>,
89    /// Input type hint (e.g., `"search_document"`, `"search_query"`).
90    pub input_type: String,
91    /// Which embedding types to return.
92    pub embedding_types: Vec<String>,
93}
94
95/// Cohere embed API v2 response body.
96#[derive(Debug, Deserialize)]
97pub struct CohereEmbedResponse {
98    /// Embedding vectors keyed by type. We request `"float"`.
99    pub embeddings: CohereEmbeddingsMap,
100}
101
102/// Container for different embedding type outputs from Cohere.
103#[derive(Debug, Deserialize)]
104pub struct CohereEmbeddingsMap {
105    /// Float embeddings, one vector per input text.
106    #[serde(default)]
107    pub float: Vec<Vec<f32>>,
108}
109
110/// Voyage AI embeddings API request body.
111#[derive(Debug, Serialize)]
112pub struct VoyageEmbeddingRequest {
113    /// Model identifier.
114    pub model: String,
115    /// Texts to embed.
116    pub input: Vec<String>,
117}
118
119/// A single embedding object from the Voyage response.
120#[derive(Debug, Deserialize)]
121pub struct VoyageEmbeddingObject {
122    /// The embedding vector.
123    pub embedding: Vec<f32>,
124    /// Position of this embedding in the input batch.
125    pub index: usize,
126}
127
128/// Voyage AI embeddings API response body.
129#[derive(Debug, Deserialize)]
130pub struct VoyageEmbeddingResponse {
131    /// Embedding results, one per input text.
132    pub data: Vec<VoyageEmbeddingObject>,
133}
134
135// ===========================================================================
136// Response parsing helpers (testable without HTTP)
137// ===========================================================================
138
139/// Parse an OpenAI embedding response JSON into a single embedding vector.
140///
141/// Expects the standard OpenAI response shape with `data[0].embedding`.
142/// Returns `Err` if the response is missing required fields.
143pub fn parse_openai_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
144    let response: OpenAiEmbeddingResponse = serde_json::from_value(json.clone())
145        .map_err(|e| ArgentorError::Agent(format!("Failed to parse OpenAI response: {e}")))?;
146    response
147        .data
148        .into_iter()
149        .next()
150        .map(|obj| obj.embedding)
151        .ok_or_else(|| {
152            ArgentorError::Agent("OpenAI response contains no embedding data".to_string())
153        })
154}
155
156/// Parse a Cohere v2 embed response JSON into a single embedding vector.
157///
158/// Expects the v2 shape: `embeddings.float[0]`.
159/// Returns `Err` if the response is missing required fields.
160pub fn parse_cohere_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
161    let response: CohereEmbedResponse = serde_json::from_value(json.clone())
162        .map_err(|e| ArgentorError::Agent(format!("Failed to parse Cohere response: {e}")))?;
163    response.embeddings.float.into_iter().next().ok_or_else(|| {
164        ArgentorError::Agent("Cohere response contains no float embeddings".to_string())
165    })
166}
167
168/// Parse a Voyage AI embedding response JSON into a single embedding vector.
169///
170/// Expects the standard Voyage shape: `data[0].embedding`.
171/// Returns `Err` if the response is missing required fields.
172pub fn parse_voyage_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
173    let response: VoyageEmbeddingResponse = serde_json::from_value(json.clone())
174        .map_err(|e| ArgentorError::Agent(format!("Failed to parse Voyage response: {e}")))?;
175    response
176        .data
177        .into_iter()
178        .next()
179        .map(|obj| obj.embedding)
180        .ok_or_else(|| {
181            ArgentorError::Agent("Voyage response contains no embedding data".to_string())
182        })
183}
184
185// ===========================================================================
186// 1. OpenAiEmbeddingProvider
187// ===========================================================================
188
189/// Embedding provider backed by the OpenAI embeddings API.
190///
191/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
192/// performs a real HTTP request when the `http-embeddings` feature is enabled.
193/// Without the feature, it returns an error. For local/offline embeddings, use
194/// [`LocalEmbedding`].
195pub struct OpenAiEmbeddingProvider {
196    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
197    api_key: String,
198    model: String,
199    dimensions: usize,
200    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
201    base_url: String,
202}
203
204impl OpenAiEmbeddingProvider {
205    /// Create a new OpenAI embedding provider.
206    ///
207    /// `model` defaults to `"text-embedding-3-small"` when `None`.
208    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
209        let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
210        let dimensions = Self::default_dimensions(&model);
211        Self {
212            api_key: api_key.into(),
213            model,
214            dimensions,
215            base_url: "https://api.openai.com/v1/embeddings".to_string(),
216        }
217    }
218
219    /// Create with a custom base URL (e.g. Azure OpenAI endpoint).
220    pub fn with_base_url(
221        api_key: impl Into<String>,
222        model: Option<String>,
223        base_url: impl Into<String>,
224    ) -> Self {
225        let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
226        let dimensions = Self::default_dimensions(&model);
227        Self {
228            api_key: api_key.into(),
229            model,
230            dimensions,
231            base_url: base_url.into(),
232        }
233    }
234
235    /// Override the output dimension count.
236    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
237        self.dimensions = dimensions;
238        self
239    }
240
241    fn default_dimensions(model: &str) -> usize {
242        match model {
243            "text-embedding-3-large" => 3072,
244            "text-embedding-3-small" => 1536,
245            "text-embedding-ada-002" => 1536,
246            _ => 1536,
247        }
248    }
249
250    /// Returns the model name this provider is configured with.
251    pub fn model(&self) -> &str {
252        &self.model
253    }
254}
255
256#[async_trait]
257impl EmbeddingProvider for OpenAiEmbeddingProvider {
258    #[cfg(feature = "http-embeddings")]
259    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
260        let client = reqwest::Client::new();
261        let response = client
262            .post(&self.base_url)
263            .header("Authorization", format!("Bearer {}", self.api_key))
264            .json(&serde_json::json!({
265                "model": self.model,
266                "input": text,
267            }))
268            .send()
269            .await
270            .map_err(|e| ArgentorError::Http(format!("OpenAI embedding request failed: {e}")))?;
271
272        let status = response.status();
273        if !status.is_success() {
274            let body = response.text().await.unwrap_or_default();
275            return Err(ArgentorError::Http(format!(
276                "OpenAI API error {status}: {body}"
277            )));
278        }
279
280        let json: serde_json::Value = response.json().await.map_err(|e| {
281            ArgentorError::Http(format!("Failed to read OpenAI response body: {e}"))
282        })?;
283
284        parse_openai_embedding_response(&json)
285    }
286
287    #[cfg(not(feature = "http-embeddings"))]
288    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
289        Err(ArgentorError::Http(
290            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
291             or use LocalEmbedding for offline embeddings."
292                .to_string(),
293        ))
294    }
295
296    fn dimension(&self) -> usize {
297        self.dimensions
298    }
299}
300
301// ===========================================================================
302// 2. CohereEmbeddingProvider
303// ===========================================================================
304
305/// Embedding provider backed by the Cohere embed API (v2).
306///
307/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
308/// performs a real HTTP request when the `http-embeddings` feature is enabled.
309/// Without the feature, it returns an error. For local/offline embeddings, use
310/// [`LocalEmbedding`].
311pub struct CohereEmbeddingProvider {
312    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
313    api_key: String,
314    model: String,
315    dimensions: usize,
316    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
317    input_type: String,
318}
319
320impl CohereEmbeddingProvider {
321    /// Create a new Cohere embedding provider.
322    ///
323    /// `model` defaults to `"embed-english-v3.0"`.
324    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
325        let model = model.unwrap_or_else(|| "embed-english-v3.0".to_string());
326        let dimensions = Self::default_dimensions(&model);
327        Self {
328            api_key: api_key.into(),
329            model,
330            dimensions,
331            input_type: "search_document".to_string(),
332        }
333    }
334
335    /// Set the input type (`"search_document"` for indexing, `"search_query"` for querying).
336    pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
337        self.input_type = input_type.into();
338        self
339    }
340
341    /// Override the output dimension count.
342    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
343        self.dimensions = dimensions;
344        self
345    }
346
347    fn default_dimensions(model: &str) -> usize {
348        match model {
349            "embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
350            "embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
351            _ => 1024,
352        }
353    }
354
355    /// Returns the model name.
356    pub fn model(&self) -> &str {
357        &self.model
358    }
359
360    /// Returns the current input type.
361    pub fn input_type(&self) -> &str {
362        &self.input_type
363    }
364}
365
366#[async_trait]
367impl EmbeddingProvider for CohereEmbeddingProvider {
368    #[cfg(feature = "http-embeddings")]
369    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
370        let client = reqwest::Client::new();
371        let response = client
372            .post("https://api.cohere.com/v2/embed")
373            .header("Authorization", format!("Bearer {}", self.api_key))
374            .json(&serde_json::json!({
375                "model": self.model,
376                "texts": [text],
377                "input_type": self.input_type,
378                "embedding_types": ["float"],
379            }))
380            .send()
381            .await
382            .map_err(|e| ArgentorError::Http(format!("Cohere embedding request failed: {e}")))?;
383
384        let status = response.status();
385        if !status.is_success() {
386            let body = response.text().await.unwrap_or_default();
387            return Err(ArgentorError::Http(format!(
388                "Cohere API error {status}: {body}"
389            )));
390        }
391
392        let json: serde_json::Value = response.json().await.map_err(|e| {
393            ArgentorError::Http(format!("Failed to read Cohere response body: {e}"))
394        })?;
395
396        parse_cohere_embedding_response(&json)
397    }
398
399    #[cfg(not(feature = "http-embeddings"))]
400    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
401        Err(ArgentorError::Http(
402            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
403             or use LocalEmbedding for offline embeddings."
404                .to_string(),
405        ))
406    }
407
408    fn dimension(&self) -> usize {
409        self.dimensions
410    }
411}
412
413// ===========================================================================
414// 3. VoyageEmbeddingProvider
415// ===========================================================================
416
417/// Embedding provider backed by the Voyage AI embeddings API.
418///
419/// Stores the API key, model, and dimension configuration. Calling [`embed()`]
420/// performs a real HTTP request when the `http-embeddings` feature is enabled.
421/// Without the feature, it returns an error. For local/offline embeddings, use
422/// [`LocalEmbedding`].
423pub struct VoyageEmbeddingProvider {
424    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
425    api_key: String,
426    model: String,
427    dimensions: usize,
428}
429
430impl VoyageEmbeddingProvider {
431    /// Create a new Voyage embedding provider.
432    ///
433    /// `model` defaults to `"voyage-2"`.
434    pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
435        let model = model.unwrap_or_else(|| "voyage-2".to_string());
436        let dimensions = Self::default_dimensions(&model);
437        Self {
438            api_key: api_key.into(),
439            model,
440            dimensions,
441        }
442    }
443
444    /// Override the output dimension count.
445    pub fn with_dimensions(mut self, dimensions: usize) -> Self {
446        self.dimensions = dimensions;
447        self
448    }
449
450    fn default_dimensions(model: &str) -> usize {
451        match model {
452            "voyage-2" | "voyage-large-2" => 1024,
453            "voyage-lite-02-instruct" => 1024,
454            "voyage-3" => 1024,
455            "voyage-code-2" => 1536,
456            _ => 1024,
457        }
458    }
459
460    /// Returns the model name.
461    pub fn model(&self) -> &str {
462        &self.model
463    }
464}
465
466#[async_trait]
467impl EmbeddingProvider for VoyageEmbeddingProvider {
468    #[cfg(feature = "http-embeddings")]
469    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
470        let client = reqwest::Client::new();
471        let response = client
472            .post("https://api.voyageai.com/v1/embeddings")
473            .header("Authorization", format!("Bearer {}", self.api_key))
474            .json(&serde_json::json!({
475                "model": self.model,
476                "input": [text],
477            }))
478            .send()
479            .await
480            .map_err(|e| ArgentorError::Http(format!("Voyage embedding request failed: {e}")))?;
481
482        let status = response.status();
483        if !status.is_success() {
484            let body = response.text().await.unwrap_or_default();
485            return Err(ArgentorError::Http(format!(
486                "Voyage API error {status}: {body}"
487            )));
488        }
489
490        let json: serde_json::Value = response.json().await.map_err(|e| {
491            ArgentorError::Http(format!("Failed to read Voyage response body: {e}"))
492        })?;
493
494        parse_voyage_embedding_response(&json)
495    }
496
497    #[cfg(not(feature = "http-embeddings"))]
498    async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
499        Err(ArgentorError::Http(
500            "HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
501             or use LocalEmbedding for offline embeddings."
502                .to_string(),
503        ))
504    }
505
506    fn dimension(&self) -> usize {
507        self.dimensions
508    }
509}
510
511// ===========================================================================
512// 4. CachedEmbeddingProvider
513// ===========================================================================
514
515/// Statistics about cache usage.
516#[derive(Debug, Clone, Default)]
517pub struct CacheStats {
518    /// Number of cache hits.
519    pub hits: u64,
520    /// Number of cache misses.
521    pub misses: u64,
522    /// Current number of entries in the cache.
523    pub size: usize,
524}
525
526/// Wraps any [`EmbeddingProvider`] with a thread-safe in-memory LRU-ish cache.
527///
528/// Embeddings are cached by FNV-1a hash of the input text. When the cache
529/// exceeds `max_cache_size`, the oldest entry (by insertion order) is evicted.
530pub struct CachedEmbeddingProvider {
531    inner: Arc<dyn EmbeddingProvider>,
532    cache: Arc<RwLock<HashMap<u64, Vec<f32>>>>,
533    max_cache_size: usize,
534    stats: Arc<RwLock<CacheStats>>,
535}
536
537impl CachedEmbeddingProvider {
538    /// Wrap an existing provider with caching.
539    pub fn new(inner: Arc<dyn EmbeddingProvider>, max_cache_size: usize) -> Self {
540        Self {
541            inner,
542            cache: Arc::new(RwLock::new(HashMap::new())),
543            max_cache_size,
544            stats: Arc::new(RwLock::new(CacheStats::default())),
545        }
546    }
547
548    /// Returns current cache statistics.
549    pub async fn cache_stats(&self) -> CacheStats {
550        self.stats.read().await.clone()
551    }
552
553    /// Clears the cache and resets statistics.
554    pub async fn clear(&self) {
555        self.cache.write().await.clear();
556        let mut stats = self.stats.write().await;
557        stats.size = 0;
558    }
559
560    fn text_hash(text: &str) -> u64 {
561        fnv1a_hash(text.as_bytes())
562    }
563}
564
565#[async_trait]
566impl EmbeddingProvider for CachedEmbeddingProvider {
567    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
568        let key = Self::text_hash(text);
569
570        // Check cache (read lock).
571        {
572            let cache = self.cache.read().await;
573            if let Some(cached) = cache.get(&key) {
574                let mut stats = self.stats.write().await;
575                stats.hits += 1;
576                return Ok(cached.clone());
577            }
578        }
579
580        // Cache miss — compute embedding.
581        let embedding = self.inner.embed(text).await?;
582
583        // Insert into cache (write lock).
584        {
585            let mut cache = self.cache.write().await;
586
587            // Evict if at capacity.
588            if cache.len() >= self.max_cache_size {
589                // Remove an arbitrary entry (HashMap iteration order is random,
590                // which acts as a simple eviction strategy).
591                if let Some(&evict_key) = cache.keys().next() {
592                    cache.remove(&evict_key);
593                }
594            }
595
596            cache.insert(key, embedding.clone());
597
598            let mut stats = self.stats.write().await;
599            stats.misses += 1;
600            stats.size = cache.len();
601        }
602
603        Ok(embedding)
604    }
605
606    fn dimension(&self) -> usize {
607        self.inner.dimension()
608    }
609}
610
611// ===========================================================================
612// 5. BatchEmbeddingProvider
613// ===========================================================================
614
615/// Wraps any [`EmbeddingProvider`] to expose a convenience batch method.
616///
617/// Delegates to the inner provider's `embed_batch` (which by default calls
618/// `embed` sequentially). Providers that support native batching can override
619/// `embed_batch` on the trait for better performance.
620pub struct BatchEmbeddingProvider {
621    inner: Arc<dyn EmbeddingProvider>,
622}
623
624impl BatchEmbeddingProvider {
625    /// Wrap an existing provider for batch operations.
626    pub fn new(inner: Arc<dyn EmbeddingProvider>) -> Self {
627        Self { inner }
628    }
629
630    /// Embed multiple texts, returning one vector per input.
631    pub async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
632        self.inner.embed_batch(texts).await
633    }
634}
635
636#[async_trait]
637impl EmbeddingProvider for BatchEmbeddingProvider {
638    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
639        self.inner.embed(text).await
640    }
641
642    async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
643        self.inner.embed_batch(texts).await
644    }
645
646    fn dimension(&self) -> usize {
647        self.inner.dimension()
648    }
649}
650
651// ===========================================================================
652// 6. EmbeddingProviderFactory
653// ===========================================================================
654
655/// Factory that creates [`EmbeddingProvider`] instances by name.
656pub struct EmbeddingProviderFactory;
657
658impl EmbeddingProviderFactory {
659    /// Create an embedding provider from its string name.
660    ///
661    /// Supported names: `"openai"`, `"cohere"`, `"voyage"`, `"local"`.
662    pub fn create(
663        provider_name: &str,
664        api_key: impl Into<String>,
665        model: Option<String>,
666    ) -> ArgentorResult<Box<dyn EmbeddingProvider>> {
667        let api_key = api_key.into();
668        match provider_name {
669            "openai" => Ok(Box::new(OpenAiEmbeddingProvider::new(api_key, model))),
670            "cohere" => Ok(Box::new(CohereEmbeddingProvider::new(api_key, model))),
671            "voyage" => Ok(Box::new(VoyageEmbeddingProvider::new(api_key, model))),
672            "local" => {
673                let dim = model
674                    .as_deref()
675                    .and_then(|m| m.parse::<usize>().ok())
676                    .unwrap_or(256);
677                Ok(Box::new(LocalEmbedding::new(dim)))
678            }
679            other => Err(ArgentorError::Config(format!(
680                "Unknown embedding provider: {other}"
681            ))),
682        }
683    }
684
685    /// List all supported provider names.
686    pub fn available_providers() -> &'static [&'static str] {
687        &["openai", "cohere", "voyage", "local"]
688    }
689}
690
691// ===========================================================================
692// 7. EmbeddingConfig
693// ===========================================================================
694
695/// Serializable configuration for constructing an embedding provider.
696#[derive(Debug, Clone, Serialize, Deserialize)]
697pub struct EmbeddingConfig {
698    /// Provider name (`"openai"`, `"cohere"`, `"voyage"`, `"local"`).
699    pub provider: String,
700    /// API key (ignored for `"local"`).
701    #[serde(default)]
702    pub api_key: String,
703    /// Model name override.
704    #[serde(default)]
705    pub model: Option<String>,
706    /// Override for output dimensions.
707    #[serde(default)]
708    pub dimensions: Option<usize>,
709    /// Custom API base URL (e.g. Azure OpenAI).
710    #[serde(default)]
711    pub base_url: Option<String>,
712    /// If set, wraps the provider with a [`CachedEmbeddingProvider`].
713    #[serde(default)]
714    pub cache_size: Option<usize>,
715}
716
717impl EmbeddingConfig {
718    /// Build an [`EmbeddingProvider`] from this configuration.
719    ///
720    /// Returns an `Arc`-wrapped provider, optionally wrapped in a cache layer.
721    pub fn build(&self) -> ArgentorResult<Arc<dyn EmbeddingProvider>> {
722        let mut provider: Box<dyn EmbeddingProvider> = match self.provider.as_str() {
723            "openai" => {
724                let mut p = if let Some(ref url) = self.base_url {
725                    OpenAiEmbeddingProvider::with_base_url(&self.api_key, self.model.clone(), url)
726                } else {
727                    OpenAiEmbeddingProvider::new(&self.api_key, self.model.clone())
728                };
729                if let Some(dim) = self.dimensions {
730                    p = p.with_dimensions(dim);
731                }
732                Box::new(p)
733            }
734            "cohere" => {
735                let mut p = CohereEmbeddingProvider::new(&self.api_key, self.model.clone());
736                if let Some(dim) = self.dimensions {
737                    p = p.with_dimensions(dim);
738                }
739                Box::new(p)
740            }
741            "voyage" => {
742                let mut p = VoyageEmbeddingProvider::new(&self.api_key, self.model.clone());
743                if let Some(dim) = self.dimensions {
744                    p = p.with_dimensions(dim);
745                }
746                Box::new(p)
747            }
748            "local" => {
749                let dim = self.dimensions.unwrap_or(256);
750                Box::new(LocalEmbedding::new(dim))
751            }
752            other => {
753                return Err(ArgentorError::Config(format!(
754                    "Unknown embedding provider: {other}"
755                )));
756            }
757        };
758
759        // Wrap with dimensions override if the provider itself doesn't support
760        // it natively (already handled above for each provider).
761        let _ = &mut provider; // suppress unused-mut if branch is empty
762
763        let arc: Arc<dyn EmbeddingProvider> = Arc::from(provider);
764
765        // Optionally wrap with caching.
766        if let Some(cache_size) = self.cache_size {
767            Ok(Arc::new(CachedEmbeddingProvider::new(arc, cache_size)))
768        } else {
769            Ok(arc)
770        }
771    }
772}
773
774impl Default for EmbeddingConfig {
775    fn default() -> Self {
776        Self {
777            provider: "local".to_string(),
778            api_key: String::new(),
779            model: None,
780            dimensions: None,
781            base_url: None,
782            cache_size: None,
783        }
784    }
785}
786
787// ===========================================================================
788// Shared helpers for new providers (stub + payload builders)
789// ===========================================================================
790
791/// Build a deterministic stub embedding from the input text.
792///
793/// Used by all new providers when the `http-embeddings` feature is disabled,
794/// so tests and offline usage still get a usable L2-normalized vector.
795/// Also exercised by unit tests even when the HTTP feature is on.
796#[cfg_attr(all(feature = "http-embeddings", not(test)), allow(dead_code))]
797fn stub_embedding(text: &str, dimensions: usize) -> Vec<f32> {
798    let dim = dimensions.max(1);
799    let mut v = vec![0.0f32; dim];
800    for (i, b) in text.bytes().enumerate() {
801        v[i % dim] += (b as f32) / 255.0;
802    }
803    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
804    if norm > 0.0 {
805        for x in &mut v {
806            *x /= norm;
807        }
808    }
809    v
810}
811
812// ===========================================================================
813// 8. JinaEmbeddingProvider
814// ===========================================================================
815
816/// Embedding provider backed by the Jina AI embeddings API.
817///
818/// Default model: `jina-embeddings-v3` (1024 dims). Also supports multimodal
819/// models such as `jina-clip-v2`. Calling [`embed()`] performs a real HTTP
820/// request when the `http-embeddings` feature is enabled. Without the feature,
821/// returns a deterministic stub vector (useful for offline tests).
822pub struct JinaEmbeddingProvider {
823    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
824    api_key: String,
825    model: String,
826    dimensions: usize,
827    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
828    base_url: String,
829}
830
831impl JinaEmbeddingProvider {
832    /// Create a new Jina provider with the default model (`jina-embeddings-v3`).
833    pub fn new(api_key: impl Into<String>) -> Self {
834        Self::with_model(api_key, "jina-embeddings-v3", 1024)
835    }
836
837    /// Create with an explicit model and dimension override.
838    pub fn with_model(
839        api_key: impl Into<String>,
840        model: impl Into<String>,
841        dimensions: usize,
842    ) -> Self {
843        Self {
844            api_key: api_key.into(),
845            model: model.into(),
846            dimensions,
847            base_url: "https://api.jina.ai/v1/embeddings".to_string(),
848        }
849    }
850
851    /// Override the API base URL.
852    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
853        self.base_url = base_url.into();
854        self
855    }
856
857    /// Returns the configured model name.
858    pub fn model(&self) -> &str {
859        &self.model
860    }
861
862    /// Build the request payload for the Jina embeddings API.
863    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
864        serde_json::json!({
865            "model": self.model,
866            "input": texts,
867        })
868    }
869}
870
871#[async_trait]
872impl EmbeddingProvider for JinaEmbeddingProvider {
873    #[cfg(feature = "http-embeddings")]
874    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
875        let client = reqwest::Client::new();
876        let payload = self.build_payload(&[text.to_string()]);
877        let response = client
878            .post(&self.base_url)
879            .header("Authorization", format!("Bearer {}", self.api_key))
880            .json(&payload)
881            .send()
882            .await
883            .map_err(|e| ArgentorError::Http(format!("Jina embedding request failed: {e}")))?;
884
885        let status = response.status();
886        if !status.is_success() {
887            let body = response.text().await.unwrap_or_default();
888            return Err(ArgentorError::Http(format!(
889                "Jina API error {status}: {body}"
890            )));
891        }
892
893        let json: serde_json::Value = response
894            .json()
895            .await
896            .map_err(|e| ArgentorError::Http(format!("Failed to read Jina response body: {e}")))?;
897
898        // Jina follows the OpenAI-compatible `data[].embedding` shape.
899        parse_openai_embedding_response(&json)
900    }
901
902    #[cfg(not(feature = "http-embeddings"))]
903    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
904        Ok(stub_embedding(text, self.dimensions))
905    }
906
907    fn dimension(&self) -> usize {
908        self.dimensions
909    }
910}
911
912// ===========================================================================
913// 9. MistralEmbedProvider
914// ===========================================================================
915
916/// Embedding provider backed by the Mistral AI embeddings API.
917///
918/// Default model: `mistral-embed` (1024 dims). Mistral's embedding endpoint
919/// follows an OpenAI-compatible request/response shape.
920pub struct MistralEmbedProvider {
921    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
922    api_key: String,
923    model: String,
924    dimensions: usize,
925    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
926    base_url: String,
927}
928
929impl MistralEmbedProvider {
930    /// Create a new Mistral provider with the default model (`mistral-embed`).
931    pub fn new(api_key: impl Into<String>) -> Self {
932        Self::with_model(api_key, "mistral-embed", 1024)
933    }
934
935    /// Create with an explicit model and dimension override.
936    pub fn with_model(
937        api_key: impl Into<String>,
938        model: impl Into<String>,
939        dimensions: usize,
940    ) -> Self {
941        Self {
942            api_key: api_key.into(),
943            model: model.into(),
944            dimensions,
945            base_url: "https://api.mistral.ai/v1/embeddings".to_string(),
946        }
947    }
948
949    /// Override the API base URL.
950    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
951        self.base_url = base_url.into();
952        self
953    }
954
955    /// Returns the configured model name.
956    pub fn model(&self) -> &str {
957        &self.model
958    }
959
960    /// Build the request payload for the Mistral embeddings API.
961    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
962        serde_json::json!({
963            "model": self.model,
964            "input": texts,
965        })
966    }
967}
968
969#[async_trait]
970impl EmbeddingProvider for MistralEmbedProvider {
971    #[cfg(feature = "http-embeddings")]
972    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
973        let client = reqwest::Client::new();
974        let payload = self.build_payload(&[text.to_string()]);
975        let response = client
976            .post(&self.base_url)
977            .header("Authorization", format!("Bearer {}", self.api_key))
978            .json(&payload)
979            .send()
980            .await
981            .map_err(|e| ArgentorError::Http(format!("Mistral embedding request failed: {e}")))?;
982
983        let status = response.status();
984        if !status.is_success() {
985            let body = response.text().await.unwrap_or_default();
986            return Err(ArgentorError::Http(format!(
987                "Mistral API error {status}: {body}"
988            )));
989        }
990
991        let json: serde_json::Value = response.json().await.map_err(|e| {
992            ArgentorError::Http(format!("Failed to read Mistral response body: {e}"))
993        })?;
994
995        parse_openai_embedding_response(&json)
996    }
997
998    #[cfg(not(feature = "http-embeddings"))]
999    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1000        Ok(stub_embedding(text, self.dimensions))
1001    }
1002
1003    fn dimension(&self) -> usize {
1004        self.dimensions
1005    }
1006}
1007
1008// ===========================================================================
1009// 10. NomicEmbedProvider
1010// ===========================================================================
1011
1012/// Embedding provider backed by the Nomic Atlas embeddings API.
1013///
1014/// Default model: `nomic-embed-text-v1.5` (768 dims). The Nomic endpoint
1015/// accepts an array of texts under the `texts` key and responds with
1016/// `{ "embeddings": [ [...], ... ] }`.
1017pub struct NomicEmbedProvider {
1018    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1019    api_key: String,
1020    model: String,
1021    dimensions: usize,
1022    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1023    base_url: String,
1024    task_type: String,
1025}
1026
1027impl NomicEmbedProvider {
1028    /// Create a new Nomic provider with the default model (`nomic-embed-text-v1.5`).
1029    pub fn new(api_key: impl Into<String>) -> Self {
1030        Self::with_model(api_key, "nomic-embed-text-v1.5", 768)
1031    }
1032
1033    /// Create with an explicit model and dimension override.
1034    pub fn with_model(
1035        api_key: impl Into<String>,
1036        model: impl Into<String>,
1037        dimensions: usize,
1038    ) -> Self {
1039        Self {
1040            api_key: api_key.into(),
1041            model: model.into(),
1042            dimensions,
1043            base_url: "https://api-atlas.nomic.ai/v1/embedding/text".to_string(),
1044            task_type: "search_document".to_string(),
1045        }
1046    }
1047
1048    /// Set the `task_type` (`search_document`, `search_query`, `clustering`, `classification`).
1049    pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
1050        self.task_type = task_type.into();
1051        self
1052    }
1053
1054    /// Returns the configured model name.
1055    pub fn model(&self) -> &str {
1056        &self.model
1057    }
1058
1059    /// Returns the current task type.
1060    pub fn task_type(&self) -> &str {
1061        &self.task_type
1062    }
1063
1064    /// Build the request payload for the Nomic embeddings API.
1065    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1066        serde_json::json!({
1067            "model": self.model,
1068            "texts": texts,
1069            "task_type": self.task_type,
1070        })
1071    }
1072}
1073
1074#[async_trait]
1075impl EmbeddingProvider for NomicEmbedProvider {
1076    #[cfg(feature = "http-embeddings")]
1077    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1078        let client = reqwest::Client::new();
1079        let payload = self.build_payload(&[text.to_string()]);
1080        let response = client
1081            .post(&self.base_url)
1082            .header("Authorization", format!("Bearer {}", self.api_key))
1083            .json(&payload)
1084            .send()
1085            .await
1086            .map_err(|e| ArgentorError::Http(format!("Nomic embedding request failed: {e}")))?;
1087
1088        let status = response.status();
1089        if !status.is_success() {
1090            let body = response.text().await.unwrap_or_default();
1091            return Err(ArgentorError::Http(format!(
1092                "Nomic API error {status}: {body}"
1093            )));
1094        }
1095
1096        let json: serde_json::Value = response
1097            .json()
1098            .await
1099            .map_err(|e| ArgentorError::Http(format!("Failed to read Nomic response body: {e}")))?;
1100
1101        // Nomic response shape: { "embeddings": [[...]] }
1102        let embeddings = json
1103            .get("embeddings")
1104            .and_then(|v| v.as_array())
1105            .ok_or_else(|| {
1106                ArgentorError::Agent("Nomic response missing 'embeddings' array".to_string())
1107            })?;
1108        let first = embeddings.first().ok_or_else(|| {
1109            ArgentorError::Agent("Nomic response contains no embedding vectors".to_string())
1110        })?;
1111        let vec: Vec<f32> = serde_json::from_value(first.clone()).map_err(|e| {
1112            ArgentorError::Agent(format!("Failed to parse Nomic embedding vector: {e}"))
1113        })?;
1114        Ok(vec)
1115    }
1116
1117    #[cfg(not(feature = "http-embeddings"))]
1118    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1119        Ok(stub_embedding(text, self.dimensions))
1120    }
1121
1122    fn dimension(&self) -> usize {
1123        self.dimensions
1124    }
1125}
1126
1127// ===========================================================================
1128// 11. SentenceTransformersProvider (via Hugging Face Inference API)
1129// ===========================================================================
1130
1131/// Embedding provider backed by the Hugging Face Inference API for
1132/// `sentence-transformers/*` models.
1133///
1134/// Default model: `sentence-transformers/all-MiniLM-L6-v2` (384 dims).
1135/// Also supports `all-mpnet-base-v2` (768) and `multi-qa-mpnet-base-dot-v1` (768).
1136pub struct SentenceTransformersProvider {
1137    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1138    api_key: String,
1139    model: String,
1140    dimensions: usize,
1141    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1142    base_url: String,
1143}
1144
1145impl SentenceTransformersProvider {
1146    /// Create a new provider with the default model (`all-MiniLM-L6-v2`, 384 dims).
1147    pub fn new(api_key: impl Into<String>) -> Self {
1148        Self::with_model(api_key, "sentence-transformers/all-MiniLM-L6-v2", 384)
1149    }
1150
1151    /// Create with an explicit model and dimension override.
1152    pub fn with_model(
1153        api_key: impl Into<String>,
1154        model: impl Into<String>,
1155        dimensions: usize,
1156    ) -> Self {
1157        let model = model.into();
1158        let base_url =
1159            format!("https://api-inference.huggingface.co/pipeline/feature-extraction/{model}");
1160        Self {
1161            api_key: api_key.into(),
1162            model,
1163            dimensions,
1164            base_url,
1165        }
1166    }
1167
1168    /// Override the API base URL (useful for self-hosted HF inference endpoints).
1169    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1170        self.base_url = base_url.into();
1171        self
1172    }
1173
1174    /// Returns the configured model name.
1175    pub fn model(&self) -> &str {
1176        &self.model
1177    }
1178
1179    /// Default dimension for well-known sentence-transformer models.
1180    pub fn default_dimensions(model: &str) -> usize {
1181        match model {
1182            "sentence-transformers/all-MiniLM-L6-v2" => 384,
1183            "sentence-transformers/all-mpnet-base-v2"
1184            | "sentence-transformers/multi-qa-mpnet-base-dot-v1" => 768,
1185            _ => 384,
1186        }
1187    }
1188
1189    /// Build the request payload for the HF Inference API.
1190    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1191        serde_json::json!({
1192            "inputs": texts,
1193            "options": { "wait_for_model": true },
1194        })
1195    }
1196}
1197
1198#[async_trait]
1199impl EmbeddingProvider for SentenceTransformersProvider {
1200    #[cfg(feature = "http-embeddings")]
1201    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1202        let client = reqwest::Client::new();
1203        let payload = self.build_payload(&[text.to_string()]);
1204        let response = client
1205            .post(&self.base_url)
1206            .header("Authorization", format!("Bearer {}", self.api_key))
1207            .json(&payload)
1208            .send()
1209            .await
1210            .map_err(|e| {
1211                ArgentorError::Http(format!("HuggingFace embedding request failed: {e}"))
1212            })?;
1213
1214        let status = response.status();
1215        if !status.is_success() {
1216            let body = response.text().await.unwrap_or_default();
1217            return Err(ArgentorError::Http(format!(
1218                "HuggingFace API error {status}: {body}"
1219            )));
1220        }
1221
1222        let json: serde_json::Value = response.json().await.map_err(|e| {
1223            ArgentorError::Http(format!("Failed to read HuggingFace response body: {e}"))
1224        })?;
1225
1226        // HF feature-extraction returns either `[[f32; D]]` (batch) or `[f32; D]` (single).
1227        match &json {
1228            serde_json::Value::Array(arr)
1229                if arr.first().is_some_and(serde_json::Value::is_array) =>
1230            {
1231                let first = arr.first().cloned().ok_or_else(|| {
1232                    ArgentorError::Agent("HuggingFace response empty".to_string())
1233                })?;
1234                serde_json::from_value(first)
1235                    .map_err(|e| ArgentorError::Agent(format!("Failed to parse HF vector: {e}")))
1236            }
1237            serde_json::Value::Array(_) => serde_json::from_value(json)
1238                .map_err(|e| ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))),
1239            _ => Err(ArgentorError::Agent(
1240                "HuggingFace response is not an array".to_string(),
1241            )),
1242        }
1243    }
1244
1245    #[cfg(not(feature = "http-embeddings"))]
1246    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1247        Ok(stub_embedding(text, self.dimensions))
1248    }
1249
1250    fn dimension(&self) -> usize {
1251        self.dimensions
1252    }
1253}
1254
1255// ===========================================================================
1256// 12. TogetherEmbedProvider
1257// ===========================================================================
1258
1259/// Embedding provider backed by the Together AI embeddings API.
1260///
1261/// Default model: `togethercomputer/m2-bert-80M-32k-retrieval` (768 dims).
1262/// Together uses an OpenAI-compatible request/response shape.
1263pub struct TogetherEmbedProvider {
1264    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1265    api_key: String,
1266    model: String,
1267    dimensions: usize,
1268    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1269    base_url: String,
1270}
1271
1272impl TogetherEmbedProvider {
1273    /// Create a new Together provider with the default BERT retrieval model.
1274    pub fn new(api_key: impl Into<String>) -> Self {
1275        Self::with_model(api_key, "togethercomputer/m2-bert-80M-32k-retrieval", 768)
1276    }
1277
1278    /// Create with an explicit model and dimension override.
1279    pub fn with_model(
1280        api_key: impl Into<String>,
1281        model: impl Into<String>,
1282        dimensions: usize,
1283    ) -> Self {
1284        Self {
1285            api_key: api_key.into(),
1286            model: model.into(),
1287            dimensions,
1288            base_url: "https://api.together.xyz/v1/embeddings".to_string(),
1289        }
1290    }
1291
1292    /// Override the API base URL.
1293    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1294        self.base_url = base_url.into();
1295        self
1296    }
1297
1298    /// Returns the configured model name.
1299    pub fn model(&self) -> &str {
1300        &self.model
1301    }
1302
1303    /// Build the request payload for the Together API.
1304    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1305        serde_json::json!({
1306            "model": self.model,
1307            "input": texts,
1308        })
1309    }
1310}
1311
1312#[async_trait]
1313impl EmbeddingProvider for TogetherEmbedProvider {
1314    #[cfg(feature = "http-embeddings")]
1315    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1316        let client = reqwest::Client::new();
1317        let payload = self.build_payload(&[text.to_string()]);
1318        let response = client
1319            .post(&self.base_url)
1320            .header("Authorization", format!("Bearer {}", self.api_key))
1321            .json(&payload)
1322            .send()
1323            .await
1324            .map_err(|e| ArgentorError::Http(format!("Together embedding request failed: {e}")))?;
1325
1326        let status = response.status();
1327        if !status.is_success() {
1328            let body = response.text().await.unwrap_or_default();
1329            return Err(ArgentorError::Http(format!(
1330                "Together API error {status}: {body}"
1331            )));
1332        }
1333
1334        let json: serde_json::Value = response.json().await.map_err(|e| {
1335            ArgentorError::Http(format!("Failed to read Together response body: {e}"))
1336        })?;
1337
1338        parse_openai_embedding_response(&json)
1339    }
1340
1341    #[cfg(not(feature = "http-embeddings"))]
1342    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1343        Ok(stub_embedding(text, self.dimensions))
1344    }
1345
1346    fn dimension(&self) -> usize {
1347        self.dimensions
1348    }
1349}
1350
1351// ===========================================================================
1352// 13. CohereEmbedV4Provider (newer v4 embed endpoint)
1353// ===========================================================================
1354
1355/// Embedding provider backed by the Cohere v2 embed API (labeled "v4" here
1356/// to disambiguate from the existing [`CohereEmbeddingProvider`] and to
1357/// mirror the naming in higher-level integrations).
1358///
1359/// Differs from [`CohereEmbeddingProvider`] in that it exposes explicit
1360/// `input_type` helpers (`for_search_document`, `for_search_query`) and
1361/// supports `embed-english-v3.0` / `embed-multilingual-v3.0` at 1024 dims.
1362pub struct CohereEmbedV4Provider {
1363    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1364    api_key: String,
1365    model: String,
1366    dimensions: usize,
1367    input_type: String,
1368    #[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
1369    base_url: String,
1370}
1371
1372impl CohereEmbedV4Provider {
1373    /// Create a new v4 provider with the default model (`embed-english-v3.0`).
1374    pub fn new(api_key: impl Into<String>) -> Self {
1375        Self::with_model(api_key, "embed-english-v3.0", 1024)
1376    }
1377
1378    /// Create with an explicit model and dimension override.
1379    pub fn with_model(
1380        api_key: impl Into<String>,
1381        model: impl Into<String>,
1382        dimensions: usize,
1383    ) -> Self {
1384        Self {
1385            api_key: api_key.into(),
1386            model: model.into(),
1387            dimensions,
1388            input_type: "search_document".to_string(),
1389            base_url: "https://api.cohere.com/v2/embed".to_string(),
1390        }
1391    }
1392
1393    /// Configure this provider for indexing documents (`search_document`).
1394    pub fn for_search_document(mut self) -> Self {
1395        self.input_type = "search_document".to_string();
1396        self
1397    }
1398
1399    /// Configure this provider for querying (`search_query`).
1400    pub fn for_search_query(mut self) -> Self {
1401        self.input_type = "search_query".to_string();
1402        self
1403    }
1404
1405    /// Set an arbitrary `input_type` string.
1406    pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
1407        self.input_type = input_type.into();
1408        self
1409    }
1410
1411    /// Override the API base URL.
1412    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1413        self.base_url = base_url.into();
1414        self
1415    }
1416
1417    /// Returns the configured model name.
1418    pub fn model(&self) -> &str {
1419        &self.model
1420    }
1421
1422    /// Returns the current input type.
1423    pub fn input_type(&self) -> &str {
1424        &self.input_type
1425    }
1426
1427    /// Build the request payload for the Cohere v2 embed endpoint.
1428    pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
1429        serde_json::json!({
1430            "model": self.model,
1431            "texts": texts,
1432            "input_type": self.input_type,
1433            "embedding_types": ["float"],
1434        })
1435    }
1436}
1437
1438#[async_trait]
1439impl EmbeddingProvider for CohereEmbedV4Provider {
1440    #[cfg(feature = "http-embeddings")]
1441    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1442        let client = reqwest::Client::new();
1443        let payload = self.build_payload(&[text.to_string()]);
1444        let response = client
1445            .post(&self.base_url)
1446            .header("Authorization", format!("Bearer {}", self.api_key))
1447            .json(&payload)
1448            .send()
1449            .await
1450            .map_err(|e| ArgentorError::Http(format!("Cohere v4 embedding request failed: {e}")))?;
1451
1452        let status = response.status();
1453        if !status.is_success() {
1454            let body = response.text().await.unwrap_or_default();
1455            return Err(ArgentorError::Http(format!(
1456                "Cohere v4 API error {status}: {body}"
1457            )));
1458        }
1459
1460        let json: serde_json::Value = response.json().await.map_err(|e| {
1461            ArgentorError::Http(format!("Failed to read Cohere v4 response body: {e}"))
1462        })?;
1463
1464        parse_cohere_embedding_response(&json)
1465    }
1466
1467    #[cfg(not(feature = "http-embeddings"))]
1468    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
1469        Ok(stub_embedding(text, self.dimensions))
1470    }
1471
1472    fn dimension(&self) -> usize {
1473        self.dimensions
1474    }
1475}
1476
1477// ===========================================================================
1478// Tests
1479// ===========================================================================
1480
1481#[cfg(test)]
1482#[allow(clippy::unwrap_used, clippy::expect_used)]
1483mod tests {
1484    use super::*;
1485
1486    // -- Provider creation tests ------------------------------------------
1487
1488    #[test]
1489    fn test_openai_provider_default_model() {
1490        let p = OpenAiEmbeddingProvider::new("sk-test", None);
1491        assert_eq!(p.model(), "text-embedding-3-small");
1492        assert_eq!(p.dimension(), 1536);
1493    }
1494
1495    #[test]
1496    fn test_openai_provider_large_model() {
1497        let p = OpenAiEmbeddingProvider::new("sk-test", Some("text-embedding-3-large".into()));
1498        assert_eq!(p.dimension(), 3072);
1499    }
1500
1501    #[test]
1502    fn test_openai_provider_custom_dimensions() {
1503        let p = OpenAiEmbeddingProvider::new("sk-test", None).with_dimensions(512);
1504        assert_eq!(p.dimension(), 512);
1505    }
1506
1507    #[test]
1508    fn test_openai_provider_custom_base_url() {
1509        let p = OpenAiEmbeddingProvider::with_base_url(
1510            "sk-test",
1511            None,
1512            "https://my-azure.openai.azure.com/openai/deployments/embed",
1513        );
1514        assert_eq!(p.dimension(), 1536);
1515    }
1516
1517    #[cfg(not(feature = "http-embeddings"))]
1518    #[tokio::test]
1519    async fn test_openai_provider_returns_feature_error() {
1520        let p = OpenAiEmbeddingProvider::new("sk-test", None);
1521        let err = p.embed("hello").await.unwrap_err();
1522        let msg = format!("{err}");
1523        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1524    }
1525
1526    #[test]
1527    fn test_cohere_provider_default() {
1528        let p = CohereEmbeddingProvider::new("key", None);
1529        assert_eq!(p.model(), "embed-english-v3.0");
1530        assert_eq!(p.dimension(), 1024);
1531        assert_eq!(p.input_type(), "search_document");
1532    }
1533
1534    #[test]
1535    fn test_cohere_provider_query_input_type() {
1536        let p = CohereEmbeddingProvider::new("key", None).with_input_type("search_query");
1537        assert_eq!(p.input_type(), "search_query");
1538    }
1539
1540    #[test]
1541    fn test_cohere_provider_light_model() {
1542        let p = CohereEmbeddingProvider::new("key", Some("embed-english-light-v3.0".into()));
1543        assert_eq!(p.dimension(), 384);
1544    }
1545
1546    #[cfg(not(feature = "http-embeddings"))]
1547    #[tokio::test]
1548    async fn test_cohere_provider_returns_feature_error() {
1549        let p = CohereEmbeddingProvider::new("key", None);
1550        let err = p.embed("hello").await.unwrap_err();
1551        let msg = format!("{err}");
1552        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1553    }
1554
1555    #[test]
1556    fn test_voyage_provider_default() {
1557        let p = VoyageEmbeddingProvider::new("key", None);
1558        assert_eq!(p.model(), "voyage-2");
1559        assert_eq!(p.dimension(), 1024);
1560    }
1561
1562    #[test]
1563    fn test_voyage_provider_code_model() {
1564        let p = VoyageEmbeddingProvider::new("key", Some("voyage-code-2".into()));
1565        assert_eq!(p.dimension(), 1536);
1566    }
1567
1568    #[cfg(not(feature = "http-embeddings"))]
1569    #[tokio::test]
1570    async fn test_voyage_provider_returns_feature_error() {
1571        let p = VoyageEmbeddingProvider::new("key", None);
1572        let err = p.embed("hello").await.unwrap_err();
1573        let msg = format!("{err}");
1574        assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
1575    }
1576
1577    // -- Response parsing tests -------------------------------------------
1578
1579    #[test]
1580    fn test_parse_openai_embedding_response_valid() {
1581        let json = serde_json::json!({
1582            "data": [
1583                {
1584                    "embedding": [0.1, 0.2, 0.3, 0.4],
1585                    "index": 0
1586                }
1587            ],
1588            "model": "text-embedding-3-small"
1589        });
1590        let result = parse_openai_embedding_response(&json).unwrap();
1591        assert_eq!(result, vec![0.1, 0.2, 0.3, 0.4]);
1592    }
1593
1594    #[test]
1595    fn test_parse_openai_embedding_response_empty_data() {
1596        let json = serde_json::json!({
1597            "data": [],
1598            "model": "text-embedding-3-small"
1599        });
1600        let err = parse_openai_embedding_response(&json).unwrap_err();
1601        let msg = format!("{err}");
1602        assert!(msg.contains("no embedding data"), "got: {msg}");
1603    }
1604
1605    #[test]
1606    fn test_parse_openai_embedding_response_invalid_shape() {
1607        let json = serde_json::json!({ "error": "bad request" });
1608        let err = parse_openai_embedding_response(&json).unwrap_err();
1609        let msg = format!("{err}");
1610        assert!(msg.contains("Failed to parse"), "got: {msg}");
1611    }
1612
1613    #[test]
1614    fn test_parse_openai_embedding_response_multiple_picks_first() {
1615        let json = serde_json::json!({
1616            "data": [
1617                { "embedding": [1.0, 2.0], "index": 0 },
1618                { "embedding": [3.0, 4.0], "index": 1 }
1619            ],
1620            "model": "text-embedding-3-small"
1621        });
1622        let result = parse_openai_embedding_response(&json).unwrap();
1623        assert_eq!(result, vec![1.0, 2.0]);
1624    }
1625
1626    #[test]
1627    fn test_parse_cohere_embedding_response_valid() {
1628        let json = serde_json::json!({
1629            "embeddings": {
1630                "float": [
1631                    [0.5, 0.6, 0.7]
1632                ]
1633            }
1634        });
1635        let result = parse_cohere_embedding_response(&json).unwrap();
1636        assert_eq!(result, vec![0.5, 0.6, 0.7]);
1637    }
1638
1639    #[test]
1640    fn test_parse_cohere_embedding_response_empty_float() {
1641        let json = serde_json::json!({
1642            "embeddings": {
1643                "float": []
1644            }
1645        });
1646        let err = parse_cohere_embedding_response(&json).unwrap_err();
1647        let msg = format!("{err}");
1648        assert!(msg.contains("no float embeddings"), "got: {msg}");
1649    }
1650
1651    #[test]
1652    fn test_parse_cohere_embedding_response_invalid_shape() {
1653        let json = serde_json::json!({ "message": "unauthorized" });
1654        let err = parse_cohere_embedding_response(&json).unwrap_err();
1655        let msg = format!("{err}");
1656        assert!(msg.contains("Failed to parse"), "got: {msg}");
1657    }
1658
1659    #[test]
1660    fn test_parse_cohere_embedding_response_missing_float_key() {
1661        // If "float" key is absent, serde default gives empty vec.
1662        let json = serde_json::json!({
1663            "embeddings": {}
1664        });
1665        let err = parse_cohere_embedding_response(&json).unwrap_err();
1666        let msg = format!("{err}");
1667        assert!(msg.contains("no float embeddings"), "got: {msg}");
1668    }
1669
1670    #[test]
1671    fn test_parse_voyage_embedding_response_valid() {
1672        let json = serde_json::json!({
1673            "data": [
1674                {
1675                    "embedding": [0.9, 0.8, 0.7, 0.6, 0.5],
1676                    "index": 0
1677                }
1678            ]
1679        });
1680        let result = parse_voyage_embedding_response(&json).unwrap();
1681        assert_eq!(result, vec![0.9, 0.8, 0.7, 0.6, 0.5]);
1682    }
1683
1684    #[test]
1685    fn test_parse_voyage_embedding_response_empty_data() {
1686        let json = serde_json::json!({ "data": [] });
1687        let err = parse_voyage_embedding_response(&json).unwrap_err();
1688        let msg = format!("{err}");
1689        assert!(msg.contains("no embedding data"), "got: {msg}");
1690    }
1691
1692    #[test]
1693    fn test_parse_voyage_embedding_response_invalid_shape() {
1694        let json = serde_json::json!({ "error": "invalid key" });
1695        let err = parse_voyage_embedding_response(&json).unwrap_err();
1696        let msg = format!("{err}");
1697        assert!(msg.contains("Failed to parse"), "got: {msg}");
1698    }
1699
1700    // -- CachedEmbeddingProvider tests ------------------------------------
1701
1702    #[tokio::test]
1703    async fn test_cache_hit() {
1704        let local = Arc::new(LocalEmbedding::new(64));
1705        let cached = CachedEmbeddingProvider::new(local, 100);
1706
1707        let v1 = cached.embed("hello world").await.unwrap();
1708        let v2 = cached.embed("hello world").await.unwrap();
1709        assert_eq!(v1, v2);
1710
1711        let stats = cached.cache_stats().await;
1712        assert_eq!(stats.hits, 1);
1713        assert_eq!(stats.misses, 1);
1714        assert_eq!(stats.size, 1);
1715    }
1716
1717    #[tokio::test]
1718    async fn test_cache_miss_different_texts() {
1719        let local = Arc::new(LocalEmbedding::new(64));
1720        let cached = CachedEmbeddingProvider::new(local, 100);
1721
1722        let _ = cached.embed("alpha").await.unwrap();
1723        let _ = cached.embed("bravo").await.unwrap();
1724
1725        let stats = cached.cache_stats().await;
1726        assert_eq!(stats.misses, 2);
1727        assert_eq!(stats.hits, 0);
1728        assert_eq!(stats.size, 2);
1729    }
1730
1731    #[tokio::test]
1732    async fn test_cache_eviction() {
1733        let local = Arc::new(LocalEmbedding::new(64));
1734        let cached = CachedEmbeddingProvider::new(local, 2);
1735
1736        let _ = cached.embed("one").await.unwrap();
1737        let _ = cached.embed("two").await.unwrap();
1738        let _ = cached.embed("three").await.unwrap();
1739
1740        let stats = cached.cache_stats().await;
1741        // After eviction, cache should still have at most max_cache_size entries.
1742        assert!(stats.size <= 2, "size={} should be <= 2", stats.size);
1743        assert_eq!(stats.misses, 3);
1744    }
1745
1746    #[tokio::test]
1747    async fn test_cache_clear() {
1748        let local = Arc::new(LocalEmbedding::new(64));
1749        let cached = CachedEmbeddingProvider::new(local, 100);
1750
1751        let _ = cached.embed("text").await.unwrap();
1752        cached.clear().await;
1753
1754        let stats = cached.cache_stats().await;
1755        assert_eq!(stats.size, 0);
1756    }
1757
1758    #[tokio::test]
1759    async fn test_cache_dimension_delegates() {
1760        let local = Arc::new(LocalEmbedding::new(128));
1761        let cached = CachedEmbeddingProvider::new(local, 10);
1762        assert_eq!(cached.dimension(), 128);
1763    }
1764
1765    // -- BatchEmbeddingProvider tests -------------------------------------
1766
1767    #[tokio::test]
1768    async fn test_batch_embed() {
1769        let local = Arc::new(LocalEmbedding::new(64));
1770        let batch = BatchEmbeddingProvider::new(local);
1771
1772        let results = batch
1773            .embed_batch(&["hello", "world", "test"])
1774            .await
1775            .unwrap();
1776        assert_eq!(results.len(), 3);
1777        for v in &results {
1778            assert_eq!(v.len(), 64);
1779        }
1780    }
1781
1782    #[tokio::test]
1783    async fn test_batch_single_embed_delegates() {
1784        let local = Arc::new(LocalEmbedding::new(64));
1785        let batch = BatchEmbeddingProvider::new(local);
1786
1787        let v = batch.embed("hello").await.unwrap();
1788        assert_eq!(v.len(), 64);
1789    }
1790
1791    #[tokio::test]
1792    async fn test_batch_empty() {
1793        let local = Arc::new(LocalEmbedding::new(64));
1794        let batch = BatchEmbeddingProvider::new(local);
1795
1796        let results = batch.embed_batch(&[]).await.unwrap();
1797        assert!(results.is_empty());
1798    }
1799
1800    #[tokio::test]
1801    async fn test_batch_dimension_delegates() {
1802        let local = Arc::new(LocalEmbedding::new(200));
1803        let batch = BatchEmbeddingProvider::new(local);
1804        assert_eq!(batch.dimension(), 200);
1805    }
1806
1807    // -- Factory tests ----------------------------------------------------
1808
1809    #[test]
1810    fn test_factory_create_local() {
1811        let p = EmbeddingProviderFactory::create("local", "", None).unwrap();
1812        assert_eq!(p.dimension(), 256);
1813    }
1814
1815    #[test]
1816    fn test_factory_create_local_custom_dim() {
1817        let p = EmbeddingProviderFactory::create("local", "", Some("128".into())).unwrap();
1818        assert_eq!(p.dimension(), 128);
1819    }
1820
1821    #[test]
1822    fn test_factory_create_openai() {
1823        let p = EmbeddingProviderFactory::create("openai", "sk-test", None).unwrap();
1824        assert_eq!(p.dimension(), 1536);
1825    }
1826
1827    #[test]
1828    fn test_factory_create_cohere() {
1829        let p = EmbeddingProviderFactory::create("cohere", "key", None).unwrap();
1830        assert_eq!(p.dimension(), 1024);
1831    }
1832
1833    #[test]
1834    fn test_factory_create_voyage() {
1835        let p = EmbeddingProviderFactory::create("voyage", "key", None).unwrap();
1836        assert_eq!(p.dimension(), 1024);
1837    }
1838
1839    #[test]
1840    fn test_factory_unknown_provider() {
1841        let result = EmbeddingProviderFactory::create("unknown", "", None);
1842        assert!(result.is_err(), "Unknown provider should return Err");
1843    }
1844
1845    #[test]
1846    fn test_factory_available_providers() {
1847        let names = EmbeddingProviderFactory::available_providers();
1848        assert!(names.contains(&"openai"));
1849        assert!(names.contains(&"cohere"));
1850        assert!(names.contains(&"voyage"));
1851        assert!(names.contains(&"local"));
1852    }
1853
1854    // -- Config tests -----------------------------------------------------
1855
1856    #[test]
1857    fn test_config_default() {
1858        let cfg = EmbeddingConfig::default();
1859        assert_eq!(cfg.provider, "local");
1860        assert!(cfg.api_key.is_empty());
1861        assert!(cfg.model.is_none());
1862        assert!(cfg.dimensions.is_none());
1863        assert!(cfg.base_url.is_none());
1864        assert!(cfg.cache_size.is_none());
1865    }
1866
1867    #[test]
1868    fn test_config_serialize_deserialize() {
1869        let cfg = EmbeddingConfig {
1870            provider: "openai".to_string(),
1871            api_key: "sk-123".to_string(),
1872            model: Some("text-embedding-3-small".to_string()),
1873            dimensions: Some(1536),
1874            base_url: None,
1875            cache_size: Some(500),
1876        };
1877        let json = serde_json::to_string(&cfg).unwrap();
1878        let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
1879        assert_eq!(parsed.provider, "openai");
1880        assert_eq!(parsed.api_key, "sk-123");
1881        assert_eq!(parsed.dimensions, Some(1536));
1882        assert_eq!(parsed.cache_size, Some(500));
1883    }
1884
1885    #[test]
1886    fn test_config_deserialize_minimal() {
1887        let json = r#"{"provider":"local"}"#;
1888        let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
1889        assert_eq!(cfg.provider, "local");
1890        assert!(cfg.api_key.is_empty());
1891    }
1892
1893    #[tokio::test]
1894    async fn test_config_build_local() {
1895        let cfg = EmbeddingConfig::default();
1896        let provider = cfg.build().unwrap();
1897        assert_eq!(provider.dimension(), 256);
1898        let v = provider.embed("test text").await.unwrap();
1899        assert_eq!(v.len(), 256);
1900    }
1901
1902    #[tokio::test]
1903    async fn test_config_build_local_with_cache() {
1904        let cfg = EmbeddingConfig {
1905            provider: "local".to_string(),
1906            cache_size: Some(50),
1907            ..Default::default()
1908        };
1909        let provider = cfg.build().unwrap();
1910        // Dimension from local default.
1911        assert_eq!(provider.dimension(), 256);
1912        // Should work — cache wraps local.
1913        let v1 = provider.embed("cached text").await.unwrap();
1914        let v2 = provider.embed("cached text").await.unwrap();
1915        assert_eq!(v1, v2);
1916    }
1917
1918    #[tokio::test]
1919    async fn test_config_build_local_custom_dimensions() {
1920        let cfg = EmbeddingConfig {
1921            provider: "local".to_string(),
1922            dimensions: Some(512),
1923            ..Default::default()
1924        };
1925        let provider = cfg.build().unwrap();
1926        assert_eq!(provider.dimension(), 512);
1927    }
1928
1929    #[test]
1930    fn test_config_build_unknown_provider() {
1931        let cfg = EmbeddingConfig {
1932            provider: "imaginary".to_string(),
1933            ..Default::default()
1934        };
1935        assert!(cfg.build().is_err());
1936    }
1937
1938    // -- Misc / edge cases ------------------------------------------------
1939
1940    #[test]
1941    fn test_fnv_hash_deterministic() {
1942        let h1 = fnv1a_hash(b"hello world");
1943        let h2 = fnv1a_hash(b"hello world");
1944        assert_eq!(h1, h2);
1945    }
1946
1947    #[test]
1948    fn test_fnv_hash_different_inputs() {
1949        let h1 = fnv1a_hash(b"alpha");
1950        let h2 = fnv1a_hash(b"bravo");
1951        assert_ne!(h1, h2);
1952    }
1953
1954    // =====================================================================
1955    // Stub helper tests
1956    // =====================================================================
1957
1958    #[test]
1959    fn test_stub_embedding_length() {
1960        let v = stub_embedding("hello", 128);
1961        assert_eq!(v.len(), 128);
1962    }
1963
1964    #[test]
1965    fn test_stub_embedding_deterministic() {
1966        let v1 = stub_embedding("same input", 64);
1967        let v2 = stub_embedding("same input", 64);
1968        assert_eq!(v1, v2);
1969    }
1970
1971    #[test]
1972    fn test_stub_embedding_normalized() {
1973        let v = stub_embedding("the quick brown fox", 256);
1974        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
1975        assert!((norm - 1.0).abs() < 0.01, "norm={norm}");
1976    }
1977
1978    #[test]
1979    fn test_stub_embedding_different_inputs_differ() {
1980        let a = stub_embedding("alpha", 64);
1981        let b = stub_embedding("bravo", 64);
1982        assert_ne!(a, b);
1983    }
1984
1985    #[test]
1986    fn test_stub_embedding_empty_text_zeroes() {
1987        let v = stub_embedding("", 32);
1988        assert_eq!(v.len(), 32);
1989        assert!(v.iter().all(|&x| x == 0.0));
1990    }
1991
1992    #[test]
1993    fn test_stub_embedding_zero_dimension_safe() {
1994        // Must not panic; helper clamps dimension to at least 1.
1995        let v = stub_embedding("hi", 0);
1996        assert_eq!(v.len(), 1);
1997    }
1998
1999    // =====================================================================
2000    // JinaEmbeddingProvider tests
2001    // =====================================================================
2002
2003    #[test]
2004    fn test_jina_default_construction() {
2005        let p = JinaEmbeddingProvider::new("jina-key");
2006        assert_eq!(p.model(), "jina-embeddings-v3");
2007        assert_eq!(p.dimension(), 1024);
2008    }
2009
2010    #[test]
2011    fn test_jina_with_model_clip() {
2012        let p = JinaEmbeddingProvider::with_model("k", "jina-clip-v2", 768);
2013        assert_eq!(p.model(), "jina-clip-v2");
2014        assert_eq!(p.dimension(), 768);
2015    }
2016
2017    #[test]
2018    fn test_jina_with_base_url() {
2019        let p = JinaEmbeddingProvider::new("k").with_base_url("https://custom.jina/v1");
2020        // Indirect check: construction succeeds and model unchanged.
2021        assert_eq!(p.model(), "jina-embeddings-v3");
2022    }
2023
2024    #[test]
2025    fn test_jina_build_payload_shape() {
2026        let p = JinaEmbeddingProvider::new("k");
2027        let payload = p.build_payload(&["hello".to_string(), "world".to_string()]);
2028        assert_eq!(payload["model"], "jina-embeddings-v3");
2029        assert_eq!(payload["input"][0], "hello");
2030        assert_eq!(payload["input"][1], "world");
2031    }
2032
2033    #[tokio::test]
2034    async fn test_jina_embed_length_matches_dimension() {
2035        let p = JinaEmbeddingProvider::new("k");
2036        #[cfg(not(feature = "http-embeddings"))]
2037        {
2038            let v = p.embed("hello jina").await.unwrap();
2039            assert_eq!(v.len(), 1024);
2040        }
2041        // When http-embeddings is enabled, we don't hit the real API in tests;
2042        // just confirm dimension() reports 1024.
2043        assert_eq!(p.dimension(), 1024);
2044    }
2045
2046    #[cfg(not(feature = "http-embeddings"))]
2047    #[tokio::test]
2048    async fn test_jina_stub_is_normalized() {
2049        let p = JinaEmbeddingProvider::new("k");
2050        let v = p.embed("some input").await.unwrap();
2051        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2052        assert!((norm - 1.0).abs() < 0.01);
2053    }
2054
2055    #[cfg(not(feature = "http-embeddings"))]
2056    #[tokio::test]
2057    async fn test_jina_stub_deterministic() {
2058        let p = JinaEmbeddingProvider::new("k");
2059        let a = p.embed("consistent").await.unwrap();
2060        let b = p.embed("consistent").await.unwrap();
2061        assert_eq!(a, b);
2062    }
2063
2064    // =====================================================================
2065    // MistralEmbedProvider tests
2066    // =====================================================================
2067
2068    #[test]
2069    fn test_mistral_default_construction() {
2070        let p = MistralEmbedProvider::new("mistral-key");
2071        assert_eq!(p.model(), "mistral-embed");
2072        assert_eq!(p.dimension(), 1024);
2073    }
2074
2075    #[test]
2076    fn test_mistral_with_model_and_dimensions() {
2077        let p = MistralEmbedProvider::with_model("k", "mistral-embed-large", 2048);
2078        assert_eq!(p.model(), "mistral-embed-large");
2079        assert_eq!(p.dimension(), 2048);
2080    }
2081
2082    #[test]
2083    fn test_mistral_build_payload_shape() {
2084        let p = MistralEmbedProvider::new("k");
2085        let payload = p.build_payload(&["alpha".to_string()]);
2086        assert_eq!(payload["model"], "mistral-embed");
2087        assert_eq!(payload["input"][0], "alpha");
2088    }
2089
2090    #[test]
2091    fn test_mistral_with_base_url() {
2092        let p = MistralEmbedProvider::new("k").with_base_url("https://custom.mistral/v1");
2093        assert_eq!(p.dimension(), 1024);
2094    }
2095
2096    #[cfg(not(feature = "http-embeddings"))]
2097    #[tokio::test]
2098    async fn test_mistral_embed_length() {
2099        let p = MistralEmbedProvider::new("k");
2100        let v = p.embed("hello mistral").await.unwrap();
2101        assert_eq!(v.len(), 1024);
2102    }
2103
2104    #[cfg(not(feature = "http-embeddings"))]
2105    #[tokio::test]
2106    async fn test_mistral_stub_normalized() {
2107        let p = MistralEmbedProvider::new("k");
2108        let v = p.embed("normalized?").await.unwrap();
2109        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2110        assert!((norm - 1.0).abs() < 0.01);
2111    }
2112
2113    // =====================================================================
2114    // NomicEmbedProvider tests
2115    // =====================================================================
2116
2117    #[test]
2118    fn test_nomic_default_construction() {
2119        let p = NomicEmbedProvider::new("nomic-key");
2120        assert_eq!(p.model(), "nomic-embed-text-v1.5");
2121        assert_eq!(p.dimension(), 768);
2122        assert_eq!(p.task_type(), "search_document");
2123    }
2124
2125    #[test]
2126    fn test_nomic_with_task_type() {
2127        let p = NomicEmbedProvider::new("k").with_task_type("search_query");
2128        assert_eq!(p.task_type(), "search_query");
2129    }
2130
2131    #[test]
2132    fn test_nomic_build_payload_shape() {
2133        let p = NomicEmbedProvider::new("k").with_task_type("clustering");
2134        let payload = p.build_payload(&["doc a".to_string(), "doc b".to_string()]);
2135        assert_eq!(payload["model"], "nomic-embed-text-v1.5");
2136        assert_eq!(payload["texts"][0], "doc a");
2137        assert_eq!(payload["texts"][1], "doc b");
2138        assert_eq!(payload["task_type"], "clustering");
2139    }
2140
2141    #[test]
2142    fn test_nomic_with_model_custom_dims() {
2143        let p = NomicEmbedProvider::with_model("k", "custom-nomic", 512);
2144        assert_eq!(p.dimension(), 512);
2145    }
2146
2147    #[cfg(not(feature = "http-embeddings"))]
2148    #[tokio::test]
2149    async fn test_nomic_embed_length() {
2150        let p = NomicEmbedProvider::new("k");
2151        let v = p.embed("nomic test").await.unwrap();
2152        assert_eq!(v.len(), 768);
2153    }
2154
2155    #[cfg(not(feature = "http-embeddings"))]
2156    #[tokio::test]
2157    async fn test_nomic_embed_normalized() {
2158        let p = NomicEmbedProvider::new("k");
2159        let v = p.embed("some text").await.unwrap();
2160        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2161        assert!((norm - 1.0).abs() < 0.01);
2162    }
2163
2164    // =====================================================================
2165    // SentenceTransformersProvider tests
2166    // =====================================================================
2167
2168    #[test]
2169    fn test_sentence_transformers_default_construction() {
2170        let p = SentenceTransformersProvider::new("hf-key");
2171        assert_eq!(p.model(), "sentence-transformers/all-MiniLM-L6-v2");
2172        assert_eq!(p.dimension(), 384);
2173    }
2174
2175    #[test]
2176    fn test_sentence_transformers_mpnet_dims() {
2177        let dims = SentenceTransformersProvider::default_dimensions(
2178            "sentence-transformers/all-mpnet-base-v2",
2179        );
2180        assert_eq!(dims, 768);
2181    }
2182
2183    #[test]
2184    fn test_sentence_transformers_multi_qa_dims() {
2185        let dims = SentenceTransformersProvider::default_dimensions(
2186            "sentence-transformers/multi-qa-mpnet-base-dot-v1",
2187        );
2188        assert_eq!(dims, 768);
2189    }
2190
2191    #[test]
2192    fn test_sentence_transformers_unknown_model_fallback() {
2193        let dims =
2194            SentenceTransformersProvider::default_dimensions("sentence-transformers/unknown");
2195        assert_eq!(dims, 384);
2196    }
2197
2198    #[test]
2199    fn test_sentence_transformers_with_model() {
2200        let p = SentenceTransformersProvider::with_model(
2201            "k",
2202            "sentence-transformers/all-mpnet-base-v2",
2203            768,
2204        );
2205        assert_eq!(p.model(), "sentence-transformers/all-mpnet-base-v2");
2206        assert_eq!(p.dimension(), 768);
2207    }
2208
2209    #[test]
2210    fn test_sentence_transformers_build_payload_shape() {
2211        let p = SentenceTransformersProvider::new("k");
2212        let payload = p.build_payload(&["hi".to_string()]);
2213        assert_eq!(payload["inputs"][0], "hi");
2214        assert_eq!(payload["options"]["wait_for_model"], true);
2215    }
2216
2217    #[test]
2218    fn test_sentence_transformers_with_base_url() {
2219        let p =
2220            SentenceTransformersProvider::new("k").with_base_url("https://self-hosted.hf/embed");
2221        assert_eq!(p.dimension(), 384);
2222    }
2223
2224    #[cfg(not(feature = "http-embeddings"))]
2225    #[tokio::test]
2226    async fn test_sentence_transformers_embed_length() {
2227        let p = SentenceTransformersProvider::new("k");
2228        let v = p.embed("minilm test").await.unwrap();
2229        assert_eq!(v.len(), 384);
2230    }
2231
2232    #[cfg(not(feature = "http-embeddings"))]
2233    #[tokio::test]
2234    async fn test_sentence_transformers_embed_normalized() {
2235        let p = SentenceTransformersProvider::new("k");
2236        let v = p.embed("some input").await.unwrap();
2237        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2238        assert!((norm - 1.0).abs() < 0.01);
2239    }
2240
2241    // =====================================================================
2242    // TogetherEmbedProvider tests
2243    // =====================================================================
2244
2245    #[test]
2246    fn test_together_default_construction() {
2247        let p = TogetherEmbedProvider::new("together-key");
2248        assert_eq!(p.model(), "togethercomputer/m2-bert-80M-32k-retrieval");
2249        assert_eq!(p.dimension(), 768);
2250    }
2251
2252    #[test]
2253    fn test_together_with_model() {
2254        let p = TogetherEmbedProvider::with_model("k", "togethercomputer/custom", 1024);
2255        assert_eq!(p.model(), "togethercomputer/custom");
2256        assert_eq!(p.dimension(), 1024);
2257    }
2258
2259    #[test]
2260    fn test_together_build_payload_shape() {
2261        let p = TogetherEmbedProvider::new("k");
2262        let payload = p.build_payload(&["x".to_string(), "y".to_string()]);
2263        assert_eq!(
2264            payload["model"],
2265            "togethercomputer/m2-bert-80M-32k-retrieval"
2266        );
2267        assert_eq!(payload["input"][0], "x");
2268        assert_eq!(payload["input"][1], "y");
2269    }
2270
2271    #[test]
2272    fn test_together_with_base_url() {
2273        let p = TogetherEmbedProvider::new("k").with_base_url("https://custom.together/v1");
2274        assert_eq!(p.dimension(), 768);
2275    }
2276
2277    #[cfg(not(feature = "http-embeddings"))]
2278    #[tokio::test]
2279    async fn test_together_embed_length() {
2280        let p = TogetherEmbedProvider::new("k");
2281        let v = p.embed("together test").await.unwrap();
2282        assert_eq!(v.len(), 768);
2283    }
2284
2285    #[cfg(not(feature = "http-embeddings"))]
2286    #[tokio::test]
2287    async fn test_together_embed_normalized() {
2288        let p = TogetherEmbedProvider::new("k");
2289        let v = p.embed("text").await.unwrap();
2290        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2291        assert!((norm - 1.0).abs() < 0.01);
2292    }
2293
2294    // =====================================================================
2295    // CohereEmbedV4Provider tests
2296    // =====================================================================
2297
2298    #[test]
2299    fn test_cohere_v4_default_construction() {
2300        let p = CohereEmbedV4Provider::new("cohere-key");
2301        assert_eq!(p.model(), "embed-english-v3.0");
2302        assert_eq!(p.dimension(), 1024);
2303        assert_eq!(p.input_type(), "search_document");
2304    }
2305
2306    #[test]
2307    fn test_cohere_v4_multilingual_model() {
2308        let p = CohereEmbedV4Provider::with_model("k", "embed-multilingual-v3.0", 1024);
2309        assert_eq!(p.model(), "embed-multilingual-v3.0");
2310        assert_eq!(p.dimension(), 1024);
2311    }
2312
2313    #[test]
2314    fn test_cohere_v4_for_search_document() {
2315        let p = CohereEmbedV4Provider::new("k").for_search_document();
2316        assert_eq!(p.input_type(), "search_document");
2317    }
2318
2319    #[test]
2320    fn test_cohere_v4_for_search_query() {
2321        let p = CohereEmbedV4Provider::new("k").for_search_query();
2322        assert_eq!(p.input_type(), "search_query");
2323    }
2324
2325    #[test]
2326    fn test_cohere_v4_with_input_type() {
2327        let p = CohereEmbedV4Provider::new("k").with_input_type("classification");
2328        assert_eq!(p.input_type(), "classification");
2329    }
2330
2331    #[test]
2332    fn test_cohere_v4_build_payload_shape_document() {
2333        let p = CohereEmbedV4Provider::new("k").for_search_document();
2334        let payload = p.build_payload(&["doc".to_string()]);
2335        assert_eq!(payload["model"], "embed-english-v3.0");
2336        assert_eq!(payload["texts"][0], "doc");
2337        assert_eq!(payload["input_type"], "search_document");
2338        assert_eq!(payload["embedding_types"][0], "float");
2339    }
2340
2341    #[test]
2342    fn test_cohere_v4_build_payload_shape_query() {
2343        let p = CohereEmbedV4Provider::new("k").for_search_query();
2344        let payload = p.build_payload(&["q".to_string()]);
2345        assert_eq!(payload["input_type"], "search_query");
2346    }
2347
2348    #[test]
2349    fn test_cohere_v4_with_base_url() {
2350        let p = CohereEmbedV4Provider::new("k").with_base_url("https://custom.cohere/v2/embed");
2351        assert_eq!(p.dimension(), 1024);
2352    }
2353
2354    #[cfg(not(feature = "http-embeddings"))]
2355    #[tokio::test]
2356    async fn test_cohere_v4_embed_length() {
2357        let p = CohereEmbedV4Provider::new("k");
2358        let v = p.embed("cohere v4 test").await.unwrap();
2359        assert_eq!(v.len(), 1024);
2360    }
2361
2362    #[cfg(not(feature = "http-embeddings"))]
2363    #[tokio::test]
2364    async fn test_cohere_v4_embed_normalized() {
2365        let p = CohereEmbedV4Provider::new("k");
2366        let v = p.embed("x").await.unwrap();
2367        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
2368        assert!((norm - 1.0).abs() < 0.01);
2369    }
2370
2371    #[cfg(not(feature = "http-embeddings"))]
2372    #[tokio::test]
2373    async fn test_cohere_v4_embed_deterministic() {
2374        let p = CohereEmbedV4Provider::new("k");
2375        let a = p.embed("same").await.unwrap();
2376        let b = p.embed("same").await.unwrap();
2377        assert_eq!(a, b);
2378    }
2379
2380    // =====================================================================
2381    // Cross-provider checks
2382    // =====================================================================
2383
2384    #[test]
2385    fn test_all_new_providers_implement_embedding_provider_trait() {
2386        // Compile-time check — if these coerce into `Box<dyn EmbeddingProvider>`,
2387        // they correctly implement the trait.
2388        let _boxes: Vec<Box<dyn EmbeddingProvider>> = vec![
2389            Box::new(JinaEmbeddingProvider::new("k")),
2390            Box::new(MistralEmbedProvider::new("k")),
2391            Box::new(NomicEmbedProvider::new("k")),
2392            Box::new(SentenceTransformersProvider::new("k")),
2393            Box::new(TogetherEmbedProvider::new("k")),
2394            Box::new(CohereEmbedV4Provider::new("k")),
2395        ];
2396    }
2397
2398    #[test]
2399    fn test_new_providers_have_expected_dimensions() {
2400        assert_eq!(JinaEmbeddingProvider::new("k").dimension(), 1024);
2401        assert_eq!(MistralEmbedProvider::new("k").dimension(), 1024);
2402        assert_eq!(NomicEmbedProvider::new("k").dimension(), 768);
2403        assert_eq!(SentenceTransformersProvider::new("k").dimension(), 384);
2404        assert_eq!(TogetherEmbedProvider::new("k").dimension(), 768);
2405        assert_eq!(CohereEmbedV4Provider::new("k").dimension(), 1024);
2406    }
2407}