Skip to main content

batuta/
pytorch_converter.rs

1//! PyTorch to Realizar conversion module (BATUTA-010)
2//!
3//! Converts Python PyTorch inference code to Rust Realizar equivalents
4//! with automatic backend selection for CPU/GPU/WASM execution.
5//!
6//! # Conversion Strategy
7//!
8//! PyTorch inference patterns are mapped to Realizar equivalents:
9//! - `torch.load(path)` → GGUF/SafeTensors model loading
10//! - `model.forward(x)` → Realizar inference pipeline
11//! - `torch.Tensor` → `realizar::tensor::Tensor`
12//! - `torch.nn.Linear` → `realizar::layers::LinearLayer`
13//! - Tokenization → `realizar::tokenizer`
14//! - Generation → `realizar::generate`
15//!
16//! # Example
17//!
18//! ```python
19//! # Python PyTorch inference code
20//! import torch
21//! from transformers import AutoModelForCausalLM, AutoTokenizer
22//!
23//! model = AutoModelForCausalLM.from_pretrained("model_name")
24//! tokenizer = AutoTokenizer.from_pretrained("model_name")
25//! inputs = tokenizer("Hello, world!", return_tensors="pt")
26//! outputs = model.generate(**inputs, max_length=50)
27//! text = tokenizer.decode(outputs[0])
28//! ```
29//!
30//! Converts to:
31//!
32//! ```rust,ignore
33//! use realizar::gguf::GGUFModel;
34//! use realizar::tokenizer::Tokenizer;
35//! use realizar::generate::generate_text;
36//!
37//! let model = GGUFModel::from_file("model.gguf")?;
38//! let tokenizer = Tokenizer::from_file("tokenizer.json")?;
39//! let tokens = tokenizer.encode("Hello, world!")?;
40//! let output = generate_text(&model, &tokens, 50)?;
41//! let text = tokenizer.decode(&output)?;
42//! ```
43
44use std::collections::HashMap;
45
46/// PyTorch operation types (inference-focused)
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48#[allow(clippy::upper_case_acronyms)]
49pub enum PyTorchOperation {
50    // Model Loading
51    LoadModel,     // torch.load(), from_pretrained()
52    SaveModel,     // torch.save()
53    LoadTokenizer, // AutoTokenizer.from_pretrained()
54
55    // Inference Operations
56    Forward,  // model(x), model.forward(x)
57    Generate, // model.generate()
58    Predict,  // model.predict()
59
60    // Tensor Operations
61    TensorCreation, // torch.tensor(), torch.zeros()
62    TensorReshape,  // tensor.view(), tensor.reshape()
63    TensorSlice,    // tensor[start:end]
64
65    // Layer Types
66    Linear,    // nn.Linear
67    Embedding, // nn.Embedding
68    LayerNorm, // nn.LayerNorm
69    Attention, // nn.MultiheadAttention
70
71    // Activation Functions
72    ReLU,    // nn.ReLU
73    GELU,    // nn.GELU
74    Softmax, // nn.Softmax
75
76    // Tokenization
77    Encode, // tokenizer.encode()
78    Decode, // tokenizer.decode()
79
80    // Utilities
81    NoGrad, // torch.no_grad()
82    Eval,   // model.eval()
83}
84
85impl PyTorchOperation {
86    /// Get the computational complexity for MoE routing
87    pub fn complexity(&self) -> crate::backend::OpComplexity {
88        use crate::backend::OpComplexity;
89
90        match self {
91            // Simple operations are Low complexity
92            PyTorchOperation::TensorCreation
93            | PyTorchOperation::TensorReshape
94            | PyTorchOperation::TensorSlice
95            | PyTorchOperation::Encode
96            | PyTorchOperation::Decode
97            | PyTorchOperation::NoGrad
98            | PyTorchOperation::Eval => OpComplexity::Low,
99
100            // Layer operations and activations are Medium complexity
101            PyTorchOperation::Linear
102            | PyTorchOperation::Embedding
103            | PyTorchOperation::LayerNorm
104            | PyTorchOperation::ReLU
105            | PyTorchOperation::GELU
106            | PyTorchOperation::Softmax
107            | PyTorchOperation::LoadModel
108            | PyTorchOperation::SaveModel
109            | PyTorchOperation::LoadTokenizer => OpComplexity::Medium,
110
111            // Inference and generation are High complexity
112            PyTorchOperation::Forward
113            | PyTorchOperation::Generate
114            | PyTorchOperation::Predict
115            | PyTorchOperation::Attention => OpComplexity::High,
116        }
117    }
118
119    /// Get the PyTorch module path
120    pub fn pytorch_module(&self) -> &str {
121        match self {
122            PyTorchOperation::LoadModel
123            | PyTorchOperation::SaveModel
124            | PyTorchOperation::TensorCreation
125            | PyTorchOperation::TensorReshape
126            | PyTorchOperation::TensorSlice
127            | PyTorchOperation::NoGrad => "torch",
128
129            PyTorchOperation::Linear
130            | PyTorchOperation::Embedding
131            | PyTorchOperation::LayerNorm
132            | PyTorchOperation::Attention
133            | PyTorchOperation::ReLU
134            | PyTorchOperation::GELU
135            | PyTorchOperation::Softmax => "torch.nn",
136
137            PyTorchOperation::LoadTokenizer
138            | PyTorchOperation::Encode
139            | PyTorchOperation::Decode => "transformers",
140
141            PyTorchOperation::Forward
142            | PyTorchOperation::Generate
143            | PyTorchOperation::Predict
144            | PyTorchOperation::Eval => "torch.nn.Module",
145        }
146    }
147}
148
149/// Realizar equivalent operation
150#[derive(Debug, Clone)]
151pub struct RealizarOperation {
152    /// Rust code template for the operation
153    pub code_template: String,
154    /// Required imports
155    pub imports: Vec<String>,
156    /// Computational complexity
157    pub complexity: crate::backend::OpComplexity,
158    /// Typical usage pattern
159    pub usage_pattern: String,
160}
161
162/// PyTorch to Realizar converter
163pub struct PyTorchConverter {
164    /// Operation mapping
165    operation_map: HashMap<PyTorchOperation, RealizarOperation>,
166    /// Backend selector for MoE routing
167    backend_selector: crate::backend::BackendSelector,
168}
169
170impl Default for PyTorchConverter {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176impl PyTorchConverter {
177    /// Create a new PyTorch converter with default mappings
178    pub fn new() -> Self {
179        let mut operation_map = HashMap::new();
180
181        // Model Loading
182        operation_map.insert(
183            PyTorchOperation::LoadModel,
184            RealizarOperation {
185                code_template: "GGUFModel::from_file(\"{model_path}\")".to_string(),
186                imports: vec!["use realizar::gguf::GGUFModel;".to_string()],
187                complexity: crate::backend::OpComplexity::Medium,
188                usage_pattern: "let model = GGUFModel::from_file(\"model.gguf\")?;".to_string(),
189            },
190        );
191
192        operation_map.insert(
193            PyTorchOperation::LoadTokenizer,
194            RealizarOperation {
195                code_template: "Tokenizer::from_file(\"{tokenizer_path}\")".to_string(),
196                imports: vec!["use realizar::tokenizer::Tokenizer;".to_string()],
197                complexity: crate::backend::OpComplexity::Medium,
198                usage_pattern: "let tokenizer = Tokenizer::from_file(\"tokenizer.json\")?;"
199                    .to_string(),
200            },
201        );
202
203        // Inference Operations
204        operation_map.insert(
205            PyTorchOperation::Forward,
206            RealizarOperation {
207                code_template: "model.forward(&{input})".to_string(),
208                imports: vec!["use realizar::gguf::GGUFModel;".to_string()],
209                complexity: crate::backend::OpComplexity::High,
210                usage_pattern: "let output = model.forward(&input_tensor)?;".to_string(),
211            },
212        );
213
214        operation_map.insert(
215            PyTorchOperation::Generate,
216            RealizarOperation {
217                code_template: "generate_text(&model, &{tokens}, {max_length})".to_string(),
218                imports: vec![
219                    "use realizar::generate::generate_text;".to_string(),
220                ],
221                complexity: crate::backend::OpComplexity::High,
222                usage_pattern: "let output = generate_text(&model, &input_tokens, 50)?;\nlet text = tokenizer.decode(&output)?;".to_string(),
223            },
224        );
225
226        // Tensor Operations
227        operation_map.insert(
228            PyTorchOperation::TensorCreation,
229            RealizarOperation {
230                code_template: "Tensor::from_vec({data})".to_string(),
231                imports: vec!["use realizar::tensor::Tensor;".to_string()],
232                complexity: crate::backend::OpComplexity::Low,
233                usage_pattern: "let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0])?;".to_string(),
234            },
235        );
236
237        // Layer Types
238        operation_map.insert(
239            PyTorchOperation::Linear,
240            RealizarOperation {
241                code_template: "LinearLayer::new({in_features}, {out_features})".to_string(),
242                imports: vec!["use realizar::layers::LinearLayer;".to_string()],
243                complexity: crate::backend::OpComplexity::Medium,
244                usage_pattern:
245                    "let layer = LinearLayer::new(768, 512)?;\nlet output = layer.forward(&input)?;"
246                        .to_string(),
247            },
248        );
249
250        operation_map.insert(
251            PyTorchOperation::Attention,
252            RealizarOperation {
253                code_template: "AttentionLayer::new({embed_dim}, {num_heads})".to_string(),
254                imports: vec!["use realizar::layers::AttentionLayer;".to_string()],
255                complexity: crate::backend::OpComplexity::High,
256                usage_pattern:
257                    "let attn = AttentionLayer::new(512, 8)?;\nlet output = attn.forward(&input)?;"
258                        .to_string(),
259            },
260        );
261
262        // Activation Functions
263        operation_map.insert(
264            PyTorchOperation::GELU,
265            RealizarOperation {
266                code_template: "gelu(&{input})".to_string(),
267                imports: vec!["use realizar::layers::activations::gelu;".to_string()],
268                complexity: crate::backend::OpComplexity::Medium,
269                usage_pattern: "let activated = gelu(&input_tensor)?;".to_string(),
270            },
271        );
272
273        // Tokenization
274        operation_map.insert(
275            PyTorchOperation::Encode,
276            RealizarOperation {
277                code_template: "tokenizer.encode(\"{text}\")".to_string(),
278                imports: vec!["use realizar::tokenizer::Tokenizer;".to_string()],
279                complexity: crate::backend::OpComplexity::Low,
280                usage_pattern: "let tokens = tokenizer.encode(\"Hello, world!\")?;".to_string(),
281            },
282        );
283
284        operation_map.insert(
285            PyTorchOperation::Decode,
286            RealizarOperation {
287                code_template: "tokenizer.decode(&{tokens})".to_string(),
288                imports: vec!["use realizar::tokenizer::Tokenizer;".to_string()],
289                complexity: crate::backend::OpComplexity::Low,
290                usage_pattern: "let text = tokenizer.decode(&output_tokens)?;".to_string(),
291            },
292        );
293
294        Self { operation_map, backend_selector: crate::backend::BackendSelector::new() }
295    }
296
297    /// Convert a PyTorch operation to Realizar
298    pub fn convert(&self, operation: &PyTorchOperation) -> Option<&RealizarOperation> {
299        self.operation_map.get(operation)
300    }
301
302    /// Get recommended backend for an operation
303    pub fn recommend_backend(
304        &self,
305        operation: &PyTorchOperation,
306        data_size: usize,
307    ) -> crate::backend::Backend {
308        self.backend_selector.select_with_moe(operation.complexity(), data_size)
309    }
310
311    /// Get all available conversions
312    pub fn available_operations(&self) -> Vec<&PyTorchOperation> {
313        self.operation_map.keys().collect()
314    }
315
316    /// Generate conversion report
317    pub fn conversion_report(&self) -> String {
318        let mut report = String::from("PyTorch → Realizar Conversion Map\n");
319        report.push_str("====================================\n\n");
320
321        // Group by module
322        let mut by_module: HashMap<&str, Vec<(&PyTorchOperation, &RealizarOperation)>> =
323            HashMap::new();
324
325        for (op, realizar_op) in &self.operation_map {
326            by_module.entry(op.pytorch_module()).or_default().push((op, realizar_op));
327        }
328
329        for (module, operations) in &by_module {
330            report.push_str(&format!("## {}\n\n", module));
331
332            for (op, realizar_op) in operations {
333                report.push_str(&format!("{:?}:\n", op));
334                report.push_str(&format!("  Template: {}\n", realizar_op.code_template));
335                report.push_str(&format!("  Complexity: {:?}\n", realizar_op.complexity));
336                report.push_str(&format!("  Imports: {}\n", realizar_op.imports.join(", ")));
337                report.push_str(&format!(
338                    "  Usage:\n    {}\n\n",
339                    realizar_op.usage_pattern.replace('\n', "\n    ")
340                ));
341            }
342            report.push('\n');
343        }
344
345        report
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_converter_creation() {
355        let converter = PyTorchConverter::new();
356        assert!(!converter.available_operations().is_empty());
357    }
358
359    #[test]
360    fn test_operation_complexity() {
361        assert_eq!(
362            PyTorchOperation::TensorCreation.complexity(),
363            crate::backend::OpComplexity::Low
364        );
365        assert_eq!(PyTorchOperation::Linear.complexity(), crate::backend::OpComplexity::Medium);
366        assert_eq!(PyTorchOperation::Generate.complexity(), crate::backend::OpComplexity::High);
367    }
368
369    #[test]
370    fn test_load_model_conversion() {
371        let converter = PyTorchConverter::new();
372        let realizar_op =
373            converter.convert(&PyTorchOperation::LoadModel).expect("conversion failed");
374        assert!(realizar_op.code_template.contains("GGUFModel"));
375        assert!(realizar_op.imports.iter().any(|i| i.contains("gguf")));
376    }
377
378    #[test]
379    fn test_generate_conversion() {
380        let converter = PyTorchConverter::new();
381        let realizar_op =
382            converter.convert(&PyTorchOperation::Generate).expect("conversion failed");
383        assert!(realizar_op.code_template.contains("generate_text"));
384        assert!(realizar_op.imports.iter().any(|i| i.contains("generate")));
385    }
386
387    #[test]
388    fn test_backend_recommendation() {
389        let converter = PyTorchConverter::new();
390
391        // Small tensor operation should use Scalar
392        let backend = converter.recommend_backend(&PyTorchOperation::TensorCreation, 100);
393        assert_eq!(backend, crate::backend::Backend::Scalar);
394
395        // Medium-sized linear layer should use SIMD
396        let backend = converter.recommend_backend(&PyTorchOperation::Linear, 50_000);
397        assert_eq!(backend, crate::backend::Backend::SIMD);
398
399        // Large generation task should use GPU
400        let backend = converter.recommend_backend(&PyTorchOperation::Generate, 100_000);
401        assert_eq!(backend, crate::backend::Backend::GPU);
402    }
403
404    #[test]
405    fn test_pytorch_module_paths() {
406        assert_eq!(PyTorchOperation::LoadModel.pytorch_module(), "torch");
407        assert_eq!(PyTorchOperation::Linear.pytorch_module(), "torch.nn");
408        assert_eq!(PyTorchOperation::LoadTokenizer.pytorch_module(), "transformers");
409    }
410
411    #[test]
412    fn test_conversion_report() {
413        let converter = PyTorchConverter::new();
414        let report = converter.conversion_report();
415        assert!(report.contains("PyTorch → Realizar"));
416        assert!(report.contains("LoadModel"));
417        assert!(report.contains("Complexity"));
418    }
419
420    // ============================================================================
421    // PYTORCH OPERATION ENUM TESTS
422    // ============================================================================
423
424    #[test]
425    fn test_all_pytorch_operations_exist() {
426        // Test all 22 variants can be constructed
427        let ops = vec![
428            PyTorchOperation::LoadModel,
429            PyTorchOperation::SaveModel,
430            PyTorchOperation::LoadTokenizer,
431            PyTorchOperation::Forward,
432            PyTorchOperation::Generate,
433            PyTorchOperation::Predict,
434            PyTorchOperation::TensorCreation,
435            PyTorchOperation::TensorReshape,
436            PyTorchOperation::TensorSlice,
437            PyTorchOperation::Linear,
438            PyTorchOperation::Embedding,
439            PyTorchOperation::LayerNorm,
440            PyTorchOperation::Attention,
441            PyTorchOperation::ReLU,
442            PyTorchOperation::GELU,
443            PyTorchOperation::Softmax,
444            PyTorchOperation::Encode,
445            PyTorchOperation::Decode,
446            PyTorchOperation::NoGrad,
447            PyTorchOperation::Eval,
448        ];
449        assert_eq!(ops.len(), 20); // 20 operations tested
450    }
451
452    #[test]
453    fn test_operation_equality() {
454        assert_eq!(PyTorchOperation::LoadModel, PyTorchOperation::LoadModel);
455        assert_ne!(PyTorchOperation::LoadModel, PyTorchOperation::Generate);
456    }
457
458    #[test]
459    fn test_operation_clone() {
460        let op1 = PyTorchOperation::Forward;
461        let op2 = op1.clone();
462        assert_eq!(op1, op2);
463    }
464
465    #[test]
466    fn test_complexity_low_operations() {
467        let low_ops = vec![
468            PyTorchOperation::TensorCreation,
469            PyTorchOperation::TensorReshape,
470            PyTorchOperation::TensorSlice,
471            PyTorchOperation::Encode,
472            PyTorchOperation::Decode,
473            PyTorchOperation::NoGrad,
474            PyTorchOperation::Eval,
475        ];
476
477        for op in low_ops {
478            assert_eq!(op.complexity(), crate::backend::OpComplexity::Low);
479        }
480    }
481
482    #[test]
483    fn test_complexity_medium_operations() {
484        let medium_ops = vec![
485            PyTorchOperation::Linear,
486            PyTorchOperation::Embedding,
487            PyTorchOperation::LayerNorm,
488            PyTorchOperation::ReLU,
489            PyTorchOperation::GELU,
490            PyTorchOperation::Softmax,
491            PyTorchOperation::LoadModel,
492            PyTorchOperation::SaveModel,
493            PyTorchOperation::LoadTokenizer,
494        ];
495
496        for op in medium_ops {
497            assert_eq!(op.complexity(), crate::backend::OpComplexity::Medium);
498        }
499    }
500
501    #[test]
502    fn test_complexity_high_operations() {
503        let high_ops = vec![
504            PyTorchOperation::Forward,
505            PyTorchOperation::Generate,
506            PyTorchOperation::Predict,
507            PyTorchOperation::Attention,
508        ];
509
510        for op in high_ops {
511            assert_eq!(op.complexity(), crate::backend::OpComplexity::High);
512        }
513    }
514
515    #[test]
516    fn test_pytorch_module_torch() {
517        let torch_ops = vec![
518            PyTorchOperation::LoadModel,
519            PyTorchOperation::SaveModel,
520            PyTorchOperation::TensorCreation,
521            PyTorchOperation::TensorReshape,
522            PyTorchOperation::TensorSlice,
523            PyTorchOperation::NoGrad,
524        ];
525
526        for op in torch_ops {
527            assert_eq!(op.pytorch_module(), "torch");
528        }
529    }
530
531    #[test]
532    fn test_pytorch_module_torch_nn() {
533        let nn_ops = vec![
534            PyTorchOperation::Linear,
535            PyTorchOperation::Embedding,
536            PyTorchOperation::LayerNorm,
537            PyTorchOperation::Attention,
538            PyTorchOperation::ReLU,
539            PyTorchOperation::GELU,
540            PyTorchOperation::Softmax,
541        ];
542
543        for op in nn_ops {
544            assert_eq!(op.pytorch_module(), "torch.nn");
545        }
546    }
547
548    #[test]
549    fn test_pytorch_module_transformers() {
550        let transformers_ops = vec![
551            PyTorchOperation::LoadTokenizer,
552            PyTorchOperation::Encode,
553            PyTorchOperation::Decode,
554        ];
555
556        for op in transformers_ops {
557            assert_eq!(op.pytorch_module(), "transformers");
558        }
559    }
560
561    #[test]
562    fn test_pytorch_module_torch_nn_module() {
563        let module_ops = vec![
564            PyTorchOperation::Forward,
565            PyTorchOperation::Generate,
566            PyTorchOperation::Predict,
567            PyTorchOperation::Eval,
568        ];
569
570        for op in module_ops {
571            assert_eq!(op.pytorch_module(), "torch.nn.Module");
572        }
573    }
574
575    // ============================================================================
576    // REALIZAR OPERATION STRUCT TESTS
577    // ============================================================================
578
579    #[test]
580    fn test_realizar_operation_construction() {
581        let op = RealizarOperation {
582            code_template: "test_template".to_string(),
583            imports: vec!["use test;".to_string()],
584            complexity: crate::backend::OpComplexity::Medium,
585            usage_pattern: "let x = test();".to_string(),
586        };
587
588        assert_eq!(op.code_template, "test_template");
589        assert_eq!(op.imports.len(), 1);
590        assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
591        assert!(op.usage_pattern.contains("test()"));
592    }
593
594    #[test]
595    fn test_realizar_operation_clone() {
596        let op1 = RealizarOperation {
597            code_template: "template".to_string(),
598            imports: vec!["import".to_string()],
599            complexity: crate::backend::OpComplexity::High,
600            usage_pattern: "usage".to_string(),
601        };
602
603        let op2 = op1.clone();
604        assert_eq!(op1.code_template, op2.code_template);
605        assert_eq!(op1.imports, op2.imports);
606        assert_eq!(op1.complexity, op2.complexity);
607    }
608
609    // ============================================================================
610    // PYTORCH CONVERTER TESTS
611    // ============================================================================
612
613    #[test]
614    fn test_converter_default() {
615        let converter = PyTorchConverter::default();
616        assert!(!converter.available_operations().is_empty());
617    }
618
619    #[test]
620    fn test_convert_all_mapped_operations() {
621        let converter = PyTorchConverter::new();
622
623        // Test all operations that should have mappings
624        let mapped_ops = vec![
625            PyTorchOperation::LoadModel,
626            PyTorchOperation::LoadTokenizer,
627            PyTorchOperation::Forward,
628            PyTorchOperation::Generate,
629            PyTorchOperation::TensorCreation,
630            PyTorchOperation::Linear,
631            PyTorchOperation::Attention,
632            PyTorchOperation::GELU,
633            PyTorchOperation::Encode,
634            PyTorchOperation::Decode,
635        ];
636
637        for op in mapped_ops {
638            assert!(converter.convert(&op).is_some(), "Missing mapping for {:?}", op);
639        }
640    }
641
642    #[test]
643    fn test_convert_unmapped_operation() {
644        let converter = PyTorchConverter::new();
645
646        // SaveModel, Predict, etc. might not be mapped
647        // Just verify the function handles missing ops gracefully
648        let result = converter.convert(&PyTorchOperation::SaveModel);
649        // It's ok if this is None - we're testing the API works
650        let _ = result;
651    }
652
653    #[test]
654    fn test_forward_conversion() {
655        let converter = PyTorchConverter::new();
656        let op = converter.convert(&PyTorchOperation::Forward).expect("conversion failed");
657
658        assert!(op.code_template.contains("forward"));
659        assert!(op.imports.iter().any(|i| i.contains("gguf")));
660        assert_eq!(op.complexity, crate::backend::OpComplexity::High);
661    }
662
663    #[test]
664    fn test_tokenizer_conversion() {
665        let converter = PyTorchConverter::new();
666        let op = converter.convert(&PyTorchOperation::LoadTokenizer).expect("conversion failed");
667
668        assert!(op.code_template.contains("Tokenizer"));
669        assert!(op.imports.iter().any(|i| i.contains("tokenizer")));
670    }
671
672    #[test]
673    fn test_encode_decode_conversions() {
674        let converter = PyTorchConverter::new();
675
676        let encode_op = converter.convert(&PyTorchOperation::Encode).expect("conversion failed");
677        assert!(encode_op.code_template.contains("encode"));
678
679        let decode_op = converter.convert(&PyTorchOperation::Decode).expect("conversion failed");
680        assert!(decode_op.code_template.contains("decode"));
681    }
682
683    #[test]
684    fn test_tensor_operation_conversion() {
685        let converter = PyTorchConverter::new();
686        let op = converter.convert(&PyTorchOperation::TensorCreation).expect("unexpected failure");
687
688        assert!(op.code_template.contains("Tensor"));
689        assert!(op.imports.iter().any(|i| i.contains("tensor")));
690        assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
691    }
692
693    #[test]
694    fn test_linear_layer_conversion() {
695        let converter = PyTorchConverter::new();
696        let op = converter.convert(&PyTorchOperation::Linear).expect("conversion failed");
697
698        assert!(op.code_template.contains("LinearLayer"));
699        assert!(op.imports.iter().any(|i| i.contains("layers")));
700    }
701
702    #[test]
703    fn test_attention_layer_conversion() {
704        let converter = PyTorchConverter::new();
705        let op = converter.convert(&PyTorchOperation::Attention).expect("conversion failed");
706
707        assert!(op.code_template.contains("AttentionLayer"));
708        assert_eq!(op.complexity, crate::backend::OpComplexity::High);
709    }
710
711    #[test]
712    fn test_gelu_activation_conversion() {
713        let converter = PyTorchConverter::new();
714        let op = converter.convert(&PyTorchOperation::GELU).expect("conversion failed");
715
716        assert!(op.code_template.contains("gelu"));
717        assert!(op.imports.iter().any(|i| i.contains("activations")));
718    }
719
720    #[test]
721    fn test_available_operations() {
722        let converter = PyTorchConverter::new();
723        let ops = converter.available_operations();
724
725        assert!(!ops.is_empty());
726        // Should have at least the mapped operations
727        assert!(ops.len() >= 10);
728    }
729
730    #[test]
731    fn test_recommend_backend_low_complexity() {
732        let converter = PyTorchConverter::new();
733
734        // Small data size with low complexity should use Scalar
735        let backend = converter.recommend_backend(&PyTorchOperation::TensorCreation, 10);
736        assert_eq!(backend, crate::backend::Backend::Scalar);
737    }
738
739    #[test]
740    fn test_recommend_backend_medium_complexity() {
741        let converter = PyTorchConverter::new();
742
743        // Medium data size with medium complexity should use SIMD
744        let backend = converter.recommend_backend(&PyTorchOperation::Linear, 50_000);
745        assert_eq!(backend, crate::backend::Backend::SIMD);
746    }
747
748    #[test]
749    fn test_recommend_backend_high_complexity() {
750        let converter = PyTorchConverter::new();
751
752        // Large data size with high complexity should use GPU
753        let backend = converter.recommend_backend(&PyTorchOperation::Forward, 500_000);
754        assert_eq!(backend, crate::backend::Backend::GPU);
755    }
756
757    #[test]
758    fn test_recommend_backend_generation() {
759        let converter = PyTorchConverter::new();
760
761        // Generation is high complexity, large size should use GPU
762        let backend = converter.recommend_backend(&PyTorchOperation::Generate, 1_000_000);
763        assert_eq!(backend, crate::backend::Backend::GPU);
764    }
765
766    #[test]
767    fn test_conversion_report_structure() {
768        let converter = PyTorchConverter::new();
769        let report = converter.conversion_report();
770
771        // Check report contains expected sections
772        assert!(report.contains("PyTorch → Realizar"));
773        assert!(report.contains("===="));
774        assert!(report.contains("##")); // Module headers
775        assert!(report.contains("Template:"));
776        assert!(report.contains("Imports:"));
777        assert!(report.contains("Usage:"));
778    }
779
780    #[test]
781    fn test_conversion_report_has_modules() {
782        let converter = PyTorchConverter::new();
783        let report = converter.conversion_report();
784
785        // Should group by PyTorch modules
786        assert!(report.contains("torch") || report.contains("transformers"));
787    }
788
789    #[test]
790    fn test_conversion_report_has_all_operations() {
791        let converter = PyTorchConverter::new();
792        let report = converter.conversion_report();
793
794        // Spot check a few operations appear in report
795        assert!(
796            report.contains("LoadModel")
797                || report.contains("Generate")
798                || report.contains("Forward")
799        );
800    }
801
802    #[test]
803    fn test_usage_patterns_not_empty() {
804        let converter = PyTorchConverter::new();
805
806        for op in converter.available_operations() {
807            if let Some(realizar_op) = converter.convert(op) {
808                assert!(!realizar_op.usage_pattern.is_empty(), "Empty usage pattern for {:?}", op);
809                assert!(!realizar_op.code_template.is_empty(), "Empty code template for {:?}", op);
810                assert!(!realizar_op.imports.is_empty(), "Empty imports for {:?}", op);
811            }
812        }
813    }
814
815    #[test]
816    fn test_imports_are_valid_rust() {
817        let converter = PyTorchConverter::new();
818
819        for op in converter.available_operations() {
820            if let Some(realizar_op) = converter.convert(op) {
821                for import in &realizar_op.imports {
822                    assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
823                    assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
824                }
825            }
826        }
827    }
828}