neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! HuggingFace Inference API embeddings

use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error};

use super::base::EmbeddingBase;
use crate::config::HuggingFaceEmbedderConfig;
use crate::error::{NeomemxError, Result};

/// HuggingFace embedding client backed by the Inference API
pub struct HuggingFaceEmbedding {
    api_key: String,
    url: String,
    model: String,
    embedding_dims: usize,
    client: Client,
}

#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
    inputs: Vec<&'a str>,
    #[serde(skip_serializing_if = "Option::is_none")]
    options: Option<RequestOptions>,
}

#[derive(Debug, Serialize)]
struct RequestOptions {
    /// Instruct HF to wait until the model is ready instead of returning 503
    wait_for_model: bool,
}

#[derive(Debug, Deserialize)]
struct ErrorResponse {
    error: Option<serde_json::Value>,
}

impl HuggingFaceEmbedding {
    /// Creates a new HuggingFace embedding client
    pub fn new(config: HuggingFaceEmbedderConfig) -> Result<Self> {
        let api_key = config.get_api_key().ok_or_else(|| {
            NeomemxError::EmbeddingError(
                "HuggingFace API key not found. Set HUGGINGFACE_API_KEY environment variable or \
                 provide it in the configuration."
                    .to_string(),
            )
        })?;

        let client = Client::builder()
            .pool_max_idle_per_host(16)
            .pool_idle_timeout(std::time::Duration::from_secs(90))
            .tcp_keepalive(std::time::Duration::from_secs(60))
            .timeout(std::time::Duration::from_secs(30))
            .build()
            .map_err(|e| {
                NeomemxError::EmbeddingError(format!("Failed to create HTTP client: {}", e))
            })?;

        let url = format!("{}/{}", config.base_url.trim_end_matches('/'), config.model);

        Ok(Self {
            api_key,
            url,
            model: config.model,
            embedding_dims: config.embedding_dims,
            client,
        })
    }

    fn validate_dims(&self, embedding: &[f32]) -> Result<()> {
        if embedding.len() != self.embedding_dims {
            return Err(NeomemxError::EmbeddingError(format!(
                "HuggingFace returned embedding with dimension {} but expected {}",
                embedding.len(),
                self.embedding_dims
            )));
        }
        Ok(())
    }

    fn parse_embeddings(&self, body: &str, expected_inputs: usize) -> Result<Vec<Vec<f32>>> {
        use serde_json::Value;

        let value: Value = serde_json::from_str(body).map_err(|e| {
            NeomemxError::EmbeddingError(format!("Failed to parse HuggingFace response: {}", e))
        })?;

        fn as_f32_slice(values: &[serde_json::Value]) -> Result<Vec<f32>> {
            values
                .iter()
                .map(|v| {
                    v.as_f64().map(|n| n as f32).ok_or_else(|| {
                        NeomemxError::EmbeddingError(
                            "Non-numeric value in HuggingFace embedding response".to_string(),
                        )
                    })
                })
                .collect()
        }

        fn average_token_embeddings(tokens: &[serde_json::Value], dims: usize) -> Result<Vec<f32>> {
            if tokens.is_empty() {
                return Err(NeomemxError::EmbeddingError(
                    "HuggingFace returned empty token embeddings".to_string(),
                ));
            }

            let mut sums = vec![0f32; dims];
            let mut count = 0usize;

            for token in tokens {
                let token_values = token.as_array().ok_or_else(|| {
                    NeomemxError::EmbeddingError(
                        "Unexpected token structure in HuggingFace response".to_string(),
                    )
                })?;

                if token_values.len() != dims {
                    return Err(NeomemxError::EmbeddingError(format!(
                        "Token embedding dimension {} does not match expected {}",
                        token_values.len(),
                        dims
                    )));
                }

                for (i, value) in token_values.iter().enumerate() {
                    let val = value.as_f64().ok_or_else(|| {
                        NeomemxError::EmbeddingError(
                            "Non-numeric value in HuggingFace embedding response".to_string(),
                        )
                    })?;
                    sums[i] += val as f32;
                }
                count += 1;
            }

            for sum in &mut sums {
                *sum /= count as f32;
            }

            Ok(sums)
        }

        match value {
            // Could be a single vector (1D) or a set of token embeddings (2D)
            serde_json::Value::Array(arr) if expected_inputs == 1 => {
                if arr.iter().all(|v| v.is_number()) {
                    let embedding = as_f32_slice(&arr)?;
                    self.validate_dims(&embedding)?;
                    return Ok(vec![embedding]);
                }

                if let Some(first) = arr.first() {
                    if first.is_array() {
                        // Token-level embeddings for a single input
                        let embedding = average_token_embeddings(&arr, self.embedding_dims)?;
                        self.validate_dims(&embedding)?;
                        return Ok(vec![embedding]);
                    }
                }
            }
            // Could be batch embeddings (2D) or batch token embeddings (3D)
            serde_json::Value::Array(arr) => {
                if arr.len() != expected_inputs {
                    return Err(NeomemxError::EmbeddingError(format!(
                        "HuggingFace returned {} embeddings but expected {}",
                        arr.len(),
                        expected_inputs
                    )));
                }

                if let Some(first) = arr.first() {
                    // Batch token embeddings
                    if first
                        .as_array()
                        .and_then(|a| a.first())
                        .map(|v| v.is_array())
                        .unwrap_or(false)
                    {
                        let mut embeddings = Vec::with_capacity(arr.len());
                        for token_set in arr {
                            let tokens = token_set.as_array().ok_or_else(|| {
                                NeomemxError::EmbeddingError(
                                    "Unexpected token structure in HuggingFace response"
                                        .to_string(),
                                )
                            })?;
                            let embedding = average_token_embeddings(tokens, self.embedding_dims)?;
                            self.validate_dims(&embedding)?;
                            embeddings.push(embedding);
                        }
                        return Ok(embeddings);
                    }

                    // Batch embeddings (2D)
                    if first
                        .as_array()
                        .map(|inner| inner.iter().all(|v| v.is_number()))
                        .unwrap_or(false)
                    {
                        let mut embeddings = Vec::with_capacity(arr.len());
                        for entry in arr {
                            let embedding_values = entry.as_array().ok_or_else(|| {
                                NeomemxError::EmbeddingError(
                                    "Unexpected embedding structure in HuggingFace response"
                                        .to_string(),
                                )
                            })?;
                            let embedding = as_f32_slice(embedding_values)?;
                            self.validate_dims(&embedding)?;
                            embeddings.push(embedding);
                        }
                        return Ok(embeddings);
                    }
                }
            }
            _ => {}
        }

        Err(NeomemxError::EmbeddingError(
            "Unsupported HuggingFace embedding response format".to_string(),
        ))
    }

    async fn embed_texts(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }

        let request = EmbeddingRequest {
            inputs: texts.iter().map(|t| t.trim()).collect(),
            options: Some(RequestOptions {
                wait_for_model: true,
            }),
        };

        debug!("HuggingFace batch embedding: {} texts", texts.len());

        let response = self
            .client
            .post(&self.url)
            .header("Content-Type", "application/json")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?;

        let status = response.status();
        let body = response.text().await?;

        if !status.is_success() {
            let error_msg = serde_json::from_str::<ErrorResponse>(&body)
                .ok()
                .and_then(|e| e.error)
                .map(|v| {
                    v.as_str()
                        .map(|s| s.to_string())
                        .unwrap_or_else(|| v.to_string())
                })
                .unwrap_or_else(|| body.clone());
            error!("HuggingFace API error: {}", error_msg);
            return Err(NeomemxError::EmbeddingError(format!(
                "HuggingFace API error: {}",
                error_msg
            )));
        }

        self.parse_embeddings(&body, texts.len())
    }
}

#[async_trait]
impl EmbeddingBase for HuggingFaceEmbedding {
    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
        let mut results = self.embed_texts(&[text]).await?;
        results.pop().ok_or_else(|| {
            NeomemxError::EmbeddingError("No embeddings in HuggingFace response".to_string())
        })
    }

    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
        self.embed_texts(texts).await
    }

    fn embedding_dims(&self) -> usize {
        self.embedding_dims
    }
}

#[cfg(test)]
mod tests {
    use crate::config::HuggingFaceEmbedderConfig;

    #[test]
    fn test_default_config() {
        let config = HuggingFaceEmbedderConfig::default();
        assert_eq!(config.model, "BAAI/bge-small-en-v1.5");
        assert_eq!(config.embedding_dims, 384);
        assert_eq!(
            config.base_url,
            "https://api-inference.huggingface.co/models"
        );
    }

    #[test]
    fn test_custom_model() {
        let config = HuggingFaceEmbedderConfig {
            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
            ..Default::default()
        };
        assert_eq!(config.model, "sentence-transformers/all-MiniLM-L6-v2");
        assert_eq!(config.embedding_dims, 384);
    }
}