Skip to main content

sc/embeddings/
huggingface.rs

1//! HuggingFace Inference API embedding provider.
2//!
3//! Uses HuggingFace's hosted inference API for embedding generation.
4//! Requires a HuggingFace API token (HF_TOKEN).
5
6use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8
9use super::config::{resolve_hf_endpoint, resolve_hf_model, resolve_hf_token};
10use super::provider::EmbeddingProvider;
11use super::types::{huggingface_models, ProviderInfo};
12
13/// HuggingFace Inference API embedding provider.
14pub struct HuggingFaceProvider {
15    client: reqwest::Client,
16    endpoint: String,
17    model: String,
18    token: String,
19    dimensions: usize,
20    max_chars: usize,
21}
22
23impl HuggingFaceProvider {
24    /// Create a new HuggingFace provider with default configuration.
25    ///
26    /// Returns `None` if no API token is configured.
27    pub fn new() -> Option<Self> {
28        Self::with_config(None, None, None)
29    }
30
31    /// Create a new HuggingFace provider with custom configuration.
32    ///
33    /// Returns `None` if no API token is available.
34    pub fn with_config(
35        endpoint: Option<String>,
36        model: Option<String>,
37        token: Option<String>,
38    ) -> Option<Self> {
39        let token = token.or_else(resolve_hf_token)?;
40        let endpoint = endpoint.unwrap_or_else(resolve_hf_endpoint);
41        let model = model.unwrap_or_else(resolve_hf_model);
42        let config = huggingface_models::get_config(&model);
43
44        Some(Self {
45            client: reqwest::Client::new(),
46            endpoint,
47            model,
48            token,
49            dimensions: config.dimensions,
50            max_chars: config.max_chars,
51        })
52    }
53}
54
55/// HuggingFace API request for feature extraction.
56#[derive(Debug, Serialize)]
57struct HfEmbedRequest<'a> {
58    inputs: HfInputs<'a>,
59    options: HfOptions,
60}
61
62#[derive(Debug, Serialize)]
63#[serde(untagged)]
64enum HfInputs<'a> {
65    Single(&'a str),
66    Batch(Vec<&'a str>),
67}
68
69#[derive(Debug, Serialize)]
70struct HfOptions {
71    wait_for_model: bool,
72}
73
74/// HuggingFace API response - can be single or batch embeddings.
75#[derive(Debug, Deserialize)]
76#[serde(untagged)]
77enum HfEmbedResponse {
78    /// Single embedding (nested array for sentence-transformers)
79    Single(Vec<Vec<f32>>),
80    /// Batch embeddings
81    Batch(Vec<Vec<Vec<f32>>>),
82    /// Direct embedding (some models)
83    Direct(Vec<f32>),
84}
85
86impl EmbeddingProvider for HuggingFaceProvider {
87    fn info(&self) -> ProviderInfo {
88        ProviderInfo {
89            name: "huggingface".to_string(),
90            model: self.model.clone(),
91            dimensions: self.dimensions,
92            max_chars: self.max_chars,
93            available: false,
94        }
95    }
96
97    async fn is_available(&self) -> bool {
98        // HuggingFace is available if we have a token
99        // We could also ping the API, but that uses rate limit quota
100        !self.token.is_empty()
101    }
102
103    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
104        let url = format!("{}/models/{}/pipeline/feature-extraction", self.endpoint, self.model);
105
106        let request = HfEmbedRequest {
107            inputs: HfInputs::Single(text),
108            options: HfOptions { wait_for_model: true },
109        };
110
111        let response = self.client
112            .post(&url)
113            .header("Authorization", format!("Bearer {}", self.token))
114            .json(&request)
115            .send()
116            .await
117            .map_err(|e| Error::Embedding(format!("HuggingFace request failed: {e}")))?;
118
119        if !response.status().is_success() {
120            let status = response.status();
121            let error = response.text().await.unwrap_or_default();
122            return Err(Error::Embedding(format!(
123                "HuggingFace API error ({status}): {error}"
124            )));
125        }
126
127        let data: HfEmbedResponse = response.json().await
128            .map_err(|e| Error::Embedding(format!("Failed to parse HuggingFace response: {e}")))?;
129
130        // Handle different response formats
131        match data {
132            HfEmbedResponse::Single(nested) => {
133                // sentence-transformers returns [[embedding]]
134                nested.into_iter().next()
135                    .ok_or_else(|| Error::Embedding("No embeddings in response".into()))
136            }
137            HfEmbedResponse::Direct(embedding) => Ok(embedding),
138            HfEmbedResponse::Batch(batch) => {
139                batch.into_iter()
140                    .next()
141                    .and_then(|nested| nested.into_iter().next())
142                    .ok_or_else(|| Error::Embedding("No embeddings in batch response".into()))
143            }
144        }
145    }
146
147    async fn generate_embeddings(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
148        let url = format!("{}/models/{}/pipeline/feature-extraction", self.endpoint, self.model);
149
150        let request = HfEmbedRequest {
151            inputs: HfInputs::Batch(texts.to_vec()),
152            options: HfOptions { wait_for_model: true },
153        };
154
155        let response = self.client
156            .post(&url)
157            .header("Authorization", format!("Bearer {}", self.token))
158            .json(&request)
159            .send()
160            .await
161            .map_err(|e| Error::Embedding(format!("HuggingFace batch request failed: {e}")))?;
162
163        if !response.status().is_success() {
164            let status = response.status();
165            let error = response.text().await.unwrap_or_default();
166            return Err(Error::Embedding(format!(
167                "HuggingFace API error ({status}): {error}"
168            )));
169        }
170
171        let data: HfEmbedResponse = response.json().await
172            .map_err(|e| Error::Embedding(format!("Failed to parse HuggingFace response: {e}")))?;
173
174        // Handle different response formats
175        match data {
176            HfEmbedResponse::Batch(batch) => {
177                // sentence-transformers returns [[[embedding1]], [[embedding2]], ...]
178                Ok(batch.into_iter()
179                    .filter_map(|nested| nested.into_iter().next())
180                    .collect())
181            }
182            HfEmbedResponse::Single(nested) => {
183                // Single response for batch of 1
184                Ok(nested)
185            }
186            HfEmbedResponse::Direct(embedding) => {
187                // Direct embedding for batch of 1
188                Ok(vec![embedding])
189            }
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_huggingface_provider_with_explicit_none_token() {
200        // When explicitly passing None for token and no config/env, behavior depends
201        // on config file. The key assertion is that if a provider IS created,
202        // it must have a valid (non-empty) token.
203        let provider = HuggingFaceProvider::with_config(None, None, None);
204        // Can't assert None because there might be a config file or env var with token
205        if let Some(p) = provider {
206            assert!(!p.token.is_empty(), "Provider token should not be empty");
207        }
208    }
209
210    #[test]
211    fn test_huggingface_provider_with_token() {
212        let provider = HuggingFaceProvider::with_config(
213            None,
214            Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
215            Some("test-token".to_string()),
216        );
217        assert!(provider.is_some());
218        let p = provider.unwrap();
219        let info = p.info();
220        assert_eq!(info.name, "huggingface");
221        assert_eq!(info.dimensions, 384);
222        assert_eq!(p.token, "test-token");
223    }
224
225    #[test]
226    fn test_huggingface_provider_uses_custom_model() {
227        let provider = HuggingFaceProvider::with_config(
228            None,
229            Some("sentence-transformers/all-mpnet-base-v2".to_string()),
230            Some("test-token".to_string()),
231        );
232        assert!(provider.is_some());
233        let info = provider.unwrap().info();
234        assert_eq!(info.model, "sentence-transformers/all-mpnet-base-v2");
235        assert_eq!(info.dimensions, 768); // mpnet-base-v2 has 768 dimensions
236    }
237}