use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error};
use super::base::EmbeddingBase;
use crate::config::HuggingFaceEmbedderConfig;
use crate::error::{NeomemxError, Result};
pub struct HuggingFaceEmbedding {
api_key: String,
url: String,
model: String,
embedding_dims: usize,
client: Client,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
inputs: Vec<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
options: Option<RequestOptions>,
}
#[derive(Debug, Serialize)]
struct RequestOptions {
wait_for_model: bool,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: Option<serde_json::Value>,
}
impl HuggingFaceEmbedding {
pub fn new(config: HuggingFaceEmbedderConfig) -> Result<Self> {
let api_key = config.get_api_key().ok_or_else(|| {
NeomemxError::EmbeddingError(
"HuggingFace API key not found. Set HUGGINGFACE_API_KEY environment variable or \
provide it in the configuration."
.to_string(),
)
})?;
let client = Client::builder()
.pool_max_idle_per_host(16)
.pool_idle_timeout(std::time::Duration::from_secs(90))
.tcp_keepalive(std::time::Duration::from_secs(60))
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| {
NeomemxError::EmbeddingError(format!("Failed to create HTTP client: {}", e))
})?;
let url = format!("{}/{}", config.base_url.trim_end_matches('/'), config.model);
Ok(Self {
api_key,
url,
model: config.model,
embedding_dims: config.embedding_dims,
client,
})
}
fn validate_dims(&self, embedding: &[f32]) -> Result<()> {
if embedding.len() != self.embedding_dims {
return Err(NeomemxError::EmbeddingError(format!(
"HuggingFace returned embedding with dimension {} but expected {}",
embedding.len(),
self.embedding_dims
)));
}
Ok(())
}
fn parse_embeddings(&self, body: &str, expected_inputs: usize) -> Result<Vec<Vec<f32>>> {
use serde_json::Value;
let value: Value = serde_json::from_str(body).map_err(|e| {
NeomemxError::EmbeddingError(format!("Failed to parse HuggingFace response: {}", e))
})?;
fn as_f32_slice(values: &[serde_json::Value]) -> Result<Vec<f32>> {
values
.iter()
.map(|v| {
v.as_f64().map(|n| n as f32).ok_or_else(|| {
NeomemxError::EmbeddingError(
"Non-numeric value in HuggingFace embedding response".to_string(),
)
})
})
.collect()
}
fn average_token_embeddings(tokens: &[serde_json::Value], dims: usize) -> Result<Vec<f32>> {
if tokens.is_empty() {
return Err(NeomemxError::EmbeddingError(
"HuggingFace returned empty token embeddings".to_string(),
));
}
let mut sums = vec![0f32; dims];
let mut count = 0usize;
for token in tokens {
let token_values = token.as_array().ok_or_else(|| {
NeomemxError::EmbeddingError(
"Unexpected token structure in HuggingFace response".to_string(),
)
})?;
if token_values.len() != dims {
return Err(NeomemxError::EmbeddingError(format!(
"Token embedding dimension {} does not match expected {}",
token_values.len(),
dims
)));
}
for (i, value) in token_values.iter().enumerate() {
let val = value.as_f64().ok_or_else(|| {
NeomemxError::EmbeddingError(
"Non-numeric value in HuggingFace embedding response".to_string(),
)
})?;
sums[i] += val as f32;
}
count += 1;
}
for sum in &mut sums {
*sum /= count as f32;
}
Ok(sums)
}
match value {
serde_json::Value::Array(arr) if expected_inputs == 1 => {
if arr.iter().all(|v| v.is_number()) {
let embedding = as_f32_slice(&arr)?;
self.validate_dims(&embedding)?;
return Ok(vec![embedding]);
}
if let Some(first) = arr.first() {
if first.is_array() {
let embedding = average_token_embeddings(&arr, self.embedding_dims)?;
self.validate_dims(&embedding)?;
return Ok(vec![embedding]);
}
}
}
serde_json::Value::Array(arr) => {
if arr.len() != expected_inputs {
return Err(NeomemxError::EmbeddingError(format!(
"HuggingFace returned {} embeddings but expected {}",
arr.len(),
expected_inputs
)));
}
if let Some(first) = arr.first() {
if first
.as_array()
.and_then(|a| a.first())
.map(|v| v.is_array())
.unwrap_or(false)
{
let mut embeddings = Vec::with_capacity(arr.len());
for token_set in arr {
let tokens = token_set.as_array().ok_or_else(|| {
NeomemxError::EmbeddingError(
"Unexpected token structure in HuggingFace response"
.to_string(),
)
})?;
let embedding = average_token_embeddings(tokens, self.embedding_dims)?;
self.validate_dims(&embedding)?;
embeddings.push(embedding);
}
return Ok(embeddings);
}
if first
.as_array()
.map(|inner| inner.iter().all(|v| v.is_number()))
.unwrap_or(false)
{
let mut embeddings = Vec::with_capacity(arr.len());
for entry in arr {
let embedding_values = entry.as_array().ok_or_else(|| {
NeomemxError::EmbeddingError(
"Unexpected embedding structure in HuggingFace response"
.to_string(),
)
})?;
let embedding = as_f32_slice(embedding_values)?;
self.validate_dims(&embedding)?;
embeddings.push(embedding);
}
return Ok(embeddings);
}
}
}
_ => {}
}
Err(NeomemxError::EmbeddingError(
"Unsupported HuggingFace embedding response format".to_string(),
))
}
async fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let request = EmbeddingRequest {
inputs: texts.iter().map(|t| t.trim()).collect(),
options: Some(RequestOptions {
wait_for_model: true,
}),
};
debug!("HuggingFace batch embedding: {} texts", texts.len());
let response = self
.client
.post(&self.url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
let error_msg = serde_json::from_str::<ErrorResponse>(&body)
.ok()
.and_then(|e| e.error)
.map(|v| {
v.as_str()
.map(|s| s.to_string())
.unwrap_or_else(|| v.to_string())
})
.unwrap_or_else(|| body.clone());
error!("HuggingFace API error: {}", error_msg);
return Err(NeomemxError::EmbeddingError(format!(
"HuggingFace API error: {}",
error_msg
)));
}
self.parse_embeddings(&body, texts.len())
}
}
#[async_trait]
impl EmbeddingBase for HuggingFaceEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut results = self.embed_texts(&[text]).await?;
results.pop().ok_or_else(|| {
NeomemxError::EmbeddingError("No embeddings in HuggingFace response".to_string())
})
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embed_texts(texts).await
}
fn embedding_dims(&self) -> usize {
self.embedding_dims
}
}
#[cfg(test)]
mod tests {
use crate::config::HuggingFaceEmbedderConfig;
#[test]
fn test_default_config() {
let config = HuggingFaceEmbedderConfig::default();
assert_eq!(config.model, "BAAI/bge-small-en-v1.5");
assert_eq!(config.embedding_dims, 384);
assert_eq!(
config.base_url,
"https://api-inference.huggingface.co/models"
);
}
#[test]
fn test_custom_model() {
let config = HuggingFaceEmbedderConfig {
model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
..Default::default()
};
assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
assert_eq!(config.embedding_dims, 384);
}
}