use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use argentor_core::{ArgentorError, ArgentorResult};
use crate::embedding::{EmbeddingProvider, LocalEmbedding};
fn fnv1a_hash(data: &[u8]) -> u64 {
let mut hash: u64 = 14695981039346656037;
for &byte in data {
hash ^= byte as u64;
hash = hash.wrapping_mul(1099511628211);
}
hash
}
#[derive(Debug, Serialize)]
pub struct OpenAiEmbeddingRequest {
pub model: String,
pub input: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiEmbeddingObject {
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Deserialize)]
pub struct OpenAiEmbeddingResponse {
pub data: Vec<OpenAiEmbeddingObject>,
pub model: String,
}
#[derive(Debug, Serialize)]
pub struct CohereEmbedRequest {
pub model: String,
pub texts: Vec<String>,
pub input_type: String,
pub embedding_types: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct CohereEmbedResponse {
pub embeddings: CohereEmbeddingsMap,
}
#[derive(Debug, Deserialize)]
pub struct CohereEmbeddingsMap {
#[serde(default)]
pub float: Vec<Vec<f32>>,
}
#[derive(Debug, Serialize)]
pub struct VoyageEmbeddingRequest {
pub model: String,
pub input: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct VoyageEmbeddingObject {
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Deserialize)]
pub struct VoyageEmbeddingResponse {
pub data: Vec<VoyageEmbeddingObject>,
}
pub fn parse_openai_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
let response: OpenAiEmbeddingResponse = serde_json::from_value(json.clone())
.map_err(|e| ArgentorError::Agent(format!("Failed to parse OpenAI response: {e}")))?;
response
.data
.into_iter()
.next()
.map(|obj| obj.embedding)
.ok_or_else(|| {
ArgentorError::Agent("OpenAI response contains no embedding data".to_string())
})
}
pub fn parse_cohere_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
let response: CohereEmbedResponse = serde_json::from_value(json.clone())
.map_err(|e| ArgentorError::Agent(format!("Failed to parse Cohere response: {e}")))?;
response.embeddings.float.into_iter().next().ok_or_else(|| {
ArgentorError::Agent("Cohere response contains no float embeddings".to_string())
})
}
pub fn parse_voyage_embedding_response(json: &serde_json::Value) -> ArgentorResult<Vec<f32>> {
let response: VoyageEmbeddingResponse = serde_json::from_value(json.clone())
.map_err(|e| ArgentorError::Agent(format!("Failed to parse Voyage response: {e}")))?;
response
.data
.into_iter()
.next()
.map(|obj| obj.embedding)
.ok_or_else(|| {
ArgentorError::Agent("Voyage response contains no embedding data".to_string())
})
}
pub struct OpenAiEmbeddingProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl OpenAiEmbeddingProvider {
pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
let dimensions = Self::default_dimensions(&model);
Self {
api_key: api_key.into(),
model,
dimensions,
base_url: "https://api.openai.com/v1/embeddings".to_string(),
}
}
pub fn with_base_url(
api_key: impl Into<String>,
model: Option<String>,
base_url: impl Into<String>,
) -> Self {
let model = model.unwrap_or_else(|| "text-embedding-3-small".to_string());
let dimensions = Self::default_dimensions(&model);
Self {
api_key: api_key.into(),
model,
dimensions,
base_url: base_url.into(),
}
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
fn default_dimensions(model: &str) -> usize {
match model {
"text-embedding-3-large" => 3072,
"text-embedding-3-small" => 1536,
"text-embedding-ada-002" => 1536,
_ => 1536,
}
}
pub fn model(&self) -> &str {
&self.model
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbeddingProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({
"model": self.model,
"input": text,
}))
.send()
.await
.map_err(|e| ArgentorError::Http(format!("OpenAI embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"OpenAI API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read OpenAI response body: {e}"))
})?;
parse_openai_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
Err(ArgentorError::Http(
"HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
or use LocalEmbedding for offline embeddings."
.to_string(),
))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct CohereEmbeddingProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
input_type: String,
}
impl CohereEmbeddingProvider {
pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
let model = model.unwrap_or_else(|| "embed-english-v3.0".to_string());
let dimensions = Self::default_dimensions(&model);
Self {
api_key: api_key.into(),
model,
dimensions,
input_type: "search_document".to_string(),
}
}
pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
self.input_type = input_type.into();
self
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
fn default_dimensions(model: &str) -> usize {
match model {
"embed-english-v3.0" | "embed-multilingual-v3.0" => 1024,
"embed-english-light-v3.0" | "embed-multilingual-light-v3.0" => 384,
_ => 1024,
}
}
pub fn model(&self) -> &str {
&self.model
}
pub fn input_type(&self) -> &str {
&self.input_type
}
}
#[async_trait]
impl EmbeddingProvider for CohereEmbeddingProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let response = client
.post("https://api.cohere.com/v2/embed")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({
"model": self.model,
"texts": [text],
"input_type": self.input_type,
"embedding_types": ["float"],
}))
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Cohere embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Cohere API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Cohere response body: {e}"))
})?;
parse_cohere_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
Err(ArgentorError::Http(
"HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
or use LocalEmbedding for offline embeddings."
.to_string(),
))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct VoyageEmbeddingProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
}
impl VoyageEmbeddingProvider {
pub fn new(api_key: impl Into<String>, model: Option<String>) -> Self {
let model = model.unwrap_or_else(|| "voyage-2".to_string());
let dimensions = Self::default_dimensions(&model);
Self {
api_key: api_key.into(),
model,
dimensions,
}
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
fn default_dimensions(model: &str) -> usize {
match model {
"voyage-2" | "voyage-large-2" => 1024,
"voyage-lite-02-instruct" => 1024,
"voyage-3" => 1024,
"voyage-code-2" => 1536,
_ => 1024,
}
}
pub fn model(&self) -> &str {
&self.model
}
}
#[async_trait]
impl EmbeddingProvider for VoyageEmbeddingProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let response = client
.post("https://api.voyageai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({
"model": self.model,
"input": [text],
}))
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Voyage embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Voyage API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Voyage response body: {e}"))
})?;
parse_voyage_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, _text: &str) -> ArgentorResult<Vec<f32>> {
Err(ArgentorError::Http(
"HTTP embeddings not enabled. Enable the 'http-embeddings' feature flag \
or use LocalEmbedding for offline embeddings."
.to_string(),
))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
}
pub struct CachedEmbeddingProvider {
inner: Arc<dyn EmbeddingProvider>,
cache: Arc<RwLock<HashMap<u64, Vec<f32>>>>,
max_cache_size: usize,
stats: Arc<RwLock<CacheStats>>,
}
impl CachedEmbeddingProvider {
pub fn new(inner: Arc<dyn EmbeddingProvider>, max_cache_size: usize) -> Self {
Self {
inner,
cache: Arc::new(RwLock::new(HashMap::new())),
max_cache_size,
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub async fn cache_stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
pub async fn clear(&self) {
self.cache.write().await.clear();
let mut stats = self.stats.write().await;
stats.size = 0;
}
fn text_hash(text: &str) -> u64 {
fnv1a_hash(text.as_bytes())
}
}
#[async_trait]
impl EmbeddingProvider for CachedEmbeddingProvider {
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let key = Self::text_hash(text);
{
let cache = self.cache.read().await;
if let Some(cached) = cache.get(&key) {
let mut stats = self.stats.write().await;
stats.hits += 1;
return Ok(cached.clone());
}
}
let embedding = self.inner.embed(text).await?;
{
let mut cache = self.cache.write().await;
if cache.len() >= self.max_cache_size {
if let Some(&evict_key) = cache.keys().next() {
cache.remove(&evict_key);
}
}
cache.insert(key, embedding.clone());
let mut stats = self.stats.write().await;
stats.misses += 1;
stats.size = cache.len();
}
Ok(embedding)
}
fn dimension(&self) -> usize {
self.inner.dimension()
}
}
pub struct BatchEmbeddingProvider {
inner: Arc<dyn EmbeddingProvider>,
}
impl BatchEmbeddingProvider {
pub fn new(inner: Arc<dyn EmbeddingProvider>) -> Self {
Self { inner }
}
pub async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
self.inner.embed_batch(texts).await
}
}
#[async_trait]
impl EmbeddingProvider for BatchEmbeddingProvider {
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
self.inner.embed(text).await
}
async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
self.inner.embed_batch(texts).await
}
fn dimension(&self) -> usize {
self.inner.dimension()
}
}
pub struct EmbeddingProviderFactory;
impl EmbeddingProviderFactory {
pub fn create(
provider_name: &str,
api_key: impl Into<String>,
model: Option<String>,
) -> ArgentorResult<Box<dyn EmbeddingProvider>> {
let api_key = api_key.into();
match provider_name {
"openai" => Ok(Box::new(OpenAiEmbeddingProvider::new(api_key, model))),
"cohere" => Ok(Box::new(CohereEmbeddingProvider::new(api_key, model))),
"voyage" => Ok(Box::new(VoyageEmbeddingProvider::new(api_key, model))),
"local" => {
let dim = model
.as_deref()
.and_then(|m| m.parse::<usize>().ok())
.unwrap_or(256);
Ok(Box::new(LocalEmbedding::new(dim)))
}
other => Err(ArgentorError::Config(format!(
"Unknown embedding provider: {other}"
))),
}
}
pub fn available_providers() -> &'static [&'static str] {
&["openai", "cohere", "voyage", "local"]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: String,
#[serde(default)]
pub api_key: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub dimensions: Option<usize>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub cache_size: Option<usize>,
}
impl EmbeddingConfig {
pub fn build(&self) -> ArgentorResult<Arc<dyn EmbeddingProvider>> {
let mut provider: Box<dyn EmbeddingProvider> = match self.provider.as_str() {
"openai" => {
let mut p = if let Some(ref url) = self.base_url {
OpenAiEmbeddingProvider::with_base_url(&self.api_key, self.model.clone(), url)
} else {
OpenAiEmbeddingProvider::new(&self.api_key, self.model.clone())
};
if let Some(dim) = self.dimensions {
p = p.with_dimensions(dim);
}
Box::new(p)
}
"cohere" => {
let mut p = CohereEmbeddingProvider::new(&self.api_key, self.model.clone());
if let Some(dim) = self.dimensions {
p = p.with_dimensions(dim);
}
Box::new(p)
}
"voyage" => {
let mut p = VoyageEmbeddingProvider::new(&self.api_key, self.model.clone());
if let Some(dim) = self.dimensions {
p = p.with_dimensions(dim);
}
Box::new(p)
}
"local" => {
let dim = self.dimensions.unwrap_or(256);
Box::new(LocalEmbedding::new(dim))
}
other => {
return Err(ArgentorError::Config(format!(
"Unknown embedding provider: {other}"
)));
}
};
let _ = &mut provider;
let arc: Arc<dyn EmbeddingProvider> = Arc::from(provider);
if let Some(cache_size) = self.cache_size {
Ok(Arc::new(CachedEmbeddingProvider::new(arc, cache_size)))
} else {
Ok(arc)
}
}
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
provider: "local".to_string(),
api_key: String::new(),
model: None,
dimensions: None,
base_url: None,
cache_size: None,
}
}
}
#[cfg_attr(
all(feature = "http-embeddings", not(test)),
allow(dead_code)
)]
fn stub_embedding(text: &str, dimensions: usize) -> Vec<f32> {
let dim = dimensions.max(1);
let mut v = vec![0.0f32; dim];
for (i, b) in text.bytes().enumerate() {
v[i % dim] += (b as f32) / 255.0;
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
pub struct JinaEmbeddingProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl JinaEmbeddingProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(api_key, "jina-embeddings-v3", 1024)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
dimensions,
base_url: "https://api.jina.ai/v1/embeddings".to_string(),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"model": self.model,
"input": texts,
})
}
}
#[async_trait]
impl EmbeddingProvider for JinaEmbeddingProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Jina embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Jina API error {status}: {body}"
)));
}
let json: serde_json::Value = response
.json()
.await
.map_err(|e| ArgentorError::Http(format!("Failed to read Jina response body: {e}")))?;
parse_openai_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct MistralEmbedProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl MistralEmbedProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(api_key, "mistral-embed", 1024)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
dimensions,
base_url: "https://api.mistral.ai/v1/embeddings".to_string(),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"model": self.model,
"input": texts,
})
}
}
#[async_trait]
impl EmbeddingProvider for MistralEmbedProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Mistral embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Mistral API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Mistral response body: {e}"))
})?;
parse_openai_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct NomicEmbedProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
task_type: String,
}
impl NomicEmbedProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(api_key, "nomic-embed-text-v1.5", 768)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
dimensions,
base_url: "https://api-atlas.nomic.ai/v1/embedding/text".to_string(),
task_type: "search_document".to_string(),
}
}
pub fn with_task_type(mut self, task_type: impl Into<String>) -> Self {
self.task_type = task_type.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn task_type(&self) -> &str {
&self.task_type
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"model": self.model,
"texts": texts,
"task_type": self.task_type,
})
}
}
#[async_trait]
impl EmbeddingProvider for NomicEmbedProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Nomic embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Nomic API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Nomic response body: {e}"))
})?;
let embeddings = json
.get("embeddings")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ArgentorError::Agent("Nomic response missing 'embeddings' array".to_string())
})?;
let first = embeddings.first().ok_or_else(|| {
ArgentorError::Agent("Nomic response contains no embedding vectors".to_string())
})?;
let vec: Vec<f32> = serde_json::from_value(first.clone()).map_err(|e| {
ArgentorError::Agent(format!("Failed to parse Nomic embedding vector: {e}"))
})?;
Ok(vec)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct SentenceTransformersProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl SentenceTransformersProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(api_key, "sentence-transformers/all-MiniLM-L6-v2", 384)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
let model = model.into();
let base_url =
format!("https://api-inference.huggingface.co/pipeline/feature-extraction/{model}");
Self {
api_key: api_key.into(),
model,
dimensions,
base_url,
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn default_dimensions(model: &str) -> usize {
match model {
"sentence-transformers/all-MiniLM-L6-v2" => 384,
"sentence-transformers/all-mpnet-base-v2"
| "sentence-transformers/multi-qa-mpnet-base-dot-v1" => 768,
_ => 384,
}
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"inputs": texts,
"options": { "wait_for_model": true },
})
}
}
#[async_trait]
impl EmbeddingProvider for SentenceTransformersProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| {
ArgentorError::Http(format!("HuggingFace embedding request failed: {e}"))
})?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"HuggingFace API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read HuggingFace response body: {e}"))
})?;
match &json {
serde_json::Value::Array(arr)
if arr.first().is_some_and(serde_json::Value::is_array) =>
{
let first = arr.first().cloned().ok_or_else(|| {
ArgentorError::Agent("HuggingFace response empty".to_string())
})?;
serde_json::from_value(first).map_err(|e| {
ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
})
}
serde_json::Value::Array(_) => serde_json::from_value(json).map_err(|e| {
ArgentorError::Agent(format!("Failed to parse HF vector: {e}"))
}),
_ => Err(ArgentorError::Agent(
"HuggingFace response is not an array".to_string(),
)),
}
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct TogetherEmbedProvider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl TogetherEmbedProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(
api_key,
"togethercomputer/m2-bert-80M-32k-retrieval",
768,
)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
dimensions,
base_url: "https://api.together.xyz/v1/embeddings".to_string(),
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"model": self.model,
"input": texts,
})
}
}
#[async_trait]
impl EmbeddingProvider for TogetherEmbedProvider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| ArgentorError::Http(format!("Together embedding request failed: {e}")))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Together API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Together response body: {e}"))
})?;
parse_openai_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
pub struct CohereEmbedV4Provider {
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
api_key: String,
model: String,
dimensions: usize,
input_type: String,
#[cfg_attr(not(feature = "http-embeddings"), allow(dead_code))]
base_url: String,
}
impl CohereEmbedV4Provider {
pub fn new(api_key: impl Into<String>) -> Self {
Self::with_model(api_key, "embed-english-v3.0", 1024)
}
pub fn with_model(
api_key: impl Into<String>,
model: impl Into<String>,
dimensions: usize,
) -> Self {
Self {
api_key: api_key.into(),
model: model.into(),
dimensions,
input_type: "search_document".to_string(),
base_url: "https://api.cohere.com/v2/embed".to_string(),
}
}
pub fn for_search_document(mut self) -> Self {
self.input_type = "search_document".to_string();
self
}
pub fn for_search_query(mut self) -> Self {
self.input_type = "search_query".to_string();
self
}
pub fn with_input_type(mut self, input_type: impl Into<String>) -> Self {
self.input_type = input_type.into();
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn input_type(&self) -> &str {
&self.input_type
}
pub fn build_payload(&self, texts: &[String]) -> serde_json::Value {
serde_json::json!({
"model": self.model,
"texts": texts,
"input_type": self.input_type,
"embedding_types": ["float"],
})
}
}
#[async_trait]
impl EmbeddingProvider for CohereEmbedV4Provider {
#[cfg(feature = "http-embeddings")]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
let client = reqwest::Client::new();
let payload = self.build_payload(&[text.to_string()]);
let response = client
.post(&self.base_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&payload)
.send()
.await
.map_err(|e| {
ArgentorError::Http(format!("Cohere v4 embedding request failed: {e}"))
})?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(ArgentorError::Http(format!(
"Cohere v4 API error {status}: {body}"
)));
}
let json: serde_json::Value = response.json().await.map_err(|e| {
ArgentorError::Http(format!("Failed to read Cohere v4 response body: {e}"))
})?;
parse_cohere_embedding_response(&json)
}
#[cfg(not(feature = "http-embeddings"))]
async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
Ok(stub_embedding(text, self.dimensions))
}
fn dimension(&self) -> usize {
self.dimensions
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_openai_provider_default_model() {
let p = OpenAiEmbeddingProvider::new("sk-test", None);
assert_eq!(p.model(), "text-embedding-3-small");
assert_eq!(p.dimension(), 1536);
}
#[test]
fn test_openai_provider_large_model() {
let p = OpenAiEmbeddingProvider::new("sk-test", Some("text-embedding-3-large".into()));
assert_eq!(p.dimension(), 3072);
}
#[test]
fn test_openai_provider_custom_dimensions() {
let p = OpenAiEmbeddingProvider::new("sk-test", None).with_dimensions(512);
assert_eq!(p.dimension(), 512);
}
#[test]
fn test_openai_provider_custom_base_url() {
let p = OpenAiEmbeddingProvider::with_base_url(
"sk-test",
None,
"https://my-azure.openai.azure.com/openai/deployments/embed",
);
assert_eq!(p.dimension(), 1536);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_openai_provider_returns_feature_error() {
let p = OpenAiEmbeddingProvider::new("sk-test", None);
let err = p.embed("hello").await.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
}
#[test]
fn test_cohere_provider_default() {
let p = CohereEmbeddingProvider::new("key", None);
assert_eq!(p.model(), "embed-english-v3.0");
assert_eq!(p.dimension(), 1024);
assert_eq!(p.input_type(), "search_document");
}
#[test]
fn test_cohere_provider_query_input_type() {
let p = CohereEmbeddingProvider::new("key", None).with_input_type("search_query");
assert_eq!(p.input_type(), "search_query");
}
#[test]
fn test_cohere_provider_light_model() {
let p = CohereEmbeddingProvider::new("key", Some("embed-english-light-v3.0".into()));
assert_eq!(p.dimension(), 384);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_cohere_provider_returns_feature_error() {
let p = CohereEmbeddingProvider::new("key", None);
let err = p.embed("hello").await.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
}
#[test]
fn test_voyage_provider_default() {
let p = VoyageEmbeddingProvider::new("key", None);
assert_eq!(p.model(), "voyage-2");
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_voyage_provider_code_model() {
let p = VoyageEmbeddingProvider::new("key", Some("voyage-code-2".into()));
assert_eq!(p.dimension(), 1536);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_voyage_provider_returns_feature_error() {
let p = VoyageEmbeddingProvider::new("key", None);
let err = p.embed("hello").await.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("HTTP embeddings not enabled"), "got: {msg}");
}
#[test]
fn test_parse_openai_embedding_response_valid() {
let json = serde_json::json!({
"data": [
{
"embedding": [0.1, 0.2, 0.3, 0.4],
"index": 0
}
],
"model": "text-embedding-3-small"
});
let result = parse_openai_embedding_response(&json).unwrap();
assert_eq!(result, vec![0.1, 0.2, 0.3, 0.4]);
}
#[test]
fn test_parse_openai_embedding_response_empty_data() {
let json = serde_json::json!({
"data": [],
"model": "text-embedding-3-small"
});
let err = parse_openai_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("no embedding data"), "got: {msg}");
}
#[test]
fn test_parse_openai_embedding_response_invalid_shape() {
let json = serde_json::json!({ "error": "bad request" });
let err = parse_openai_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("Failed to parse"), "got: {msg}");
}
#[test]
fn test_parse_openai_embedding_response_multiple_picks_first() {
let json = serde_json::json!({
"data": [
{ "embedding": [1.0, 2.0], "index": 0 },
{ "embedding": [3.0, 4.0], "index": 1 }
],
"model": "text-embedding-3-small"
});
let result = parse_openai_embedding_response(&json).unwrap();
assert_eq!(result, vec![1.0, 2.0]);
}
#[test]
fn test_parse_cohere_embedding_response_valid() {
let json = serde_json::json!({
"embeddings": {
"float": [
[0.5, 0.6, 0.7]
]
}
});
let result = parse_cohere_embedding_response(&json).unwrap();
assert_eq!(result, vec![0.5, 0.6, 0.7]);
}
#[test]
fn test_parse_cohere_embedding_response_empty_float() {
let json = serde_json::json!({
"embeddings": {
"float": []
}
});
let err = parse_cohere_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("no float embeddings"), "got: {msg}");
}
#[test]
fn test_parse_cohere_embedding_response_invalid_shape() {
let json = serde_json::json!({ "message": "unauthorized" });
let err = parse_cohere_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("Failed to parse"), "got: {msg}");
}
#[test]
fn test_parse_cohere_embedding_response_missing_float_key() {
let json = serde_json::json!({
"embeddings": {}
});
let err = parse_cohere_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("no float embeddings"), "got: {msg}");
}
#[test]
fn test_parse_voyage_embedding_response_valid() {
let json = serde_json::json!({
"data": [
{
"embedding": [0.9, 0.8, 0.7, 0.6, 0.5],
"index": 0
}
]
});
let result = parse_voyage_embedding_response(&json).unwrap();
assert_eq!(result, vec![0.9, 0.8, 0.7, 0.6, 0.5]);
}
#[test]
fn test_parse_voyage_embedding_response_empty_data() {
let json = serde_json::json!({ "data": [] });
let err = parse_voyage_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("no embedding data"), "got: {msg}");
}
#[test]
fn test_parse_voyage_embedding_response_invalid_shape() {
let json = serde_json::json!({ "error": "invalid key" });
let err = parse_voyage_embedding_response(&json).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("Failed to parse"), "got: {msg}");
}
#[tokio::test]
async fn test_cache_hit() {
let local = Arc::new(LocalEmbedding::new(64));
let cached = CachedEmbeddingProvider::new(local, 100);
let v1 = cached.embed("hello world").await.unwrap();
let v2 = cached.embed("hello world").await.unwrap();
assert_eq!(v1, v2);
let stats = cached.cache_stats().await;
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.size, 1);
}
#[tokio::test]
async fn test_cache_miss_different_texts() {
let local = Arc::new(LocalEmbedding::new(64));
let cached = CachedEmbeddingProvider::new(local, 100);
let _ = cached.embed("alpha").await.unwrap();
let _ = cached.embed("bravo").await.unwrap();
let stats = cached.cache_stats().await;
assert_eq!(stats.misses, 2);
assert_eq!(stats.hits, 0);
assert_eq!(stats.size, 2);
}
#[tokio::test]
async fn test_cache_eviction() {
let local = Arc::new(LocalEmbedding::new(64));
let cached = CachedEmbeddingProvider::new(local, 2);
let _ = cached.embed("one").await.unwrap();
let _ = cached.embed("two").await.unwrap();
let _ = cached.embed("three").await.unwrap();
let stats = cached.cache_stats().await;
assert!(stats.size <= 2, "size={} should be <= 2", stats.size);
assert_eq!(stats.misses, 3);
}
#[tokio::test]
async fn test_cache_clear() {
let local = Arc::new(LocalEmbedding::new(64));
let cached = CachedEmbeddingProvider::new(local, 100);
let _ = cached.embed("text").await.unwrap();
cached.clear().await;
let stats = cached.cache_stats().await;
assert_eq!(stats.size, 0);
}
#[tokio::test]
async fn test_cache_dimension_delegates() {
let local = Arc::new(LocalEmbedding::new(128));
let cached = CachedEmbeddingProvider::new(local, 10);
assert_eq!(cached.dimension(), 128);
}
#[tokio::test]
async fn test_batch_embed() {
let local = Arc::new(LocalEmbedding::new(64));
let batch = BatchEmbeddingProvider::new(local);
let results = batch
.embed_batch(&["hello", "world", "test"])
.await
.unwrap();
assert_eq!(results.len(), 3);
for v in &results {
assert_eq!(v.len(), 64);
}
}
#[tokio::test]
async fn test_batch_single_embed_delegates() {
let local = Arc::new(LocalEmbedding::new(64));
let batch = BatchEmbeddingProvider::new(local);
let v = batch.embed("hello").await.unwrap();
assert_eq!(v.len(), 64);
}
#[tokio::test]
async fn test_batch_empty() {
let local = Arc::new(LocalEmbedding::new(64));
let batch = BatchEmbeddingProvider::new(local);
let results = batch.embed_batch(&[]).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_batch_dimension_delegates() {
let local = Arc::new(LocalEmbedding::new(200));
let batch = BatchEmbeddingProvider::new(local);
assert_eq!(batch.dimension(), 200);
}
#[test]
fn test_factory_create_local() {
let p = EmbeddingProviderFactory::create("local", "", None).unwrap();
assert_eq!(p.dimension(), 256);
}
#[test]
fn test_factory_create_local_custom_dim() {
let p = EmbeddingProviderFactory::create("local", "", Some("128".into())).unwrap();
assert_eq!(p.dimension(), 128);
}
#[test]
fn test_factory_create_openai() {
let p = EmbeddingProviderFactory::create("openai", "sk-test", None).unwrap();
assert_eq!(p.dimension(), 1536);
}
#[test]
fn test_factory_create_cohere() {
let p = EmbeddingProviderFactory::create("cohere", "key", None).unwrap();
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_factory_create_voyage() {
let p = EmbeddingProviderFactory::create("voyage", "key", None).unwrap();
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_factory_unknown_provider() {
let result = EmbeddingProviderFactory::create("unknown", "", None);
assert!(result.is_err(), "Unknown provider should return Err");
}
#[test]
fn test_factory_available_providers() {
let names = EmbeddingProviderFactory::available_providers();
assert!(names.contains(&"openai"));
assert!(names.contains(&"cohere"));
assert!(names.contains(&"voyage"));
assert!(names.contains(&"local"));
}
#[test]
fn test_config_default() {
let cfg = EmbeddingConfig::default();
assert_eq!(cfg.provider, "local");
assert!(cfg.api_key.is_empty());
assert!(cfg.model.is_none());
assert!(cfg.dimensions.is_none());
assert!(cfg.base_url.is_none());
assert!(cfg.cache_size.is_none());
}
#[test]
fn test_config_serialize_deserialize() {
let cfg = EmbeddingConfig {
provider: "openai".to_string(),
api_key: "sk-123".to_string(),
model: Some("text-embedding-3-small".to_string()),
dimensions: Some(1536),
base_url: None,
cache_size: Some(500),
};
let json = serde_json::to_string(&cfg).unwrap();
let parsed: EmbeddingConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.provider, "openai");
assert_eq!(parsed.api_key, "sk-123");
assert_eq!(parsed.dimensions, Some(1536));
assert_eq!(parsed.cache_size, Some(500));
}
#[test]
fn test_config_deserialize_minimal() {
let json = r#"{"provider":"local"}"#;
let cfg: EmbeddingConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.provider, "local");
assert!(cfg.api_key.is_empty());
}
#[tokio::test]
async fn test_config_build_local() {
let cfg = EmbeddingConfig::default();
let provider = cfg.build().unwrap();
assert_eq!(provider.dimension(), 256);
let v = provider.embed("test text").await.unwrap();
assert_eq!(v.len(), 256);
}
#[tokio::test]
async fn test_config_build_local_with_cache() {
let cfg = EmbeddingConfig {
provider: "local".to_string(),
cache_size: Some(50),
..Default::default()
};
let provider = cfg.build().unwrap();
assert_eq!(provider.dimension(), 256);
let v1 = provider.embed("cached text").await.unwrap();
let v2 = provider.embed("cached text").await.unwrap();
assert_eq!(v1, v2);
}
#[tokio::test]
async fn test_config_build_local_custom_dimensions() {
let cfg = EmbeddingConfig {
provider: "local".to_string(),
dimensions: Some(512),
..Default::default()
};
let provider = cfg.build().unwrap();
assert_eq!(provider.dimension(), 512);
}
#[test]
fn test_config_build_unknown_provider() {
let cfg = EmbeddingConfig {
provider: "imaginary".to_string(),
..Default::default()
};
assert!(cfg.build().is_err());
}
#[test]
fn test_fnv_hash_deterministic() {
let h1 = fnv1a_hash(b"hello world");
let h2 = fnv1a_hash(b"hello world");
assert_eq!(h1, h2);
}
#[test]
fn test_fnv_hash_different_inputs() {
let h1 = fnv1a_hash(b"alpha");
let h2 = fnv1a_hash(b"bravo");
assert_ne!(h1, h2);
}
#[test]
fn test_stub_embedding_length() {
let v = stub_embedding("hello", 128);
assert_eq!(v.len(), 128);
}
#[test]
fn test_stub_embedding_deterministic() {
let v1 = stub_embedding("same input", 64);
let v2 = stub_embedding("same input", 64);
assert_eq!(v1, v2);
}
#[test]
fn test_stub_embedding_normalized() {
let v = stub_embedding("the quick brown fox", 256);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "norm={norm}");
}
#[test]
fn test_stub_embedding_different_inputs_differ() {
let a = stub_embedding("alpha", 64);
let b = stub_embedding("bravo", 64);
assert_ne!(a, b);
}
#[test]
fn test_stub_embedding_empty_text_zeroes() {
let v = stub_embedding("", 32);
assert_eq!(v.len(), 32);
assert!(v.iter().all(|&x| x == 0.0));
}
#[test]
fn test_stub_embedding_zero_dimension_safe() {
let v = stub_embedding("hi", 0);
assert_eq!(v.len(), 1);
}
#[test]
fn test_jina_default_construction() {
let p = JinaEmbeddingProvider::new("jina-key");
assert_eq!(p.model(), "jina-embeddings-v3");
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_jina_with_model_clip() {
let p = JinaEmbeddingProvider::with_model("k", "jina-clip-v2", 768);
assert_eq!(p.model(), "jina-clip-v2");
assert_eq!(p.dimension(), 768);
}
#[test]
fn test_jina_with_base_url() {
let p = JinaEmbeddingProvider::new("k").with_base_url("https://custom.jina/v1");
assert_eq!(p.model(), "jina-embeddings-v3");
}
#[test]
fn test_jina_build_payload_shape() {
let p = JinaEmbeddingProvider::new("k");
let payload = p.build_payload(&["hello".to_string(), "world".to_string()]);
assert_eq!(payload["model"], "jina-embeddings-v3");
assert_eq!(payload["input"][0], "hello");
assert_eq!(payload["input"][1], "world");
}
#[tokio::test]
async fn test_jina_embed_length_matches_dimension() {
let p = JinaEmbeddingProvider::new("k");
#[cfg(not(feature = "http-embeddings"))]
{
let v = p.embed("hello jina").await.unwrap();
assert_eq!(v.len(), 1024);
}
assert_eq!(p.dimension(), 1024);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_jina_stub_is_normalized() {
let p = JinaEmbeddingProvider::new("k");
let v = p.embed("some input").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_jina_stub_deterministic() {
let p = JinaEmbeddingProvider::new("k");
let a = p.embed("consistent").await.unwrap();
let b = p.embed("consistent").await.unwrap();
assert_eq!(a, b);
}
#[test]
fn test_mistral_default_construction() {
let p = MistralEmbedProvider::new("mistral-key");
assert_eq!(p.model(), "mistral-embed");
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_mistral_with_model_and_dimensions() {
let p = MistralEmbedProvider::with_model("k", "mistral-embed-large", 2048);
assert_eq!(p.model(), "mistral-embed-large");
assert_eq!(p.dimension(), 2048);
}
#[test]
fn test_mistral_build_payload_shape() {
let p = MistralEmbedProvider::new("k");
let payload = p.build_payload(&["alpha".to_string()]);
assert_eq!(payload["model"], "mistral-embed");
assert_eq!(payload["input"][0], "alpha");
}
#[test]
fn test_mistral_with_base_url() {
let p = MistralEmbedProvider::new("k").with_base_url("https://custom.mistral/v1");
assert_eq!(p.dimension(), 1024);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_mistral_embed_length() {
let p = MistralEmbedProvider::new("k");
let v = p.embed("hello mistral").await.unwrap();
assert_eq!(v.len(), 1024);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_mistral_stub_normalized() {
let p = MistralEmbedProvider::new("k");
let v = p.embed("normalized?").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_nomic_default_construction() {
let p = NomicEmbedProvider::new("nomic-key");
assert_eq!(p.model(), "nomic-embed-text-v1.5");
assert_eq!(p.dimension(), 768);
assert_eq!(p.task_type(), "search_document");
}
#[test]
fn test_nomic_with_task_type() {
let p = NomicEmbedProvider::new("k").with_task_type("search_query");
assert_eq!(p.task_type(), "search_query");
}
#[test]
fn test_nomic_build_payload_shape() {
let p = NomicEmbedProvider::new("k").with_task_type("clustering");
let payload = p.build_payload(&["doc a".to_string(), "doc b".to_string()]);
assert_eq!(payload["model"], "nomic-embed-text-v1.5");
assert_eq!(payload["texts"][0], "doc a");
assert_eq!(payload["texts"][1], "doc b");
assert_eq!(payload["task_type"], "clustering");
}
#[test]
fn test_nomic_with_model_custom_dims() {
let p = NomicEmbedProvider::with_model("k", "custom-nomic", 512);
assert_eq!(p.dimension(), 512);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_nomic_embed_length() {
let p = NomicEmbedProvider::new("k");
let v = p.embed("nomic test").await.unwrap();
assert_eq!(v.len(), 768);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_nomic_embed_normalized() {
let p = NomicEmbedProvider::new("k");
let v = p.embed("some text").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_sentence_transformers_default_construction() {
let p = SentenceTransformersProvider::new("hf-key");
assert_eq!(p.model(), "sentence-transformers/all-MiniLM-L6-v2");
assert_eq!(p.dimension(), 384);
}
#[test]
fn test_sentence_transformers_mpnet_dims() {
let dims = SentenceTransformersProvider::default_dimensions(
"sentence-transformers/all-mpnet-base-v2",
);
assert_eq!(dims, 768);
}
#[test]
fn test_sentence_transformers_multi_qa_dims() {
let dims = SentenceTransformersProvider::default_dimensions(
"sentence-transformers/multi-qa-mpnet-base-dot-v1",
);
assert_eq!(dims, 768);
}
#[test]
fn test_sentence_transformers_unknown_model_fallback() {
let dims = SentenceTransformersProvider::default_dimensions("sentence-transformers/unknown");
assert_eq!(dims, 384);
}
#[test]
fn test_sentence_transformers_with_model() {
let p = SentenceTransformersProvider::with_model(
"k",
"sentence-transformers/all-mpnet-base-v2",
768,
);
assert_eq!(p.model(), "sentence-transformers/all-mpnet-base-v2");
assert_eq!(p.dimension(), 768);
}
#[test]
fn test_sentence_transformers_build_payload_shape() {
let p = SentenceTransformersProvider::new("k");
let payload = p.build_payload(&["hi".to_string()]);
assert_eq!(payload["inputs"][0], "hi");
assert_eq!(payload["options"]["wait_for_model"], true);
}
#[test]
fn test_sentence_transformers_with_base_url() {
let p = SentenceTransformersProvider::new("k")
.with_base_url("https://self-hosted.hf/embed");
assert_eq!(p.dimension(), 384);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_sentence_transformers_embed_length() {
let p = SentenceTransformersProvider::new("k");
let v = p.embed("minilm test").await.unwrap();
assert_eq!(v.len(), 384);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_sentence_transformers_embed_normalized() {
let p = SentenceTransformersProvider::new("k");
let v = p.embed("some input").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_together_default_construction() {
let p = TogetherEmbedProvider::new("together-key");
assert_eq!(p.model(), "togethercomputer/m2-bert-80M-32k-retrieval");
assert_eq!(p.dimension(), 768);
}
#[test]
fn test_together_with_model() {
let p = TogetherEmbedProvider::with_model("k", "togethercomputer/custom", 1024);
assert_eq!(p.model(), "togethercomputer/custom");
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_together_build_payload_shape() {
let p = TogetherEmbedProvider::new("k");
let payload = p.build_payload(&["x".to_string(), "y".to_string()]);
assert_eq!(payload["model"], "togethercomputer/m2-bert-80M-32k-retrieval");
assert_eq!(payload["input"][0], "x");
assert_eq!(payload["input"][1], "y");
}
#[test]
fn test_together_with_base_url() {
let p = TogetherEmbedProvider::new("k").with_base_url("https://custom.together/v1");
assert_eq!(p.dimension(), 768);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_together_embed_length() {
let p = TogetherEmbedProvider::new("k");
let v = p.embed("together test").await.unwrap();
assert_eq!(v.len(), 768);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_together_embed_normalized() {
let p = TogetherEmbedProvider::new("k");
let v = p.embed("text").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_cohere_v4_default_construction() {
let p = CohereEmbedV4Provider::new("cohere-key");
assert_eq!(p.model(), "embed-english-v3.0");
assert_eq!(p.dimension(), 1024);
assert_eq!(p.input_type(), "search_document");
}
#[test]
fn test_cohere_v4_multilingual_model() {
let p = CohereEmbedV4Provider::with_model("k", "embed-multilingual-v3.0", 1024);
assert_eq!(p.model(), "embed-multilingual-v3.0");
assert_eq!(p.dimension(), 1024);
}
#[test]
fn test_cohere_v4_for_search_document() {
let p = CohereEmbedV4Provider::new("k").for_search_document();
assert_eq!(p.input_type(), "search_document");
}
#[test]
fn test_cohere_v4_for_search_query() {
let p = CohereEmbedV4Provider::new("k").for_search_query();
assert_eq!(p.input_type(), "search_query");
}
#[test]
fn test_cohere_v4_with_input_type() {
let p = CohereEmbedV4Provider::new("k").with_input_type("classification");
assert_eq!(p.input_type(), "classification");
}
#[test]
fn test_cohere_v4_build_payload_shape_document() {
let p = CohereEmbedV4Provider::new("k").for_search_document();
let payload = p.build_payload(&["doc".to_string()]);
assert_eq!(payload["model"], "embed-english-v3.0");
assert_eq!(payload["texts"][0], "doc");
assert_eq!(payload["input_type"], "search_document");
assert_eq!(payload["embedding_types"][0], "float");
}
#[test]
fn test_cohere_v4_build_payload_shape_query() {
let p = CohereEmbedV4Provider::new("k").for_search_query();
let payload = p.build_payload(&["q".to_string()]);
assert_eq!(payload["input_type"], "search_query");
}
#[test]
fn test_cohere_v4_with_base_url() {
let p = CohereEmbedV4Provider::new("k").with_base_url("https://custom.cohere/v2/embed");
assert_eq!(p.dimension(), 1024);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_cohere_v4_embed_length() {
let p = CohereEmbedV4Provider::new("k");
let v = p.embed("cohere v4 test").await.unwrap();
assert_eq!(v.len(), 1024);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_cohere_v4_embed_normalized() {
let p = CohereEmbedV4Provider::new("k");
let v = p.embed("x").await.unwrap();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[cfg(not(feature = "http-embeddings"))]
#[tokio::test]
async fn test_cohere_v4_embed_deterministic() {
let p = CohereEmbedV4Provider::new("k");
let a = p.embed("same").await.unwrap();
let b = p.embed("same").await.unwrap();
assert_eq!(a, b);
}
#[test]
fn test_all_new_providers_implement_embedding_provider_trait() {
let _boxes: Vec<Box<dyn EmbeddingProvider>> = vec![
Box::new(JinaEmbeddingProvider::new("k")),
Box::new(MistralEmbedProvider::new("k")),
Box::new(NomicEmbedProvider::new("k")),
Box::new(SentenceTransformersProvider::new("k")),
Box::new(TogetherEmbedProvider::new("k")),
Box::new(CohereEmbedV4Provider::new("k")),
];
}
#[test]
fn test_new_providers_have_expected_dimensions() {
assert_eq!(JinaEmbeddingProvider::new("k").dimension(), 1024);
assert_eq!(MistralEmbedProvider::new("k").dimension(), 1024);
assert_eq!(NomicEmbedProvider::new("k").dimension(), 768);
assert_eq!(SentenceTransformersProvider::new("k").dimension(), 384);
assert_eq!(TogetherEmbedProvider::new("k").dimension(), 768);
assert_eq!(CohereEmbedV4Provider::new("k").dimension(), 1024);
}
}