1use crate::real_time_embedding_pipeline::traits::{
4 ContentItem, EmbeddingGenerator, GeneratorStatistics, ProcessingResult, ProcessingStatus,
5};
6use crate::Vector;
7use anyhow::{anyhow, Result};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PyTorchConfig {
17 pub model_path: PathBuf,
18 pub device: PyTorchDevice,
19 pub batch_size: usize,
20 pub num_workers: usize,
21 pub pin_memory: bool,
22 pub mixed_precision: bool,
23 pub compile_mode: CompileMode,
24 pub optimization_level: usize,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum PyTorchDevice {
30 Cpu,
31 Cuda { device_id: usize },
32 Mps, Auto, }
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum CompileMode {
39 None,
40 Default,
41 Reduce,
42 Max,
43 Custom(String),
44}
45
46impl Default for PyTorchConfig {
47 fn default() -> Self {
48 Self {
49 model_path: PathBuf::from("./models/pytorch_model.pt"),
50 device: PyTorchDevice::Auto,
51 batch_size: 32,
52 num_workers: 4,
53 pin_memory: true,
54 mixed_precision: false,
55 compile_mode: CompileMode::Default,
56 optimization_level: 1,
57 }
58 }
59}
60
61#[derive(Debug)]
63pub struct PyTorchEmbedder {
64 config: PyTorchConfig,
65 model_loaded: bool,
66 model_metadata: Option<PyTorchModelMetadata>,
67 tokenizer: Option<PyTorchTokenizer>,
68}
69
70#[derive(Debug, Clone)]
72pub struct PyTorchModelMetadata {
73 pub model_name: String,
74 pub model_version: String,
75 pub input_shape: Vec<i64>,
76 pub output_shape: Vec<i64>,
77 pub embedding_dimension: usize,
78 pub vocab_size: Option<usize>,
79 pub max_sequence_length: usize,
80 pub architecture_type: ArchitectureType,
81}
82
83#[derive(Debug, Clone)]
85pub enum ArchitectureType {
86 Transformer,
87 Cnn,
88 Rnn,
89 Lstm,
90 Gru,
91 Bert,
92 Roberta,
93 Gpt,
94 T5,
95 Custom(String),
96}
97
98#[derive(Debug, Clone)]
100pub struct PyTorchTokenizer {
101 pub vocab: HashMap<String, i32>,
102 pub special_tokens: HashMap<String, i32>,
103 pub max_length: usize,
104 pub padding_token: String,
105 pub unknown_token: String,
106 pub cls_token: Option<String>,
107 pub sep_token: Option<String>,
108}
109
110impl Default for PyTorchTokenizer {
111 fn default() -> Self {
112 let mut special_tokens = HashMap::new();
113 special_tokens.insert("[PAD]".to_string(), 0);
114 special_tokens.insert("[UNK]".to_string(), 1);
115 special_tokens.insert("[CLS]".to_string(), 2);
116 special_tokens.insert("[SEP]".to_string(), 3);
117
118 Self {
119 vocab: HashMap::new(),
120 special_tokens,
121 max_length: 512,
122 padding_token: "[PAD]".to_string(),
123 unknown_token: "[UNK]".to_string(),
124 cls_token: Some("[CLS]".to_string()),
125 sep_token: Some("[SEP]".to_string()),
126 }
127 }
128}
129
130impl PyTorchEmbedder {
131 pub fn new(config: PyTorchConfig) -> Result<Self> {
133 Ok(Self {
134 config,
135 model_loaded: false,
136 model_metadata: None,
137 tokenizer: Some(PyTorchTokenizer::default()),
138 })
139 }
140
141 pub fn load_model(&mut self) -> Result<()> {
143 if !self.config.model_path.exists() {
144 return Err(anyhow!(
145 "Model file not found: {:?}",
146 self.config.model_path
147 ));
148 }
149
150 let metadata = PyTorchModelMetadata {
152 model_name: "pytorch_embedder".to_string(),
153 model_version: "1.0.0".to_string(),
154 input_shape: vec![-1, 512], output_shape: vec![-1, 768], embedding_dimension: 768,
157 vocab_size: Some(30000),
158 max_sequence_length: 512,
159 architecture_type: ArchitectureType::Transformer,
160 };
161
162 self.model_metadata = Some(metadata);
163 self.model_loaded = true;
164 Ok(())
165 }
166
167 pub fn embed_text(&self, text: &str) -> Result<Vector> {
169 if !self.model_loaded {
170 return Err(anyhow!("Model not loaded. Call load_model() first."));
171 }
172
173 let tokens = self.tokenize_text(text)?;
174 let embedding = self.forward_pass(&tokens)?;
175 Ok(Vector::new(embedding))
176 }
177
178 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
180 if !self.model_loaded {
181 return Err(anyhow!("Model not loaded"));
182 }
183
184 let mut results = Vec::new();
185
186 for chunk in texts.chunks(self.config.batch_size) {
188 let mut batch_tokens = Vec::new();
189 for text in chunk {
190 batch_tokens.push(self.tokenize_text(text)?);
191 }
192
193 let batch_embeddings = self.forward_pass_batch(&batch_tokens)?;
194 for embedding in batch_embeddings {
195 results.push(Vector::new(embedding));
196 }
197 }
198
199 Ok(results)
200 }
201
202 fn tokenize_text(&self, text: &str) -> Result<Vec<i32>> {
204 let tokenizer = self
205 .tokenizer
206 .as_ref()
207 .ok_or_else(|| anyhow!("Tokenizer not available"))?;
208
209 let mut tokens = Vec::new();
210
211 if let Some(cls_token) = &tokenizer.cls_token {
213 if let Some(&token_id) = tokenizer.special_tokens.get(cls_token) {
214 tokens.push(token_id);
215 }
216 }
217
218 let words: Vec<&str> = text.split_whitespace().collect();
220 for word in words {
221 let token_id = tokenizer
222 .vocab
223 .get(word)
224 .or_else(|| tokenizer.special_tokens.get(&tokenizer.unknown_token))
225 .copied()
226 .unwrap_or(1); tokens.push(token_id);
228 }
229
230 if let Some(sep_token) = &tokenizer.sep_token {
232 if let Some(&token_id) = tokenizer.special_tokens.get(sep_token) {
233 tokens.push(token_id);
234 }
235 }
236
237 if tokens.len() > tokenizer.max_length {
239 tokens.truncate(tokenizer.max_length);
240 } else {
241 let pad_token_id = tokenizer
242 .special_tokens
243 .get(&tokenizer.padding_token)
244 .copied()
245 .unwrap_or(0);
246 tokens.resize(tokenizer.max_length, pad_token_id);
247 }
248
249 Ok(tokens)
250 }
251
252 fn forward_pass(&self, tokens: &[i32]) -> Result<Vec<f32>> {
254 let metadata = self
255 .model_metadata
256 .as_ref()
257 .ok_or_else(|| anyhow!("Model metadata not available"))?;
258
259 let mut rng = Random::seed(tokens.iter().map(|&t| t as u64).sum::<u64>());
261
262 let mut embedding = vec![0.0f32; metadata.embedding_dimension];
263 for value in &mut embedding {
264 *value = rng.gen_range(-1.0..1.0);
265 }
266
267 let mean = embedding.iter().sum::<f32>() / embedding.len() as f32;
269 let variance =
270 embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / embedding.len() as f32;
271 let std_dev = variance.sqrt();
272
273 if std_dev > 0.0 {
274 for x in &mut embedding {
275 *x = (*x - mean) / std_dev;
276 }
277 }
278
279 Ok(embedding)
280 }
281
282 fn forward_pass_batch(&self, batch_tokens: &[Vec<i32>]) -> Result<Vec<Vec<f32>>> {
284 let mut results = Vec::new();
285 for tokens in batch_tokens {
286 results.push(self.forward_pass(tokens)?);
287 }
288 Ok(results)
289 }
290
291 pub fn get_metadata(&self) -> Option<&PyTorchModelMetadata> {
293 self.model_metadata.as_ref()
294 }
295
296 pub fn get_dimensions(&self) -> Option<usize> {
298 self.model_metadata.as_ref().map(|m| m.embedding_dimension)
299 }
300
301 pub fn set_tokenizer(&mut self, tokenizer: PyTorchTokenizer) {
303 self.tokenizer = Some(tokenizer);
304 }
305
306 pub fn supports_mixed_precision(&self) -> bool {
308 self.config.mixed_precision
309 }
310
311 pub fn get_device(&self) -> &PyTorchDevice {
313 &self.config.device
314 }
315}
316
317#[derive(Debug)]
319pub struct PyTorchModelManager {
320 models: HashMap<String, PyTorchEmbedder>,
321 default_model: String,
322 device_manager: DeviceManager,
323}
324
325#[derive(Debug)]
327pub struct DeviceManager {
328 available_devices: Vec<PyTorchDevice>,
329 current_device: PyTorchDevice,
330 memory_usage: HashMap<String, usize>,
331}
332
333impl DeviceManager {
334 pub fn new() -> Self {
336 let available_devices = Self::detect_available_devices();
337 let current_device = available_devices
338 .first()
339 .cloned()
340 .unwrap_or(PyTorchDevice::Cpu);
341
342 Self {
343 available_devices,
344 current_device,
345 memory_usage: HashMap::new(),
346 }
347 }
348
349 fn detect_available_devices() -> Vec<PyTorchDevice> {
351 let mut devices = vec![PyTorchDevice::Cpu];
352
353 devices.push(PyTorchDevice::Cuda { device_id: 0 });
355 devices.push(PyTorchDevice::Mps);
356
357 devices
358 }
359
360 pub fn get_optimal_device(&self) -> &PyTorchDevice {
362 &self.current_device
363 }
364
365 pub fn update_memory_usage(&mut self, device: String, usage: usize) {
367 self.memory_usage.insert(device, usage);
368 }
369
370 pub fn get_memory_usage(&self) -> &HashMap<String, usize> {
372 &self.memory_usage
373 }
374}
375
376impl Default for DeviceManager {
377 fn default() -> Self {
378 Self::new()
379 }
380}
381
382impl PyTorchModelManager {
383 pub fn new(default_model: String) -> Self {
385 Self {
386 models: HashMap::new(),
387 default_model,
388 device_manager: DeviceManager::new(),
389 }
390 }
391
392 pub fn register_model(&mut self, name: String, mut embedder: PyTorchEmbedder) -> Result<()> {
394 embedder.load_model()?;
395 self.models.insert(name, embedder);
396 Ok(())
397 }
398
399 pub fn list_models(&self) -> Vec<String> {
401 self.models.keys().cloned().collect()
402 }
403
404 pub fn embed_with_model(&self, model_name: &str, texts: &[String]) -> Result<Vec<Vector>> {
406 let model = self
407 .models
408 .get(model_name)
409 .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
410
411 model.embed_batch(texts)
412 }
413
414 pub fn embed(&self, texts: &[String]) -> Result<Vec<Vector>> {
416 self.embed_with_model(&self.default_model, texts)
417 }
418
419 pub fn get_device_manager(&self) -> &DeviceManager {
421 &self.device_manager
422 }
423
424 pub fn update_device_manager(&mut self, device_manager: DeviceManager) {
426 self.device_manager = device_manager;
427 }
428}
429
430impl EmbeddingGenerator for PyTorchEmbedder {
431 fn generate_embedding(&self, content: &ContentItem) -> Result<Vector> {
432 self.embed_text(&content.content)
433 }
434
435 fn generate_batch_embeddings(&self, content: &[ContentItem]) -> Result<Vec<ProcessingResult>> {
436 let mut results = Vec::new();
437
438 for item in content {
439 let start_time = Instant::now();
440 let vector_result = self.generate_embedding(item);
441 let duration = start_time.elapsed();
442
443 let result = match vector_result {
444 Ok(vector) => ProcessingResult {
445 item: item.clone(),
446 vector: Some(vector),
447 status: ProcessingStatus::Completed,
448 duration,
449 error: None,
450 metadata: HashMap::new(),
451 },
452 Err(e) => ProcessingResult {
453 item: item.clone(),
454 vector: None,
455 status: ProcessingStatus::Failed {
456 reason: e.to_string(),
457 },
458 duration,
459 error: Some(e.to_string()),
460 metadata: HashMap::new(),
461 },
462 };
463
464 results.push(result);
465 }
466
467 Ok(results)
468 }
469
470 fn embedding_dimensions(&self) -> usize {
471 self.get_dimensions().unwrap_or(768)
472 }
473
474 fn get_config(&self) -> serde_json::Value {
475 serde_json::to_value(&self.config).unwrap_or_default()
476 }
477
478 fn is_ready(&self) -> bool {
479 self.model_loaded
480 }
481
482 fn get_statistics(&self) -> GeneratorStatistics {
483 GeneratorStatistics {
484 total_embeddings: 0,
485 total_processing_time: Duration::from_millis(0),
486 average_processing_time: Duration::from_millis(0),
487 error_count: 0,
488 last_error: None,
489 }
490 }
491}
492
493#[cfg(test)]
494#[allow(clippy::useless_vec)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_pytorch_config_creation() {
500 let config = PyTorchConfig::default();
501 assert_eq!(config.batch_size, 32);
502 assert_eq!(config.num_workers, 4);
503 assert!(config.pin_memory);
504 }
505
506 #[test]
507 fn test_pytorch_embedder_creation() {
508 let config = PyTorchConfig::default();
509 let embedder = PyTorchEmbedder::new(config);
510 assert!(embedder.is_ok());
511 assert!(!embedder.unwrap().model_loaded);
512 }
513
514 #[test]
515 fn test_tokenizer_creation() {
516 let tokenizer = PyTorchTokenizer::default();
517 assert_eq!(tokenizer.max_length, 512);
518 assert_eq!(tokenizer.padding_token, "[PAD]");
519 assert!(tokenizer.special_tokens.contains_key("[CLS]"));
520 }
521
522 #[test]
523 fn test_model_metadata() {
524 let metadata = PyTorchModelMetadata {
525 model_name: "test".to_string(),
526 model_version: "1.0".to_string(),
527 input_shape: vec![-1, 512],
528 output_shape: vec![-1, 768],
529 embedding_dimension: 768,
530 vocab_size: Some(30000),
531 max_sequence_length: 512,
532 architecture_type: ArchitectureType::Transformer,
533 };
534
535 assert_eq!(metadata.embedding_dimension, 768);
536 assert_eq!(metadata.vocab_size, Some(30000));
537 }
538
539 #[test]
540 fn test_device_manager_creation() {
541 let device_manager = DeviceManager::new();
542 assert!(!device_manager.available_devices.is_empty());
543 assert!(matches!(device_manager.current_device, PyTorchDevice::Cpu));
544 }
545
546 #[test]
547 fn test_model_manager_creation() {
548 let manager = PyTorchModelManager::new("default".to_string());
549 assert_eq!(manager.default_model, "default");
550 assert!(manager.list_models().is_empty());
551 }
552
553 #[test]
554 fn test_architecture_types() {
555 let arch_types = vec![
556 ArchitectureType::Transformer,
557 ArchitectureType::Bert,
558 ArchitectureType::Gpt,
559 ArchitectureType::Custom("MyModel".to_string()),
560 ];
561 assert_eq!(arch_types.len(), 4);
562 }
563
564 #[test]
565 fn test_device_types() {
566 let devices = vec![
567 PyTorchDevice::Cpu,
568 PyTorchDevice::Cuda { device_id: 0 },
569 PyTorchDevice::Mps,
570 PyTorchDevice::Auto,
571 ];
572 assert_eq!(devices.len(), 4);
573 }
574
575 #[test]
576 fn test_compile_modes() {
577 let modes = vec![
578 CompileMode::None,
579 CompileMode::Default,
580 CompileMode::Max,
581 CompileMode::Custom("custom".to_string()),
582 ];
583 assert_eq!(modes.len(), 4);
584 }
585
586 #[test]
587 fn test_tokenizer_special_tokens() {
588 let tokenizer = PyTorchTokenizer::default();
589 assert!(tokenizer.special_tokens.contains_key("[PAD]"));
590 assert!(tokenizer.special_tokens.contains_key("[UNK]"));
591 assert!(tokenizer.special_tokens.contains_key("[CLS]"));
592 assert!(tokenizer.special_tokens.contains_key("[SEP]"));
593 }
594}