synaptic_huggingface/
lib.rs1pub mod reranker;
2pub use reranker::{BgeRerankerModel, HuggingFaceReranker};
3
4use async_trait::async_trait;
5use synaptic_core::{Embeddings, SynapticError};
6
7#[derive(Debug, Clone)]
8pub struct HuggingFaceEmbeddingsConfig {
9 pub model: String,
10 pub api_key: Option<String>,
11 pub base_url: String,
12 pub wait_for_model: bool,
13}
14
15impl HuggingFaceEmbeddingsConfig {
16 pub fn new(model: impl Into<String>) -> Self {
17 Self {
18 model: model.into(),
19 api_key: None,
20 base_url: "https://api-inference.huggingface.co/models".to_string(),
21 wait_for_model: true,
22 }
23 }
24 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
25 self.api_key = Some(api_key.into());
26 self
27 }
28 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
29 self.base_url = base_url.into();
30 self
31 }
32 pub fn with_wait_for_model(mut self, wait: bool) -> Self {
33 self.wait_for_model = wait;
34 self
35 }
36}
37
38pub struct HuggingFaceEmbeddings {
39 config: HuggingFaceEmbeddingsConfig,
40 client: reqwest::Client,
41}
42
43impl HuggingFaceEmbeddings {
44 pub fn new(config: HuggingFaceEmbeddingsConfig) -> Self {
45 Self {
46 config,
47 client: reqwest::Client::new(),
48 }
49 }
50 pub fn with_client(config: HuggingFaceEmbeddingsConfig, client: reqwest::Client) -> Self {
51 Self { config, client }
52 }
53
54 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
55 if texts.is_empty() {
56 return Ok(Vec::new());
57 }
58 let url = format!("{}/{}", self.config.base_url, self.config.model);
59 let body = serde_json::json!({ "inputs": texts });
60 let mut request = self
61 .client
62 .post(&url)
63 .header("Content-Type", "application/json");
64 if let Some(ref key) = self.config.api_key {
65 request = request.header("Authorization", format!("Bearer {key}"));
66 }
67 if self.config.wait_for_model {
68 request = request.header("x-wait-for-model", "true");
69 }
70 let response = request
71 .json(&body)
72 .send()
73 .await
74 .map_err(|e| SynapticError::Embedding(format!("HuggingFace request: {e}")))?;
75 let status = response.status();
76 if status.is_client_error() || status.is_server_error() {
77 let code = status.as_u16();
78 let text = response.text().await.unwrap_or_default();
79 return Err(SynapticError::Embedding(format!(
80 "HuggingFace API error ({code}): {text}"
81 )));
82 }
83 let resp: serde_json::Value = response
84 .json()
85 .await
86 .map_err(|e| SynapticError::Embedding(format!("HuggingFace parse: {e}")))?;
87 parse_hf_response(&resp)
88 }
89}
90
91fn parse_hf_response(resp: &serde_json::Value) -> Result<Vec<Vec<f32>>, SynapticError> {
92 let array = if let Some(arr) = resp.as_array() {
93 arr
94 } else if let Some(arr) = resp.get("embeddings").and_then(|e| e.as_array()) {
95 arr
96 } else {
97 return Err(SynapticError::Embedding(
98 "unexpected HuggingFace response format".to_string(),
99 ));
100 };
101 let mut result = Vec::with_capacity(array.len());
102 for item in array {
103 let embedding: Vec<f32> = item
104 .as_array()
105 .ok_or_else(|| SynapticError::Embedding("embedding item is not array".to_string()))?
106 .iter()
107 .map(|v| v.as_f64().unwrap_or(0.0) as f32)
108 .collect();
109 result.push(embedding);
110 }
111 Ok(result)
112}
113
114#[async_trait]
115impl Embeddings for HuggingFaceEmbeddings {
116 async fn embed_documents(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, SynapticError> {
117 self.embed_batch(texts).await
118 }
119 async fn embed_query(&self, text: &str) -> Result<Vec<f32>, SynapticError> {
120 let mut results = self.embed_batch(&[text]).await?;
121 results
122 .pop()
123 .ok_or_else(|| SynapticError::Embedding("empty HuggingFace response".to_string()))
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn config_defaults() {
133 let c = HuggingFaceEmbeddingsConfig::new("BAAI/bge-small-en-v1.5");
134 assert_eq!(c.model, "BAAI/bge-small-en-v1.5");
135 }
136
137 #[test]
138 fn config_builder() {
139 let c = HuggingFaceEmbeddingsConfig::new("model")
140 .with_api_key("hf_test")
141 .with_wait_for_model(false);
142 assert_eq!(c.api_key, Some("hf_test".to_string()));
143 }
144
145 #[test]
146 fn parse_direct_array() {
147 let resp = serde_json::json!([[0.1_f32, 0.2_f32]]);
148 let result = parse_hf_response(&resp).unwrap();
149 assert_eq!(result.len(), 1);
150 }
151}