1use crate::{EmbeddableContent, EmbeddingConfig, Vector};
4use anyhow::{anyhow, Result};
5use scirs2_core::random::Random;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct HuggingFaceConfig {
12 pub model_name: String,
13 pub cache_dir: Option<String>,
14 pub device: String,
15 pub batch_size: usize,
16 pub max_length: usize,
17 pub pooling_strategy: PoolingStrategy,
18 pub trust_remote_code: bool,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum PoolingStrategy {
24 Cls,
26 Mean,
28 Max,
30 AttentionWeighted,
32}
33
34impl Default for HuggingFaceConfig {
35 fn default() -> Self {
36 Self {
37 model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
38 cache_dir: None,
39 device: "cpu".to_string(),
40 batch_size: 32,
41 max_length: 512,
42 pooling_strategy: PoolingStrategy::Mean,
43 trust_remote_code: false,
44 }
45 }
46}
47
48#[derive(Debug)]
50pub struct HuggingFaceEmbedder {
51 config: HuggingFaceConfig,
52 model_cache: HashMap<String, ModelInfo>,
53}
54
55#[derive(Debug, Clone)]
57struct ModelInfo {
58 dimensions: usize,
59 max_sequence_length: usize,
60 model_type: String,
61 loaded: bool,
62}
63
64impl HuggingFaceEmbedder {
65 pub fn new(config: HuggingFaceConfig) -> Result<Self> {
67 Ok(Self {
68 config,
69 model_cache: HashMap::new(),
70 })
71 }
72
73 pub fn with_default_config() -> Result<Self> {
75 Self::new(HuggingFaceConfig::default())
76 }
77
78 pub async fn load_model(&mut self, model_name: &str) -> Result<()> {
80 if self.model_cache.contains_key(model_name) {
81 return Ok(());
82 }
83
84 let model_info = self.get_model_info(model_name).await?;
86 self.model_cache.insert(model_name.to_string(), model_info);
87
88 tracing::info!("Loaded HuggingFace model: {}", model_name);
89 Ok(())
90 }
91
92 async fn get_model_info(&self, model_name: &str) -> Result<ModelInfo> {
94 let dimensions = match model_name {
97 "sentence-transformers/all-MiniLM-L6-v2" => 384,
98 "sentence-transformers/all-mpnet-base-v2" => 768,
99 "microsoft/DialoGPT-medium" => 1024,
100 "bert-base-uncased" => 768,
101 "distilbert-base-uncased" => 768,
102 _ => 768, };
104
105 Ok(ModelInfo {
106 dimensions,
107 max_sequence_length: self.config.max_length,
108 model_type: "transformer".to_string(),
109 loaded: true,
110 })
111 }
112
113 pub async fn embed_batch(&mut self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
115 if contents.is_empty() {
116 return Ok(vec![]);
117 }
118
119 let model_name = self.config.model_name.clone();
121 self.load_model(&model_name).await?;
122
123 let model_info = self
124 .model_cache
125 .get(&self.config.model_name)
126 .ok_or_else(|| anyhow!("Model not loaded: {}", self.config.model_name))?;
127
128 let mut embeddings = Vec::with_capacity(contents.len());
129
130 for chunk in contents.chunks(self.config.batch_size) {
132 let texts: Vec<String> = chunk
133 .iter()
134 .map(|content| self.content_to_text(content))
135 .collect();
136
137 let batch_embeddings = self.generate_embeddings(&texts, model_info).await?;
138 embeddings.extend(batch_embeddings);
139 }
140
141 Ok(embeddings)
142 }
143
144 pub async fn embed(&mut self, content: &EmbeddableContent) -> Result<Vector> {
146 let embeddings = self.embed_batch(std::slice::from_ref(content)).await?;
147 embeddings
148 .into_iter()
149 .next()
150 .ok_or_else(|| anyhow!("Failed to generate embedding"))
151 }
152
153 fn content_to_text(&self, content: &EmbeddableContent) -> String {
155 match content {
156 EmbeddableContent::Text(text) => text.clone(),
157 EmbeddableContent::RdfResource {
158 uri,
159 label,
160 description,
161 properties,
162 } => {
163 let mut text_parts = vec![uri.clone()];
164
165 if let Some(label) = label {
166 text_parts.push(label.clone());
167 }
168
169 if let Some(desc) = description {
170 text_parts.push(desc.clone());
171 }
172
173 for (prop, values) in properties {
174 text_parts.push(format!("{}: {}", prop, values.join(", ")));
175 }
176
177 text_parts.join(" ")
178 }
179 EmbeddableContent::SparqlQuery(query) => query.clone(),
180 EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
181 }
182 }
183
184 async fn generate_embeddings(
186 &self,
187 texts: &[String],
188 model_info: &ModelInfo,
189 ) -> Result<Vec<Vector>> {
190 let mut embeddings = Vec::with_capacity(texts.len());
193
194 for text in texts {
195 let embedding = self.simulate_embedding(text, model_info.dimensions)?;
196 embeddings.push(embedding);
197 }
198
199 Ok(embeddings)
200 }
201
202 fn simulate_embedding(&self, text: &str, dimensions: usize) -> Result<Vector> {
204 use std::collections::hash_map::DefaultHasher;
206 use std::hash::{Hash, Hasher};
207
208 let mut hasher = DefaultHasher::new();
209 text.hash(&mut hasher);
210 let seed = hasher.finish();
211
212 let mut rng = Random::seed(seed);
213
214 let mut embedding = vec![0.0f32; dimensions];
215 for value in embedding.iter_mut().take(dimensions) {
216 *value = rng.gen_range(-1.0..1.0); }
218
219 if matches!(self.config.pooling_strategy, PoolingStrategy::Mean) {
221 let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
222 if norm > 0.0 {
223 for x in &mut embedding {
224 *x /= norm;
225 }
226 }
227 }
228
229 Ok(Vector::new(embedding))
230 }
231
232 pub fn get_cached_models(&self) -> Vec<String> {
234 self.model_cache.keys().cloned().collect()
235 }
236
237 pub fn clear_cache(&mut self) {
239 self.model_cache.clear();
240 }
241
242 pub fn get_model_dimensions(&self, model_name: &str) -> Option<usize> {
244 self.model_cache.get(model_name).map(|info| info.dimensions)
245 }
246}
247
248#[derive(Debug)]
250pub struct HuggingFaceModelManager {
251 embedders: HashMap<String, HuggingFaceEmbedder>,
252 default_model: String,
253}
254
255impl HuggingFaceModelManager {
256 pub fn new(default_model: String) -> Self {
258 Self {
259 embedders: HashMap::new(),
260 default_model,
261 }
262 }
263
264 pub fn add_model(&mut self, name: String, config: HuggingFaceConfig) -> Result<()> {
266 let embedder = HuggingFaceEmbedder::new(config)?;
267 self.embedders.insert(name, embedder);
268 Ok(())
269 }
270
271 pub async fn embed_with_model(
273 &mut self,
274 model_name: &str,
275 content: &EmbeddableContent,
276 ) -> Result<Vector> {
277 let embedder = self
278 .embedders
279 .get_mut(model_name)
280 .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
281 embedder.embed(content).await
282 }
283
284 pub async fn embed(&mut self, content: &EmbeddableContent) -> Result<Vector> {
286 self.embed_with_model(&self.default_model.clone(), content)
287 .await
288 }
289
290 pub fn list_models(&self) -> Vec<String> {
292 self.embedders.keys().cloned().collect()
293 }
294}
295
296impl From<EmbeddingConfig> for HuggingFaceConfig {
298 fn from(config: EmbeddingConfig) -> Self {
299 Self {
300 model_name: config.model_name,
301 cache_dir: None,
302 device: "cpu".to_string(),
303 batch_size: 32,
304 max_length: config.max_sequence_length,
305 pooling_strategy: if config.normalize {
306 PoolingStrategy::Mean
307 } else {
308 PoolingStrategy::Cls
309 },
310 trust_remote_code: false,
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use anyhow::Result;
319
320 #[tokio::test]
321 async fn test_huggingface_embedder_creation() {
322 let embedder = HuggingFaceEmbedder::with_default_config();
323 assert!(embedder.is_ok());
324 }
325
326 #[tokio::test]
327 async fn test_model_loading() -> Result<()> {
328 let mut embedder = HuggingFaceEmbedder::with_default_config()?;
329 let result = embedder
330 .load_model("sentence-transformers/all-MiniLM-L6-v2")
331 .await;
332 assert!(result.is_ok());
333
334 let dimensions = embedder.get_model_dimensions("sentence-transformers/all-MiniLM-L6-v2");
335 assert_eq!(dimensions, Some(384));
336 Ok(())
337 }
338
339 #[tokio::test]
340 async fn test_text_embedding() -> Result<()> {
341 let mut embedder = HuggingFaceEmbedder::with_default_config()?;
342 let content = EmbeddableContent::Text("Hello, world!".to_string());
343
344 let result = embedder.embed(&content).await;
345 assert!(result.is_ok());
346
347 let embedding = result?;
348 assert_eq!(embedding.dimensions, 384);
349 Ok(())
350 }
351
352 #[tokio::test]
353 async fn test_rdf_resource_embedding() -> Result<()> {
354 let mut embedder = HuggingFaceEmbedder::with_default_config()?;
355 let mut properties = HashMap::new();
356 properties.insert("type".to_string(), vec!["Person".to_string()]);
357
358 let content = EmbeddableContent::RdfResource {
359 uri: "http://example.org/person/1".to_string(),
360 label: Some("John Doe".to_string()),
361 description: Some("A person in the knowledge graph".to_string()),
362 properties,
363 };
364
365 let result = embedder.embed(&content).await;
366 assert!(result.is_ok());
367 Ok(())
368 }
369
370 #[tokio::test]
371 async fn test_batch_embedding() -> Result<()> {
372 let mut embedder = HuggingFaceEmbedder::with_default_config()?;
373 let contents = vec![
374 EmbeddableContent::Text("First text".to_string()),
375 EmbeddableContent::Text("Second text".to_string()),
376 EmbeddableContent::Text("Third text".to_string()),
377 ];
378
379 let result = embedder.embed_batch(&contents).await;
380 assert!(result.is_ok());
381
382 let embeddings = result?;
383 assert_eq!(embeddings.len(), 3);
384 Ok(())
385 }
386
387 #[tokio::test]
388 async fn test_model_manager() {
389 let mut manager = HuggingFaceModelManager::new("default".to_string());
390 let config = HuggingFaceConfig::default();
391
392 let result = manager.add_model("default".to_string(), config);
393 assert!(result.is_ok());
394
395 let models = manager.list_models();
396 assert!(models.contains(&"default".to_string()));
397 }
398
399 #[test]
400 fn test_config_conversion() {
401 let embedding_config = EmbeddingConfig {
402 model_name: "test-model".to_string(),
403 dimensions: 768,
404 max_sequence_length: 512,
405 normalize: true,
406 };
407
408 let hf_config: HuggingFaceConfig = embedding_config.into();
409 assert_eq!(hf_config.model_name, "test-model");
410 assert_eq!(hf_config.max_length, 512);
411 assert!(matches!(hf_config.pooling_strategy, PoolingStrategy::Mean));
412 }
413
414 #[test]
415 fn test_pooling_strategies() {
416 let strategies = vec![
417 PoolingStrategy::Cls,
418 PoolingStrategy::Mean,
419 PoolingStrategy::Max,
420 PoolingStrategy::AttentionWeighted,
421 ];
422
423 for strategy in strategies {
424 let config = HuggingFaceConfig {
425 pooling_strategy: strategy,
426 ..Default::default()
427 };
428 assert!(matches!(
429 config.pooling_strategy,
430 PoolingStrategy::Cls
431 | PoolingStrategy::Mean
432 | PoolingStrategy::Max
433 | PoolingStrategy::AttentionWeighted
434 ));
435 }
436 }
437}