1use crate::real_time_embedding_pipeline::traits::{
4 ContentItem, EmbeddingGenerator, GeneratorStatistics, ProcessingResult, ProcessingStatus,
5};
6use crate::Vector;
7use anyhow::{anyhow, Result};
8use scirs2_core::random::Random;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PyTorchConfig {
17 pub model_path: PathBuf,
18 pub device: PyTorchDevice,
19 pub batch_size: usize,
20 pub num_workers: usize,
21 pub pin_memory: bool,
22 pub mixed_precision: bool,
23 pub compile_mode: CompileMode,
24 pub optimization_level: usize,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub enum PyTorchDevice {
30 Cpu,
31 Cuda { device_id: usize },
32 Mps, Auto, }
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum CompileMode {
39 None,
40 Default,
41 Reduce,
42 Max,
43 Custom(String),
44}
45
46impl Default for PyTorchConfig {
47 fn default() -> Self {
48 Self {
49 model_path: PathBuf::from("./models/pytorch_model.pt"),
50 device: PyTorchDevice::Auto,
51 batch_size: 32,
52 num_workers: 4,
53 pin_memory: true,
54 mixed_precision: false,
55 compile_mode: CompileMode::Default,
56 optimization_level: 1,
57 }
58 }
59}
60
61#[derive(Debug)]
63pub struct PyTorchEmbedder {
64 config: PyTorchConfig,
65 model_loaded: bool,
66 model_metadata: Option<PyTorchModelMetadata>,
67 tokenizer: Option<PyTorchTokenizer>,
68}
69
70#[derive(Debug, Clone)]
72pub struct PyTorchModelMetadata {
73 pub model_name: String,
74 pub model_version: String,
75 pub input_shape: Vec<i64>,
76 pub output_shape: Vec<i64>,
77 pub embedding_dimension: usize,
78 pub vocab_size: Option<usize>,
79 pub max_sequence_length: usize,
80 pub architecture_type: ArchitectureType,
81}
82
83#[derive(Debug, Clone)]
85pub enum ArchitectureType {
86 Transformer,
87 Cnn,
88 Rnn,
89 Lstm,
90 Gru,
91 Bert,
92 Roberta,
93 Gpt,
94 T5,
95 Custom(String),
96}
97
98#[derive(Debug, Clone)]
100pub struct PyTorchTokenizer {
101 pub vocab: HashMap<String, i32>,
102 pub special_tokens: HashMap<String, i32>,
103 pub max_length: usize,
104 pub padding_token: String,
105 pub unknown_token: String,
106 pub cls_token: Option<String>,
107 pub sep_token: Option<String>,
108}
109
110impl Default for PyTorchTokenizer {
111 fn default() -> Self {
112 let mut special_tokens = HashMap::new();
113 special_tokens.insert("[PAD]".to_string(), 0);
114 special_tokens.insert("[UNK]".to_string(), 1);
115 special_tokens.insert("[CLS]".to_string(), 2);
116 special_tokens.insert("[SEP]".to_string(), 3);
117
118 Self {
119 vocab: HashMap::new(),
120 special_tokens,
121 max_length: 512,
122 padding_token: "[PAD]".to_string(),
123 unknown_token: "[UNK]".to_string(),
124 cls_token: Some("[CLS]".to_string()),
125 sep_token: Some("[SEP]".to_string()),
126 }
127 }
128}
129
130impl PyTorchEmbedder {
131 pub fn new(config: PyTorchConfig) -> Result<Self> {
133 Ok(Self {
134 config,
135 model_loaded: false,
136 model_metadata: None,
137 tokenizer: Some(PyTorchTokenizer::default()),
138 })
139 }
140
141 pub fn load_model(&mut self) -> Result<()> {
143 if !self.config.model_path.exists() {
144 return Err(anyhow!(
145 "Model file not found: {:?}",
146 self.config.model_path
147 ));
148 }
149
150 let metadata = PyTorchModelMetadata {
152 model_name: "pytorch_embedder".to_string(),
153 model_version: "1.0.0".to_string(),
154 input_shape: vec![-1, 512], output_shape: vec![-1, 768], embedding_dimension: 768,
157 vocab_size: Some(30000),
158 max_sequence_length: 512,
159 architecture_type: ArchitectureType::Transformer,
160 };
161
162 self.model_metadata = Some(metadata);
163 self.model_loaded = true;
164 Ok(())
165 }
166
167 pub fn embed_text(&self, text: &str) -> Result<Vector> {
169 if !self.model_loaded {
170 return Err(anyhow!("Model not loaded. Call load_model() first."));
171 }
172
173 let tokens = self.tokenize_text(text)?;
174 let embedding = self.forward_pass(&tokens)?;
175 Ok(Vector::new(embedding))
176 }
177
178 pub fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vector>> {
180 if !self.model_loaded {
181 return Err(anyhow!("Model not loaded"));
182 }
183
184 let mut results = Vec::new();
185
186 for chunk in texts.chunks(self.config.batch_size) {
188 let mut batch_tokens = Vec::new();
189 for text in chunk {
190 batch_tokens.push(self.tokenize_text(text)?);
191 }
192
193 let batch_embeddings = self.forward_pass_batch(&batch_tokens)?;
194 for embedding in batch_embeddings {
195 results.push(Vector::new(embedding));
196 }
197 }
198
199 Ok(results)
200 }
201
202 fn tokenize_text(&self, text: &str) -> Result<Vec<i32>> {
204 let tokenizer = self
205 .tokenizer
206 .as_ref()
207 .ok_or_else(|| anyhow!("Tokenizer not available"))?;
208
209 let mut tokens = Vec::new();
210
211 if let Some(cls_token) = &tokenizer.cls_token {
213 if let Some(&token_id) = tokenizer.special_tokens.get(cls_token) {
214 tokens.push(token_id);
215 }
216 }
217
218 let words: Vec<&str> = text.split_whitespace().collect();
220 for word in words {
221 let token_id = tokenizer
222 .vocab
223 .get(word)
224 .or_else(|| tokenizer.special_tokens.get(&tokenizer.unknown_token))
225 .copied()
226 .unwrap_or(1); tokens.push(token_id);
228 }
229
230 if let Some(sep_token) = &tokenizer.sep_token {
232 if let Some(&token_id) = tokenizer.special_tokens.get(sep_token) {
233 tokens.push(token_id);
234 }
235 }
236
237 if tokens.len() > tokenizer.max_length {
239 tokens.truncate(tokenizer.max_length);
240 } else {
241 let pad_token_id = tokenizer
242 .special_tokens
243 .get(&tokenizer.padding_token)
244 .copied()
245 .unwrap_or(0);
246 tokens.resize(tokenizer.max_length, pad_token_id);
247 }
248
249 Ok(tokens)
250 }
251
252 fn forward_pass(&self, tokens: &[i32]) -> Result<Vec<f32>> {
254 let metadata = self
255 .model_metadata
256 .as_ref()
257 .ok_or_else(|| anyhow!("Model metadata not available"))?;
258
259 let mut rng = Random::seed(tokens.iter().map(|&t| t as u64).sum::<u64>());
261 use scirs2_core::random::Rng;
262
263 let mut embedding = vec![0.0f32; metadata.embedding_dimension];
264 for value in &mut embedding {
265 *value = rng.gen_range(-1.0..1.0);
266 }
267
268 let mean = embedding.iter().sum::<f32>() / embedding.len() as f32;
270 let variance =
271 embedding.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / embedding.len() as f32;
272 let std_dev = variance.sqrt();
273
274 if std_dev > 0.0 {
275 for x in &mut embedding {
276 *x = (*x - mean) / std_dev;
277 }
278 }
279
280 Ok(embedding)
281 }
282
283 fn forward_pass_batch(&self, batch_tokens: &[Vec<i32>]) -> Result<Vec<Vec<f32>>> {
285 let mut results = Vec::new();
286 for tokens in batch_tokens {
287 results.push(self.forward_pass(tokens)?);
288 }
289 Ok(results)
290 }
291
292 pub fn get_metadata(&self) -> Option<&PyTorchModelMetadata> {
294 self.model_metadata.as_ref()
295 }
296
297 pub fn get_dimensions(&self) -> Option<usize> {
299 self.model_metadata.as_ref().map(|m| m.embedding_dimension)
300 }
301
302 pub fn set_tokenizer(&mut self, tokenizer: PyTorchTokenizer) {
304 self.tokenizer = Some(tokenizer);
305 }
306
307 pub fn supports_mixed_precision(&self) -> bool {
309 self.config.mixed_precision
310 }
311
312 pub fn get_device(&self) -> &PyTorchDevice {
314 &self.config.device
315 }
316}
317
318#[derive(Debug)]
320pub struct PyTorchModelManager {
321 models: HashMap<String, PyTorchEmbedder>,
322 default_model: String,
323 device_manager: DeviceManager,
324}
325
326#[derive(Debug)]
328pub struct DeviceManager {
329 available_devices: Vec<PyTorchDevice>,
330 current_device: PyTorchDevice,
331 memory_usage: HashMap<String, usize>,
332}
333
334impl DeviceManager {
335 pub fn new() -> Self {
337 let available_devices = Self::detect_available_devices();
338 let current_device = available_devices
339 .first()
340 .cloned()
341 .unwrap_or(PyTorchDevice::Cpu);
342
343 Self {
344 available_devices,
345 current_device,
346 memory_usage: HashMap::new(),
347 }
348 }
349
350 fn detect_available_devices() -> Vec<PyTorchDevice> {
352 let mut devices = vec![PyTorchDevice::Cpu];
353
354 devices.push(PyTorchDevice::Cuda { device_id: 0 });
356 devices.push(PyTorchDevice::Mps);
357
358 devices
359 }
360
361 pub fn get_optimal_device(&self) -> &PyTorchDevice {
363 &self.current_device
364 }
365
366 pub fn update_memory_usage(&mut self, device: String, usage: usize) {
368 self.memory_usage.insert(device, usage);
369 }
370
371 pub fn get_memory_usage(&self) -> &HashMap<String, usize> {
373 &self.memory_usage
374 }
375}
376
377impl Default for DeviceManager {
378 fn default() -> Self {
379 Self::new()
380 }
381}
382
383impl PyTorchModelManager {
384 pub fn new(default_model: String) -> Self {
386 Self {
387 models: HashMap::new(),
388 default_model,
389 device_manager: DeviceManager::new(),
390 }
391 }
392
393 pub fn register_model(&mut self, name: String, mut embedder: PyTorchEmbedder) -> Result<()> {
395 embedder.load_model()?;
396 self.models.insert(name, embedder);
397 Ok(())
398 }
399
400 pub fn list_models(&self) -> Vec<String> {
402 self.models.keys().cloned().collect()
403 }
404
405 pub fn embed_with_model(&self, model_name: &str, texts: &[String]) -> Result<Vec<Vector>> {
407 let model = self
408 .models
409 .get(model_name)
410 .ok_or_else(|| anyhow!("Model not found: {}", model_name))?;
411
412 model.embed_batch(texts)
413 }
414
415 pub fn embed(&self, texts: &[String]) -> Result<Vec<Vector>> {
417 self.embed_with_model(&self.default_model, texts)
418 }
419
420 pub fn get_device_manager(&self) -> &DeviceManager {
422 &self.device_manager
423 }
424
425 pub fn update_device_manager(&mut self, device_manager: DeviceManager) {
427 self.device_manager = device_manager;
428 }
429}
430
431impl EmbeddingGenerator for PyTorchEmbedder {
432 fn generate_embedding(&self, content: &ContentItem) -> Result<Vector> {
433 self.embed_text(&content.content)
434 }
435
436 fn generate_batch_embeddings(&self, content: &[ContentItem]) -> Result<Vec<ProcessingResult>> {
437 let mut results = Vec::new();
438
439 for item in content {
440 let start_time = Instant::now();
441 let vector_result = self.generate_embedding(item);
442 let duration = start_time.elapsed();
443
444 let result = match vector_result {
445 Ok(vector) => ProcessingResult {
446 item: item.clone(),
447 vector: Some(vector),
448 status: ProcessingStatus::Completed,
449 duration,
450 error: None,
451 metadata: HashMap::new(),
452 },
453 Err(e) => ProcessingResult {
454 item: item.clone(),
455 vector: None,
456 status: ProcessingStatus::Failed {
457 reason: e.to_string(),
458 },
459 duration,
460 error: Some(e.to_string()),
461 metadata: HashMap::new(),
462 },
463 };
464
465 results.push(result);
466 }
467
468 Ok(results)
469 }
470
471 fn embedding_dimensions(&self) -> usize {
472 self.get_dimensions().unwrap_or(768)
473 }
474
475 fn get_config(&self) -> serde_json::Value {
476 serde_json::to_value(&self.config).unwrap_or_default()
477 }
478
479 fn is_ready(&self) -> bool {
480 self.model_loaded
481 }
482
483 fn get_statistics(&self) -> GeneratorStatistics {
484 GeneratorStatistics {
485 total_embeddings: 0,
486 total_processing_time: Duration::from_millis(0),
487 average_processing_time: Duration::from_millis(0),
488 error_count: 0,
489 last_error: None,
490 }
491 }
492}
493
494#[cfg(test)]
495#[allow(clippy::useless_vec)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_pytorch_config_creation() {
501 let config = PyTorchConfig::default();
502 assert_eq!(config.batch_size, 32);
503 assert_eq!(config.num_workers, 4);
504 assert!(config.pin_memory);
505 }
506
507 #[test]
508 fn test_pytorch_embedder_creation() {
509 let config = PyTorchConfig::default();
510 let embedder = PyTorchEmbedder::new(config);
511 assert!(embedder.is_ok());
512 assert!(!embedder.unwrap().model_loaded);
513 }
514
515 #[test]
516 fn test_tokenizer_creation() {
517 let tokenizer = PyTorchTokenizer::default();
518 assert_eq!(tokenizer.max_length, 512);
519 assert_eq!(tokenizer.padding_token, "[PAD]");
520 assert!(tokenizer.special_tokens.contains_key("[CLS]"));
521 }
522
523 #[test]
524 fn test_model_metadata() {
525 let metadata = PyTorchModelMetadata {
526 model_name: "test".to_string(),
527 model_version: "1.0".to_string(),
528 input_shape: vec![-1, 512],
529 output_shape: vec![-1, 768],
530 embedding_dimension: 768,
531 vocab_size: Some(30000),
532 max_sequence_length: 512,
533 architecture_type: ArchitectureType::Transformer,
534 };
535
536 assert_eq!(metadata.embedding_dimension, 768);
537 assert_eq!(metadata.vocab_size, Some(30000));
538 }
539
540 #[test]
541 fn test_device_manager_creation() {
542 let device_manager = DeviceManager::new();
543 assert!(!device_manager.available_devices.is_empty());
544 assert!(matches!(device_manager.current_device, PyTorchDevice::Cpu));
545 }
546
547 #[test]
548 fn test_model_manager_creation() {
549 let manager = PyTorchModelManager::new("default".to_string());
550 assert_eq!(manager.default_model, "default");
551 assert!(manager.list_models().is_empty());
552 }
553
554 #[test]
555 fn test_architecture_types() {
556 let arch_types = vec![
557 ArchitectureType::Transformer,
558 ArchitectureType::Bert,
559 ArchitectureType::Gpt,
560 ArchitectureType::Custom("MyModel".to_string()),
561 ];
562 assert_eq!(arch_types.len(), 4);
563 }
564
565 #[test]
566 fn test_device_types() {
567 let devices = vec![
568 PyTorchDevice::Cpu,
569 PyTorchDevice::Cuda { device_id: 0 },
570 PyTorchDevice::Mps,
571 PyTorchDevice::Auto,
572 ];
573 assert_eq!(devices.len(), 4);
574 }
575
576 #[test]
577 fn test_compile_modes() {
578 let modes = vec![
579 CompileMode::None,
580 CompileMode::Default,
581 CompileMode::Max,
582 CompileMode::Custom("custom".to_string()),
583 ];
584 assert_eq!(modes.len(), 4);
585 }
586
587 #[test]
588 fn test_tokenizer_special_tokens() {
589 let tokenizer = PyTorchTokenizer::default();
590 assert!(tokenizer.special_tokens.contains_key("[PAD]"));
591 assert!(tokenizer.special_tokens.contains_key("[UNK]"));
592 assert!(tokenizer.special_tokens.contains_key("[CLS]"));
593 assert!(tokenizer.special_tokens.contains_key("[SEP]"));
594 }
595}