mem0_rust/embeddings/
huggingface.rs

1//! HuggingFace Inference API embeddings provider.
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::traits::Embedder;
8use crate::config::HuggingFaceEmbedderConfig;
9use crate::errors::EmbeddingError;
10
11/// HuggingFace embeddings provider
12pub struct HuggingFaceEmbedder {
13    client: Client,
14    api_key: String,
15    model: String,
16    dimensions: usize,
17    api_url: String,
18}
19
20impl HuggingFaceEmbedder {
21    /// Create a new HuggingFace embedder
22    pub fn new(config: HuggingFaceEmbedderConfig) -> Result<Self, EmbeddingError> {
23        let api_key = config
24            .api_key
25            .or_else(|| std::env::var("HF_TOKEN").ok())
26            .ok_or_else(|| EmbeddingError::Api("HF_TOKEN not set".to_string()))?;
27
28        let api_url = config.api_url.unwrap_or_else(|| {
29            format!(
30                "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
31                config.model
32            )
33        });
34
35        Ok(Self {
36            client: Client::new(),
37            api_key,
38            model: config.model,
39            dimensions: config.dimensions,
40            api_url,
41        })
42    }
43}
44
45#[derive(Debug, Serialize)]
46struct HFRequest {
47    inputs: Vec<String>,
48    options: HFOptions,
49}
50
51#[derive(Debug, Serialize)]
52struct HFOptions {
53    wait_for_model: bool,
54}
55
56#[derive(Debug, Deserialize)]
57#[serde(untagged)]
58enum HFResponse {
59    Single(Vec<f32>),
60    Batch(Vec<Vec<f32>>),
61    // Some models return nested arrays (token-level embeddings)
62    Nested(Vec<Vec<Vec<f32>>>),
63}
64
65#[async_trait]
66impl Embedder for HuggingFaceEmbedder {
67    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
68        let request = HFRequest {
69            inputs: vec![text.to_string()],
70            options: HFOptions {
71                wait_for_model: true,
72            },
73        };
74
75        let response = self
76            .client
77            .post(&self.api_url)
78            .header("Authorization", format!("Bearer {}", self.api_key))
79            .json(&request)
80            .send()
81            .await
82            .map_err(|e| EmbeddingError::Network(e.to_string()))?;
83
84        if !response.status().is_success() {
85            let error_text = response.text().await.unwrap_or_default();
86            return Err(EmbeddingError::Api(format!(
87                "HuggingFace API error: {}",
88                error_text
89            )));
90        }
91
92        let result: HFResponse = response
93            .json()
94            .await
95            .map_err(|e| EmbeddingError::InvalidResponse(e.to_string()))?;
96
97        match result {
98            HFResponse::Single(embedding) => Ok(embedding),
99            HFResponse::Batch(embeddings) => embeddings
100                .into_iter()
101                .next()
102                .ok_or_else(|| EmbeddingError::InvalidResponse("Empty response".to_string())),
103            HFResponse::Nested(nested) => {
104                // Some models return [[embedding]] format - apply mean pooling
105                nested
106                    .into_iter()
107                    .next()
108                    .and_then(|inner| {
109                        if inner.is_empty() {
110                            return None;
111                        }
112                        let dim = inner[0].len();
113                        let mut pooled = vec![0.0f32; dim];
114                        for token_emb in &inner {
115                            for (i, v) in token_emb.iter().enumerate() {
116                                if i < dim {
117                                    pooled[i] += v;
118                                }
119                            }
120                        }
121                        let n = inner.len() as f32;
122                        for v in &mut pooled {
123                            *v /= n;
124                        }
125                        Some(pooled)
126                    })
127                    .ok_or_else(|| {
128                        EmbeddingError::InvalidResponse("Empty nested response".to_string())
129                    })
130            }
131        }
132    }
133
134    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
135        let request = HFRequest {
136            inputs: texts.iter().map(|s| s.to_string()).collect(),
137            options: HFOptions {
138                wait_for_model: true,
139            },
140        };
141
142        let response = self
143            .client
144            .post(&self.api_url)
145            .header("Authorization", format!("Bearer {}", self.api_key))
146            .json(&request)
147            .send()
148            .await
149            .map_err(|e| EmbeddingError::Network(e.to_string()))?;
150
151        if !response.status().is_success() {
152            let error_text = response.text().await.unwrap_or_default();
153            return Err(EmbeddingError::Api(format!(
154                "HuggingFace API error: {}",
155                error_text
156            )));
157        }
158
159        let result: HFResponse = response
160            .json()
161            .await
162            .map_err(|e| EmbeddingError::InvalidResponse(e.to_string()))?;
163
164        match result {
165            HFResponse::Single(embedding) => Ok(vec![embedding]),
166            HFResponse::Batch(embeddings) => Ok(embeddings),
167            HFResponse::Nested(nested) => {
168                // Mean pooling for each text
169                nested
170                    .into_iter()
171                    .map(|inner| {
172                        if inner.is_empty() {
173                            return Err(EmbeddingError::InvalidResponse(
174                                "Empty nested response".to_string(),
175                            ));
176                        }
177                        let dim = inner[0].len();
178                        let mut pooled = vec![0.0f32; dim];
179                        for token_emb in &inner {
180                            for (i, v) in token_emb.iter().enumerate() {
181                                if i < dim {
182                                    pooled[i] += v;
183                                }
184                            }
185                        }
186                        let n = inner.len() as f32;
187                        for v in &mut pooled {
188                            *v /= n;
189                        }
190                        Ok(pooled)
191                    })
192                    .collect()
193            }
194        }
195    }
196
197    fn dimensions(&self) -> usize {
198        self.dimensions
199    }
200
201    fn model_name(&self) -> &str {
202        &self.model
203    }
204}