Skip to main content

oxirs_vec/
huggingface.rs

1//! HuggingFace Transformers integration for embedding generation
2
3use crate::{EmbeddableContent, EmbeddingConfig, Vector};
4use anyhow::{anyhow, Result};
5use scirs2_core::random::Random;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// HuggingFace model configuration
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct HuggingFaceConfig {
12    pub model_name: String,
13    pub cache_dir: Option<String>,
14    pub device: String,
15    pub batch_size: usize,
16    pub max_length: usize,
17    pub pooling_strategy: PoolingStrategy,
18    pub trust_remote_code: bool,
19}
20
21/// Pooling strategies for transformer outputs
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub enum PoolingStrategy {
24    /// Use `[CLS]` token embedding
25    Cls,
26    /// Mean pooling of all token embeddings
27    Mean,
28    /// Max pooling of all token embeddings
29    Max,
30    /// Weighted mean pooling based on attention weights
31    AttentionWeighted,
32}
33
34impl Default for HuggingFaceConfig {
35    fn default() -> Self {
36        Self {
37            model_name: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
38            cache_dir: None,
39            device: "cpu".to_string(),
40            batch_size: 32,
41            max_length: 512,
42            pooling_strategy: PoolingStrategy::Mean,
43            trust_remote_code: false,
44        }
45    }
46}
47
48/// HuggingFace transformer model for embedding generation
49#[derive(Debug)]
50pub struct HuggingFaceEmbedder {
51    config: HuggingFaceConfig,
52    model_cache: HashMap<String, ModelInfo>,
53}
54
55/// Model information and metadata
56#[derive(Debug, Clone)]
57struct ModelInfo {
58    dimensions: usize,
59    max_sequence_length: usize,
60    model_type: String,
61    loaded: bool,
62}
63
64impl HuggingFaceEmbedder {
65    /// Create a new HuggingFace embedder
66    pub fn new(config: HuggingFaceConfig) -> Result<Self> {
67        Ok(Self {
68            config,
69            model_cache: HashMap::new(),
70        })
71    }
72
73    /// Create embedder with default configuration
74    pub fn with_default_config() -> Result<Self> {
75        Self::new(HuggingFaceConfig::default())
76    }
77
78    /// Load a model and prepare it for inference
79    pub async fn load_model(&mut self, model_name: &str) -> Result<()> {
80        if self.model_cache.contains_key(model_name) {
81            return Ok(());
82        }
83
84        // Check if model exists in cache directory
85        let model_info = self.get_model_info(model_name).await?;
86        self.model_cache.insert(model_name.to_string(), model_info);
87
88        tracing::info!("Loaded HuggingFace model: {}", model_name);
89        Ok(())
90    }
91
92    /// Get model information from HuggingFace Hub
93    async fn get_model_info(&self, model_name: &str) -> Result<ModelInfo> {
94        // Simulate fetching model info from HuggingFace Hub
95        // In a real implementation, this would use the HuggingFace API
96        let dimensions = match model_name {
97            "sentence-transformers/all-MiniLM-L6-v2" => 384,
98            "sentence-transformers/all-mpnet-base-v2" => 768,
99            "microsoft/DialoGPT-medium" => 1024,
100            "bert-base-uncased" => 768,
101            "distilbert-base-uncased" => 768,
102            _ => 768, // Default dimension
103        };
104
105        Ok(ModelInfo {
106            dimensions,
107            max_sequence_length: self.config.max_length,
108            model_type: "transformer".to_string(),
109            loaded: true,
110        })
111    }
112
113    /// Generate embeddings for a batch of content
114    pub async fn embed_batch(&mut self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
115        if contents.is_empty() {
116            return Ok(vec![]);
117        }
118
119        // Load model if not already loaded
120        let model_name = self.config.model_name.clone();
121        self.load_model(&model_name).await?;
122
123        let model_info = self
124            .model_cache
125            .get(&self.config.model_name)
126            .ok_or_else(|| anyhow!("Model not loaded: {}", self.config.model_name))?;
127
128        let mut embeddings = Vec::with_capacity(contents.len());
129
130        // Process in batches
131        for chunk in contents.chunks(self.config.batch_size) {
132            let texts: Vec<String> = chunk
133                .iter()
134                .map(|content| self.content_to_text(content))
135                .collect();
136
137            let batch_embeddings = self.generate_embeddings(&texts, model_info).await?;
138            embeddings.extend(batch_embeddings);
139        }
140
141        Ok(embeddings)
142    }
143
144    /// Generate a single embedding
145    pub async fn embed(&mut self, content: &EmbeddableContent) -> Result<Vector> {
146        let embeddings = self.embed_batch(std::slice::from_ref(content)).await?;
147        embeddings
148            .into_iter()
149            .next()
150            .ok_or_else(|| anyhow!("Failed to generate embedding"))
151    }
152
153    /// Convert embeddable content to text
154    fn content_to_text(&self, content: &EmbeddableContent) -> String {
155        match content {
156            EmbeddableContent::Text(text) => text.clone(),
157            EmbeddableContent::RdfResource {
158                uri,
159                label,
160                description,
161                properties,
162            } => {
163                let mut text_parts = vec![uri.clone()];
164
165                if let Some(label) = label {
166                    text_parts.push(label.clone());
167                }
168
169                if let Some(desc) = description {
170                    text_parts.push(desc.clone());
171                }
172
173                for (prop, values) in properties {
174                    text_parts.push(format!("{}: {}", prop, values.join(", ")));
175                }
176
177                text_parts.join(" ")
178            }
179            EmbeddableContent::SparqlQuery(query) => query.clone(),
180            EmbeddableContent::GraphPattern(pattern) => pattern.clone(),
181        }
182    }
183
184    /// Generate embeddings using transformer model
185    async fn generate_embeddings(
186        &self,
187        texts: &[String],
188        model_info: &ModelInfo,
189    ) -> Result<Vec<Vector>> {
190        // In a real implementation, this would use actual HuggingFace transformers
191        // For now, simulate embedding generation
192        let mut embeddings = Vec::with_capacity(texts.len());
193
194        for text in texts {
195            let embedding = self.simulate_embedding(text, model_info.dimensions)?;
196            embeddings.push(embedding);
197        }
198
199        Ok(embeddings)
200    }
201
202    /// Simulate embedding generation (placeholder for actual transformer inference)
203    fn simulate_embedding(&self, text: &str, dimensions: usize) -> Result<Vector> {
204        // Simple hash-based embedding simulation
205        use std::collections::hash_map::DefaultHasher;
206        use std::hash::{Hash, Hasher};
207
208        let mut hasher = DefaultHasher::new();
209        text.hash(&mut hasher);
210        let seed = hasher.finish();
211
212        let mut rng = Random::seed(seed);
213
214        let mut embedding = vec![0.0f32; dimensions];
215        for value in embedding.iter_mut().take(dimensions) {
216            *value = rng.gen_range(-1.0..1.0); // Random values between -1 and 1
217        }
218
219        // Normalize if required
220        if matches!(self.config.pooling_strategy, PoolingStrategy::Mean) {
221            let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
222            if norm > 0.0 {
223                for x in &mut embedding {
224                    *x /= norm;
225                }
226            }
227        }
228
229        Ok(Vector::new(embedding))
230    }
231
232    /// Get available models from cache
233    pub fn get_cached_models(&self) -> Vec<String> {
234        self.model_cache.keys().cloned().collect()
235    }
236
237    /// Clear model cache
238    pub fn clear_cache(&mut self) {
239        self.model_cache.clear();
240    }
241
242    /// Get model dimensions
243    pub fn get_model_dimensions(&self, model_name: &str) -> Option<usize> {
244        self.model_cache.get(model_name).map(|info| info.dimensions)
245    }
246}
247
248/// HuggingFace model manager for multiple models
249#[derive(Debug)]
250pub struct HuggingFaceModelManager {
251    embedders: HashMap<String, HuggingFaceEmbedder>,
252    default_model: String,
253}
254
255impl HuggingFaceModelManager {
256    /// Create a new model manager
257    pub fn new(default_model: String) -> Self {
258        Self {
259            embedders: HashMap::new(),
260            default_model,
261        }
262    }
263
264    /// Add a model to the manager
265    pub fn add_model(&mut self, name: String, config: HuggingFaceConfig) -> Result<()> {
266        let embedder = HuggingFaceEmbedder::new(config)?;
267        self.embedders.insert(name, embedder);
268        Ok(())
269    }
270
271    /// Get embeddings using specified model
272    pub async fn embed_with_model(
273        &mut self,
274        model_name: &str,
275        content: &EmbeddableContent,
276    ) -> Result<Vector> {
277        let embedder = self
278            .embedders
279            .get_mut(model_name)
280            .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
281        embedder.embed(content).await
282    }
283
284    /// Get embeddings using default model
285    pub async fn embed(&mut self, content: &EmbeddableContent) -> Result<Vector> {
286        self.embed_with_model(&self.default_model.clone(), content)
287            .await
288    }
289
290    /// List available models
291    pub fn list_models(&self) -> Vec<String> {
292        self.embedders.keys().cloned().collect()
293    }
294}
295
296/// Integration with existing embedding config
297impl From<EmbeddingConfig> for HuggingFaceConfig {
298    fn from(config: EmbeddingConfig) -> Self {
299        Self {
300            model_name: config.model_name,
301            cache_dir: None,
302            device: "cpu".to_string(),
303            batch_size: 32,
304            max_length: config.max_sequence_length,
305            pooling_strategy: if config.normalize {
306                PoolingStrategy::Mean
307            } else {
308                PoolingStrategy::Cls
309            },
310            trust_remote_code: false,
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use anyhow::Result;
319
320    #[tokio::test]
321    async fn test_huggingface_embedder_creation() {
322        let embedder = HuggingFaceEmbedder::with_default_config();
323        assert!(embedder.is_ok());
324    }
325
326    #[tokio::test]
327    async fn test_model_loading() -> Result<()> {
328        let mut embedder = HuggingFaceEmbedder::with_default_config()?;
329        let result = embedder
330            .load_model("sentence-transformers/all-MiniLM-L6-v2")
331            .await;
332        assert!(result.is_ok());
333
334        let dimensions = embedder.get_model_dimensions("sentence-transformers/all-MiniLM-L6-v2");
335        assert_eq!(dimensions, Some(384));
336        Ok(())
337    }
338
339    #[tokio::test]
340    async fn test_text_embedding() -> Result<()> {
341        let mut embedder = HuggingFaceEmbedder::with_default_config()?;
342        let content = EmbeddableContent::Text("Hello, world!".to_string());
343
344        let result = embedder.embed(&content).await;
345        assert!(result.is_ok());
346
347        let embedding = result?;
348        assert_eq!(embedding.dimensions, 384);
349        Ok(())
350    }
351
352    #[tokio::test]
353    async fn test_rdf_resource_embedding() -> Result<()> {
354        let mut embedder = HuggingFaceEmbedder::with_default_config()?;
355        let mut properties = HashMap::new();
356        properties.insert("type".to_string(), vec!["Person".to_string()]);
357
358        let content = EmbeddableContent::RdfResource {
359            uri: "http://example.org/person/1".to_string(),
360            label: Some("John Doe".to_string()),
361            description: Some("A person in the knowledge graph".to_string()),
362            properties,
363        };
364
365        let result = embedder.embed(&content).await;
366        assert!(result.is_ok());
367        Ok(())
368    }
369
370    #[tokio::test]
371    async fn test_batch_embedding() -> Result<()> {
372        let mut embedder = HuggingFaceEmbedder::with_default_config()?;
373        let contents = vec![
374            EmbeddableContent::Text("First text".to_string()),
375            EmbeddableContent::Text("Second text".to_string()),
376            EmbeddableContent::Text("Third text".to_string()),
377        ];
378
379        let result = embedder.embed_batch(&contents).await;
380        assert!(result.is_ok());
381
382        let embeddings = result?;
383        assert_eq!(embeddings.len(), 3);
384        Ok(())
385    }
386
387    #[tokio::test]
388    async fn test_model_manager() {
389        let mut manager = HuggingFaceModelManager::new("default".to_string());
390        let config = HuggingFaceConfig::default();
391
392        let result = manager.add_model("default".to_string(), config);
393        assert!(result.is_ok());
394
395        let models = manager.list_models();
396        assert!(models.contains(&"default".to_string()));
397    }
398
399    #[test]
400    fn test_config_conversion() {
401        let embedding_config = EmbeddingConfig {
402            model_name: "test-model".to_string(),
403            dimensions: 768,
404            max_sequence_length: 512,
405            normalize: true,
406        };
407
408        let hf_config: HuggingFaceConfig = embedding_config.into();
409        assert_eq!(hf_config.model_name, "test-model");
410        assert_eq!(hf_config.max_length, 512);
411        assert!(matches!(hf_config.pooling_strategy, PoolingStrategy::Mean));
412    }
413
414    #[test]
415    fn test_pooling_strategies() {
416        let strategies = vec![
417            PoolingStrategy::Cls,
418            PoolingStrategy::Mean,
419            PoolingStrategy::Max,
420            PoolingStrategy::AttentionWeighted,
421        ];
422
423        for strategy in strategies {
424            let config = HuggingFaceConfig {
425                pooling_strategy: strategy,
426                ..Default::default()
427            };
428            assert!(matches!(
429                config.pooling_strategy,
430                PoolingStrategy::Cls
431                    | PoolingStrategy::Mean
432                    | PoolingStrategy::Max
433                    | PoolingStrategy::AttentionWeighted
434            ));
435        }
436    }
437}