sc/embeddings/
huggingface.rs1use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8
9use super::config::{resolve_hf_endpoint, resolve_hf_model, resolve_hf_token};
10use super::provider::EmbeddingProvider;
11use super::types::{huggingface_models, ProviderInfo};
12
13pub struct HuggingFaceProvider {
15 client: reqwest::Client,
16 endpoint: String,
17 model: String,
18 token: String,
19 dimensions: usize,
20 max_chars: usize,
21}
22
23impl HuggingFaceProvider {
24 pub fn new() -> Option<Self> {
28 Self::with_config(None, None, None)
29 }
30
31 pub fn with_config(
35 endpoint: Option<String>,
36 model: Option<String>,
37 token: Option<String>,
38 ) -> Option<Self> {
39 let token = token.or_else(resolve_hf_token)?;
40 let endpoint = endpoint.unwrap_or_else(resolve_hf_endpoint);
41 let model = model.unwrap_or_else(resolve_hf_model);
42 let config = huggingface_models::get_config(&model);
43
44 Some(Self {
45 client: reqwest::Client::new(),
46 endpoint,
47 model,
48 token,
49 dimensions: config.dimensions,
50 max_chars: config.max_chars,
51 })
52 }
53}
54
55#[derive(Debug, Serialize)]
57struct HfEmbedRequest<'a> {
58 inputs: HfInputs<'a>,
59 options: HfOptions,
60}
61
62#[derive(Debug, Serialize)]
63#[serde(untagged)]
64enum HfInputs<'a> {
65 Single(&'a str),
66 Batch(Vec<&'a str>),
67}
68
69#[derive(Debug, Serialize)]
70struct HfOptions {
71 wait_for_model: bool,
72}
73
74#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77enum HfEmbedResponse {
78 Single(Vec<Vec<f32>>),
80 Batch(Vec<Vec<Vec<f32>>>),
82 Direct(Vec<f32>),
84}
85
86impl EmbeddingProvider for HuggingFaceProvider {
87 fn info(&self) -> ProviderInfo {
88 ProviderInfo {
89 name: "huggingface".to_string(),
90 model: self.model.clone(),
91 dimensions: self.dimensions,
92 max_chars: self.max_chars,
93 available: false,
94 }
95 }
96
97 async fn is_available(&self) -> bool {
98 !self.token.is_empty()
101 }
102
103 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
104 let url = format!("{}/models/{}/pipeline/feature-extraction", self.endpoint, self.model);
105
106 let request = HfEmbedRequest {
107 inputs: HfInputs::Single(text),
108 options: HfOptions { wait_for_model: true },
109 };
110
111 let response = self.client
112 .post(&url)
113 .header("Authorization", format!("Bearer {}", self.token))
114 .json(&request)
115 .send()
116 .await
117 .map_err(|e| Error::Embedding(format!("HuggingFace request failed: {e}")))?;
118
119 if !response.status().is_success() {
120 let status = response.status();
121 let error = response.text().await.unwrap_or_default();
122 return Err(Error::Embedding(format!(
123 "HuggingFace API error ({status}): {error}"
124 )));
125 }
126
127 let data: HfEmbedResponse = response.json().await
128 .map_err(|e| Error::Embedding(format!("Failed to parse HuggingFace response: {e}")))?;
129
130 match data {
132 HfEmbedResponse::Single(nested) => {
133 nested.into_iter().next()
135 .ok_or_else(|| Error::Embedding("No embeddings in response".into()))
136 }
137 HfEmbedResponse::Direct(embedding) => Ok(embedding),
138 HfEmbedResponse::Batch(batch) => {
139 batch.into_iter()
140 .next()
141 .and_then(|nested| nested.into_iter().next())
142 .ok_or_else(|| Error::Embedding("No embeddings in batch response".into()))
143 }
144 }
145 }
146
147 async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
148 let url = format!("{}/models/{}/pipeline/feature-extraction", self.endpoint, self.model);
149
150 let request = HfEmbedRequest {
151 inputs: HfInputs::Batch(texts.to_vec()),
152 options: HfOptions { wait_for_model: true },
153 };
154
155 let response = self.client
156 .post(&url)
157 .header("Authorization", format!("Bearer {}", self.token))
158 .json(&request)
159 .send()
160 .await
161 .map_err(|e| Error::Embedding(format!("HuggingFace batch request failed: {e}")))?;
162
163 if !response.status().is_success() {
164 let status = response.status();
165 let error = response.text().await.unwrap_or_default();
166 return Err(Error::Embedding(format!(
167 "HuggingFace API error ({status}): {error}"
168 )));
169 }
170
171 let data: HfEmbedResponse = response.json().await
172 .map_err(|e| Error::Embedding(format!("Failed to parse HuggingFace response: {e}")))?;
173
174 match data {
176 HfEmbedResponse::Batch(batch) => {
177 Ok(batch.into_iter()
179 .filter_map(|nested| nested.into_iter().next())
180 .collect())
181 }
182 HfEmbedResponse::Single(nested) => {
183 Ok(nested)
185 }
186 HfEmbedResponse::Direct(embedding) => {
187 Ok(vec![embedding])
189 }
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_huggingface_provider_with_explicit_none_token() {
200 let provider = HuggingFaceProvider::with_config(None, None, None);
204 if let Some(p) = provider {
206 assert!(!p.token.is_empty(), "Provider token should not be empty");
207 }
208 }
209
210 #[test]
211 fn test_huggingface_provider_with_token() {
212 let provider = HuggingFaceProvider::with_config(
213 None,
214 Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
215 Some("test-token".to_string()),
216 );
217 assert!(provider.is_some());
218 let p = provider.unwrap();
219 let info = p.info();
220 assert_eq!(info.name, "huggingface");
221 assert_eq!(info.dimensions, 384);
222 assert_eq!(p.token, "test-token");
223 }
224
225 #[test]
226 fn test_huggingface_provider_uses_custom_model() {
227 let provider = HuggingFaceProvider::with_config(
228 None,
229 Some("sentence-transformers/all-mpnet-base-v2".to_string()),
230 Some("test-token".to_string()),
231 );
232 assert!(provider.is_some());
233 let info = provider.unwrap().info();
234 assert_eq!(info.model, "sentence-transformers/all-mpnet-base-v2");
235 assert_eq!(info.dimensions, 768); }
237}