mem0_rust/embeddings/
huggingface.rs1use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::traits::Embedder;
8use crate::config::HuggingFaceEmbedderConfig;
9use crate::errors::EmbeddingError;
10
11pub struct HuggingFaceEmbedder {
13 client: Client,
14 api_key: String,
15 model: String,
16 dimensions: usize,
17 api_url: String,
18}
19
20impl HuggingFaceEmbedder {
21 pub fn new(config: HuggingFaceEmbedderConfig) -> Result<Self, EmbeddingError> {
23 let api_key = config
24 .api_key
25 .or_else(|| std::env::var("HF_TOKEN").ok())
26 .ok_or_else(|| EmbeddingError::Api("HF_TOKEN not set".to_string()))?;
27
28 let api_url = config.api_url.unwrap_or_else(|| {
29 format!(
30 "https://api-inference.huggingface.co/pipeline/feature-extraction/{}",
31 config.model
32 )
33 });
34
35 Ok(Self {
36 client: Client::new(),
37 api_key,
38 model: config.model,
39 dimensions: config.dimensions,
40 api_url,
41 })
42 }
43}
44
45#[derive(Debug, Serialize)]
46struct HFRequest {
47 inputs: Vec<String>,
48 options: HFOptions,
49}
50
51#[derive(Debug, Serialize)]
52struct HFOptions {
53 wait_for_model: bool,
54}
55
56#[derive(Debug, Deserialize)]
57#[serde(untagged)]
58enum HFResponse {
59 Single(Vec<f32>),
60 Batch(Vec<Vec<f32>>),
61 Nested(Vec<Vec<Vec<f32>>>),
63}
64
65#[async_trait]
66impl Embedder for HuggingFaceEmbedder {
67 async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
68 let request = HFRequest {
69 inputs: vec![text.to_string()],
70 options: HFOptions {
71 wait_for_model: true,
72 },
73 };
74
75 let response = self
76 .client
77 .post(&self.api_url)
78 .header("Authorization", format!("Bearer {}", self.api_key))
79 .json(&request)
80 .send()
81 .await
82 .map_err(|e| EmbeddingError::Network(e.to_string()))?;
83
84 if !response.status().is_success() {
85 let error_text = response.text().await.unwrap_or_default();
86 return Err(EmbeddingError::Api(format!(
87 "HuggingFace API error: {}",
88 error_text
89 )));
90 }
91
92 let result: HFResponse = response
93 .json()
94 .await
95 .map_err(|e| EmbeddingError::InvalidResponse(e.to_string()))?;
96
97 match result {
98 HFResponse::Single(embedding) => Ok(embedding),
99 HFResponse::Batch(embeddings) => embeddings
100 .into_iter()
101 .next()
102 .ok_or_else(|| EmbeddingError::InvalidResponse("Empty response".to_string())),
103 HFResponse::Nested(nested) => {
104 nested
106 .into_iter()
107 .next()
108 .and_then(|inner| {
109 if inner.is_empty() {
110 return None;
111 }
112 let dim = inner[0].len();
113 let mut pooled = vec![0.0f32; dim];
114 for token_emb in &inner {
115 for (i, v) in token_emb.iter().enumerate() {
116 if i < dim {
117 pooled[i] += v;
118 }
119 }
120 }
121 let n = inner.len() as f32;
122 for v in &mut pooled {
123 *v /= n;
124 }
125 Some(pooled)
126 })
127 .ok_or_else(|| {
128 EmbeddingError::InvalidResponse("Empty nested response".to_string())
129 })
130 }
131 }
132 }
133
134 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
135 let request = HFRequest {
136 inputs: texts.iter().map(|s| s.to_string()).collect(),
137 options: HFOptions {
138 wait_for_model: true,
139 },
140 };
141
142 let response = self
143 .client
144 .post(&self.api_url)
145 .header("Authorization", format!("Bearer {}", self.api_key))
146 .json(&request)
147 .send()
148 .await
149 .map_err(|e| EmbeddingError::Network(e.to_string()))?;
150
151 if !response.status().is_success() {
152 let error_text = response.text().await.unwrap_or_default();
153 return Err(EmbeddingError::Api(format!(
154 "HuggingFace API error: {}",
155 error_text
156 )));
157 }
158
159 let result: HFResponse = response
160 .json()
161 .await
162 .map_err(|e| EmbeddingError::InvalidResponse(e.to_string()))?;
163
164 match result {
165 HFResponse::Single(embedding) => Ok(vec![embedding]),
166 HFResponse::Batch(embeddings) => Ok(embeddings),
167 HFResponse::Nested(nested) => {
168 nested
170 .into_iter()
171 .map(|inner| {
172 if inner.is_empty() {
173 return Err(EmbeddingError::InvalidResponse(
174 "Empty nested response".to_string(),
175 ));
176 }
177 let dim = inner[0].len();
178 let mut pooled = vec![0.0f32; dim];
179 for token_emb in &inner {
180 for (i, v) in token_emb.iter().enumerate() {
181 if i < dim {
182 pooled[i] += v;
183 }
184 }
185 }
186 let n = inner.len() as f32;
187 for v in &mut pooled {
188 *v /= n;
189 }
190 Ok(pooled)
191 })
192 .collect()
193 }
194 }
195 }
196
197 fn dimensions(&self) -> usize {
198 self.dimensions
199 }
200
201 fn model_name(&self) -> &str {
202 &self.model
203 }
204}