Skip to main content

synaptic_huggingface/
lib.rs

1use async_trait::async_trait;
2use synaptic_core::{Embeddings, SynapticError};
3
4#[derive(Debug, Clone)]
5pub struct HuggingFaceEmbeddingsConfig {
6    pub model: String,
7    pub api_key: Option<String>,
8    pub base_url: String,
9    pub wait_for_model: bool,
10}
11
12impl HuggingFaceEmbeddingsConfig {
13    pub fn new(model: impl Into<String>) -> Self {
14        Self {
15            model: model.into(),
16            api_key: None,
17            base_url: "https://api-inference.huggingface.co/models".to_string(),
18            wait_for_model: true,
19        }
20    }
21    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
22        self.api_key = Some(api_key.into());
23        self
24    }
25    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
26        self.base_url = base_url.into();
27        self
28    }
29    pub fn with_wait_for_model(mut self, wait: bool) -> Self {
30        self.wait_for_model = wait;
31        self
32    }
33}
34
35pub struct HuggingFaceEmbeddings {
36    config: HuggingFaceEmbeddingsConfig,
37    client: reqwest::Client,
38}
39
40impl HuggingFaceEmbeddings {
41    pub fn new(config: HuggingFaceEmbeddingsConfig) -> Self {
42        Self {
43            config,
44            client: reqwest::Client::new(),
45        }
46    }
47    pub fn with_client(config: HuggingFaceEmbeddingsConfig, client: reqwest::Client) -> Self {
48        Self { config, client }
49    }
50
51    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
52        if texts.is_empty() {
53            return Ok(Vec::new());
54        }
55        let url = format!("{}/{}", self.config.base_url, self.config.model);
56        let body = serde_json::json!({ "inputs": texts });
57        let mut request = self
58            .client
59            .post(&url)
60            .header("Content-Type", "application/json");
61        if let Some(ref key) = self.config.api_key {
62            request = request.header("Authorization", format!("Bearer {key}"));
63        }
64        if self.config.wait_for_model {
65            request = request.header("x-wait-for-model", "true");
66        }
67        let response = request
68            .json(&body)
69            .send()
70            .await
71            .map_err(|e| SynapticError::Embedding(format!("HuggingFace request: {e}")))?;
72        let status = response.status();
73        if status.is_client_error() || status.is_server_error() {
74            let code = status.as_u16();
75            let text = response.text().await.unwrap_or_default();
76            return Err(SynapticError::Embedding(format!(
77                "HuggingFace API error ({code}): {text}"
78            )));
79        }
80        let resp: serde_json::Value = response
81            .json()
82            .await
83            .map_err(|e| SynapticError::Embedding(format!("HuggingFace parse: {e}")))?;
84        parse_hf_response(&resp)
85    }
86}
87
88fn parse_hf_response(resp: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
89    let array = if let Some(arr) = resp.as_array() {
90        arr
91    } else if let Some(arr) = resp.get("embeddings").and_then(|e| e.as_array()) {
92        arr
93    } else {
94        return Err(SynapticError::Embedding(
95            "unexpected HuggingFace response format".to_string(),
96        ));
97    };
98    let mut result = Vec::with_capacity(array.len());
99    for item in array {
100        let embedding: Vec<f32> = item
101            .as_array()
102            .ok_or_else(|| SynapticError::Embedding("embedding item is not array".to_string()))?
103            .iter()
104            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
105            .collect();
106        result.push(embedding);
107    }
108    Ok(result)
109}
110
111#[async_trait]
112impl Embeddings for HuggingFaceEmbeddings {
113    async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
114        self.embed_batch(texts).await
115    }
116    async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
117        let mut results = self.embed_batch(&[text]).await?;
118        results
119            .pop()
120            .ok_or_else(|| SynapticError::Embedding("empty HuggingFace response".to_string()))
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn config_defaults() {
130        let c = HuggingFaceEmbeddingsConfig::new("BAAI/bge-small-en-v1.5");
131        assert_eq!(c.model, "BAAI/bge-small-en-v1.5");
132    }
133
134    #[test]
135    fn config_builder() {
136        let c = HuggingFaceEmbeddingsConfig::new("model")
137            .with_api_key("hf_test")
138            .with_wait_for_model(false);
139        assert_eq!(c.api_key, Some("hf_test".to_string()));
140    }
141
142    #[test]
143    fn parse_direct_array() {
144        let resp = serde_json::json!([[0.1_f32, 0.2_f32]]);
145        let result = parse_hf_response(&resp).unwrap();
146        assert_eq!(result.len(), 1);
147    }
148}