Skip to main content

synaptic_huggingface/
lib.rs

1pub mod reranker;
2pub use reranker::{BgeRerankerModel, HuggingFaceReranker};
3
4use async_trait::async_trait;
5use synaptic_core::{Embeddings, SynapticError};
6
7#[derive(Debug, Clone)]
8pub struct HuggingFaceEmbeddingsConfig {
9    pub model: String,
10    pub api_key: Option<String>,
11    pub base_url: String,
12    pub wait_for_model: bool,
13}
14
15impl HuggingFaceEmbeddingsConfig {
16    pub fn new(model: impl Into<String>) -> Self {
17        Self {
18            model: model.into(),
19            api_key: None,
20            base_url: "https://api-inference.huggingface.co/models".to_string(),
21            wait_for_model: true,
22        }
23    }
24    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
25        self.api_key = Some(api_key.into());
26        self
27    }
28    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
29        self.base_url = base_url.into();
30        self
31    }
32    pub fn with_wait_for_model(mut self, wait: bool) -> Self {
33        self.wait_for_model = wait;
34        self
35    }
36}
37
38pub struct HuggingFaceEmbeddings {
39    config: HuggingFaceEmbeddingsConfig,
40    client: reqwest::Client,
41}
42
43impl HuggingFaceEmbeddings {
44    pub fn new(config: HuggingFaceEmbeddingsConfig) -> Self {
45        Self {
46            config,
47            client: reqwest::Client::new(),
48        }
49    }
50    pub fn with_client(config: HuggingFaceEmbeddingsConfig, client: reqwest::Client) -> Self {
51        Self { config, client }
52    }
53
54    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
55        if texts.is_empty() {
56            return Ok(Vec::new());
57        }
58        let url = format!("{}/{}", self.config.base_url, self.config.model);
59        let body = serde_json::json!({ "inputs": texts });
60        let mut request = self
61            .client
62            .post(&url)
63            .header("Content-Type", "application/json");
64        if let Some(ref key) = self.config.api_key {
65            request = request.header("Authorization", format!("Bearer {key}"));
66        }
67        if self.config.wait_for_model {
68            request = request.header("x-wait-for-model", "true");
69        }
70        let response = request
71            .json(&body)
72            .send()
73            .await
74            .map_err(|e| SynapticError::Embedding(format!("HuggingFace request: {e}")))?;
75        let status = response.status();
76        if status.is_client_error() || status.is_server_error() {
77            let code = status.as_u16();
78            let text = response.text().await.unwrap_or_default();
79            return Err(SynapticError::Embedding(format!(
80                "HuggingFace API error ({code}): {text}"
81            )));
82        }
83        let resp: serde_json::Value = response
84            .json()
85            .await
86            .map_err(|e| SynapticError::Embedding(format!("HuggingFace parse: {e}")))?;
87        parse_hf_response(&resp)
88    }
89}
90
91fn parse_hf_response(resp: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
92    let array = if let Some(arr) = resp.as_array() {
93        arr
94    } else if let Some(arr) = resp.get("embeddings").and_then(|e| e.as_array()) {
95        arr
96    } else {
97        return Err(SynapticError::Embedding(
98            "unexpected HuggingFace response format".to_string(),
99        ));
100    };
101    let mut result = Vec::with_capacity(array.len());
102    for item in array {
103        let embedding: Vec<f32> = item
104            .as_array()
105            .ok_or_else(|| SynapticError::Embedding("embedding item is not array".to_string()))?
106            .iter()
107            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
108            .collect();
109        result.push(embedding);
110    }
111    Ok(result)
112}
113
114#[async_trait]
115impl Embeddings for HuggingFaceEmbeddings {
116    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
117        self.embed_batch(texts).await
118    }
119    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
120        let mut results = self.embed_batch(&[text]).await?;
121        results
122            .pop()
123            .ok_or_else(|| SynapticError::Embedding("empty HuggingFace response".to_string()))
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn config_defaults() {
133        let c = HuggingFaceEmbeddingsConfig::new("BAAI/bge-small-en-v1.5");
134        assert_eq!(c.model, "BAAI/bge-small-en-v1.5");
135    }
136
137    #[test]
138    fn config_builder() {
139        let c = HuggingFaceEmbeddingsConfig::new("model")
140            .with_api_key("hf_test")
141            .with_wait_for_model(false);
142        assert_eq!(c.api_key, Some("hf_test".to_string()));
143    }
144
145    #[test]
146    fn parse_direct_array() {
147        let resp = serde_json::json!([[0.1_f32, 0.2_f32]]);
148        let result = parse_hf_response(&resp).unwrap();
149        assert_eq!(result.len(), 1);
150    }
151}