use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error};
use super::base::EmbeddingBase;
use crate::config::JinaEmbedderConfig;
use crate::error::{NeomemxError, Result};
pub struct JinaEmbedding {
api_key: String,
url: String,
model: String,
embedding_dims: usize,
client: Client,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
detail: Option<String>,
error: Option<String>,
}
impl JinaEmbedding {
pub fn new(config: JinaEmbedderConfig) -> Result<Self> {
let api_key = config.get_api_key().ok_or_else(|| {
NeomemxError::EmbeddingError(
"Jina API key not found. Get your FREE key at: https://jina.ai/embeddings/ \
Then set JINA_API_KEY environment variable."
.to_string(),
)
})?;
let url = {
let base = config.base_url.trim_end_matches('/');
if base.ends_with("embeddings") {
base.to_string()
} else {
format!("{}/embeddings", base)
}
};
let client = Client::builder()
.pool_max_idle_per_host(16)
.pool_idle_timeout(std::time::Duration::from_secs(90))
.tcp_keepalive(std::time::Duration::from_secs(60))
.timeout(std::time::Duration::from_secs(30))
.no_proxy()
.build()
.map_err(|e| {
NeomemxError::EmbeddingError(format!("Failed to create HTTP client: {}", e))
})?;
Ok(Self {
api_key,
url,
model: config.model,
embedding_dims: config.embedding_dims,
client,
})
}
async fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = EmbeddingRequest {
model: &self.model,
input: texts.iter().map(|t| t.trim()).collect(),
};
debug!("Jina batch embedding: {} texts", texts.len());
let response = self
.client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
let error_msg = serde_json::from_str::<ErrorResponse>(&body)
.ok()
.and_then(|e| e.detail.or(e.error))
.unwrap_or_else(|| body.clone());
error!("Jina API error: {}", error_msg);
return Err(NeomemxError::EmbeddingError(format!(
"Jina API error: {}",
error_msg
)));
}
let parsed: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
NeomemxError::EmbeddingError(format!("Failed to parse Jina response: {}", e))
})?;
if parsed.data.len() != texts.len() {
return Err(NeomemxError::EmbeddingError(format!(
"Jina returned {} embeddings but expected {}",
parsed.data.len(),
texts.len()
)));
}
Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
}
}
#[async_trait]
impl EmbeddingBase for JinaEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut results = self.embed_texts(&[text]).await?;
results.pop().ok_or_else(|| {
NeomemxError::EmbeddingError("No embeddings in Jina response".to_string())
})
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embed_texts(texts).await
}
fn embedding_dims(&self) -> usize {
self.embedding_dims
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = JinaEmbedderConfig::default();
assert_eq!(config.model, "jina-embeddings-v2-base-en");
assert_eq!(config.embedding_dims, 768);
}
#[test]
fn test_custom_model() {
let config = JinaEmbedderConfig::with_model("jina-embeddings-v3");
assert_eq!(config.model, "jina-embeddings-v3");
assert_eq!(config.embedding_dims, 768);
}
}