use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RemoteEmbeddingError {
#[error("HTTP client error: {0}")]
HttpClient(String),
#[error("API request failed: {0}")]
ApiError(String),
#[error("Invalid API response: {0}")]
InvalidResponse(String),
#[error("API key not configured for provider: {0}")]
ApiKeyNotFound(String),
#[error("Rate limit exceeded for provider: {0}")]
RateLimitExceeded(String),
#[error("Invalid embedding dimension: expected {expected}, got {got}")]
InvalidDimension {
expected: usize,
got: usize,
},
#[error("Feature not enabled: remote-embeddings feature is required")]
FeatureNotEnabled,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum RemoteProvider {
OpenAI {
model: String,
},
Cohere {
model: String,
},
Custom {
endpoint: String,
},
}
impl Default for RemoteProvider {
fn default() -> Self {
Self::OpenAI {
model: "text-embedding-3-small".to_string(),
}
}
}
impl RemoteProvider {
pub fn default_dimension(&self) -> usize {
match self {
Self::OpenAI { model } => match model.as_str() {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
_ => 1536,
},
Self::Cohere { model } => match model.as_str() {
"embed-english-v3.0" => 1024,
"embed-multilingual-v3.0" => 1024,
_ => 1024,
},
Self::Custom { .. } => 1536, }
}
}
#[derive(Debug, Clone)]
pub struct RemoteEmbeddingConfig {
pub provider: RemoteProvider,
pub api_key: Option<String>,
pub timeout_secs: u64,
pub max_retries: usize,
pub base_url: Option<String>,
}
impl Default for RemoteEmbeddingConfig {
fn default() -> Self {
Self {
provider: RemoteProvider::default(),
api_key: None,
timeout_secs: 30,
max_retries: 3,
base_url: None,
}
}
}
impl RemoteEmbeddingConfig {
pub fn openai(api_key: String, model: Option<String>) -> Self {
Self {
provider: RemoteProvider::OpenAI {
model: model.unwrap_or_else(|| "text-embedding-3-small".to_string()),
},
api_key: Some(api_key),
..Default::default()
}
}
pub fn cohere(api_key: String, model: Option<String>) -> Self {
Self {
provider: RemoteProvider::Cohere {
model: model.unwrap_or_else(|| "embed-english-v3.0".to_string()),
},
api_key: Some(api_key),
base_url: Some("https://api.cohere.ai/v1".to_string()),
..Default::default()
}
}
pub fn custom(endpoint: String, api_key: Option<String>) -> Self {
Self {
provider: RemoteProvider::Custom { endpoint },
api_key,
..Default::default()
}
}
}
#[async_trait]
pub trait RemoteEmbeddingProvider: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>, RemoteEmbeddingError>;
async fn embed_batch(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>, RemoteEmbeddingError>;
fn dimension(&self) -> usize;
}
pub struct OpenAIEmbeddingProvider {
client: Client,
config: RemoteEmbeddingConfig,
}
impl OpenAIEmbeddingProvider {
pub fn new(config: RemoteEmbeddingConfig) -> Result<Self, RemoteEmbeddingError> {
if config.api_key.is_none() {
return Err(RemoteEmbeddingError::ApiKeyNotFound("OpenAI".to_string()));
}
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| RemoteEmbeddingError::HttpClient(e.to_string()))?;
Ok(Self { client, config })
}
}
#[async_trait]
impl RemoteEmbeddingProvider for OpenAIEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>, RemoteEmbeddingError> {
let embeddings = self.embed_batch(vec![text]).await?;
embeddings.into_iter().next().ok_or_else(|| {
RemoteEmbeddingError::InvalidResponse("No embedding returned".to_string())
})
}
async fn embed_batch(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>, RemoteEmbeddingError> {
let model_name = match &self.config.provider {
RemoteProvider::OpenAI { model } => model.clone(),
_ => {
return Err(RemoteEmbeddingError::ApiError(
"Invalid provider".to_string(),
))
}
};
#[derive(Serialize)]
struct OpenAIRequest<'a> {
model: String,
input: Vec<&'a str>,
encoding_format: String,
}
#[derive(Deserialize)]
struct OpenAIResponse {
data: Vec<OpenAIEmbedding>,
}
#[derive(Deserialize)]
struct OpenAIEmbedding {
embedding: Vec<f32>,
}
let request = OpenAIRequest {
model: model_name,
input: texts,
encoding_format: "float".to_string(),
};
let api_key = self.config.api_key.as_ref().unwrap();
let response = self
.client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| RemoteEmbeddingError::ApiError(e.to_string()))?;
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| RemoteEmbeddingError::ApiError(e.to_string()))?;
if !status.is_success() {
if status.as_u16() == 429 {
return Err(RemoteEmbeddingError::RateLimitExceeded(
"OpenAI".to_string(),
));
}
return Err(RemoteEmbeddingError::ApiError(format!(
"API returned {}: {}",
status, response_text
)));
}
let openai_response: OpenAIResponse = serde_json::from_str(&response_text)
.map_err(|e| RemoteEmbeddingError::InvalidResponse(e.to_string()))?;
Ok(openai_response
.data
.into_iter()
.map(|e| e.embedding)
.collect())
}
fn dimension(&self) -> usize {
self.config.provider.default_dimension()
}
}
pub struct CohereEmbeddingProvider {
client: Client,
config: RemoteEmbeddingConfig,
}
impl CohereEmbeddingProvider {
pub fn new(config: RemoteEmbeddingConfig) -> Result<Self, RemoteEmbeddingError> {
if config.api_key.is_none() {
return Err(RemoteEmbeddingError::ApiKeyNotFound("Cohere".to_string()));
}
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| RemoteEmbeddingError::HttpClient(e.to_string()))?;
Ok(Self { client, config })
}
}
#[async_trait]
impl RemoteEmbeddingProvider for CohereEmbeddingProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>, RemoteEmbeddingError> {
let embeddings = self.embed_batch(vec![text]).await?;
embeddings.into_iter().next().ok_or_else(|| {
RemoteEmbeddingError::InvalidResponse("No embedding returned".to_string())
})
}
async fn embed_batch(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>, RemoteEmbeddingError> {
let model_name = match &self.config.provider {
RemoteProvider::Cohere { model } => model.clone(),
_ => {
return Err(RemoteEmbeddingError::ApiError(
"Invalid provider".to_string(),
))
}
};
let base_url = self.config.base_url.as_ref().unwrap();
#[derive(Serialize)]
struct CohereRequest<'a> {
model: String,
texts: Vec<&'a str>,
input_type: String,
}
#[derive(Deserialize)]
struct CohereResponse {
embeddings: Vec<CohereEmbedding>,
}
#[derive(Deserialize)]
struct CohereEmbedding {
embedding: Vec<f32>,
}
let request = CohereRequest {
model: model_name,
texts,
input_type: "search_document".to_string(),
};
let api_key = self.config.api_key.as_ref().unwrap();
let url = format!("{}/embed", base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.header("X-Client-Name", "leindex")
.json(&request)
.send()
.await
.map_err(|e| RemoteEmbeddingError::ApiError(e.to_string()))?;
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| RemoteEmbeddingError::ApiError(e.to_string()))?;
if !status.is_success() {
if status.as_u16() == 429 {
return Err(RemoteEmbeddingError::RateLimitExceeded(
"Cohere".to_string(),
));
}
return Err(RemoteEmbeddingError::ApiError(format!(
"API returned {}: {}",
status, response_text
)));
}
let cohere_response: CohereResponse = serde_json::from_str(&response_text)
.map_err(|e| RemoteEmbeddingError::InvalidResponse(e.to_string()))?;
Ok(cohere_response
.embeddings
.into_iter()
.map(|e| e.embedding)
.collect())
}
fn dimension(&self) -> usize {
self.config.provider.default_dimension()
}
}
#[derive(Clone)]
pub struct GenericRemoteProvider {
provider: Arc<dyn RemoteEmbeddingProvider>,
}
impl std::fmt::Debug for GenericRemoteProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GenericRemoteProvider")
.field("provider", &"<RemoteEmbeddingProvider>")
.finish()
}
}
impl GenericRemoteProvider {
pub fn from_config(config: RemoteEmbeddingConfig) -> Result<Self, RemoteEmbeddingError> {
let provider: Arc<dyn RemoteEmbeddingProvider> = match &config.provider {
RemoteProvider::OpenAI { .. } => Arc::new(OpenAIEmbeddingProvider::new(config)?),
RemoteProvider::Cohere { .. } => Arc::new(CohereEmbeddingProvider::new(config)?),
RemoteProvider::Custom { .. } => {
return Err(RemoteEmbeddingError::ApiError(
"Custom provider not yet implemented".to_string(),
));
}
};
Ok(Self { provider })
}
}
#[async_trait]
impl RemoteEmbeddingProvider for GenericRemoteProvider {
async fn embed(&self, text: &str) -> Result<Vec<f32>, RemoteEmbeddingError> {
self.provider.embed(text).await
}
async fn embed_batch(&self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>, RemoteEmbeddingError> {
self.provider.embed_batch(texts).await
}
fn dimension(&self) -> usize {
self.provider.dimension()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remote_provider_default_dimension() {
let provider = RemoteProvider::OpenAI {
model: "text-embedding-3-small".to_string(),
};
assert_eq!(provider.default_dimension(), 1536);
}
#[test]
fn test_remote_config_openai() {
let config = RemoteEmbeddingConfig::openai("test-key".to_string(), None);
assert!(config.api_key.is_some());
}
}