Skip to main content

batuta/
numpy_converter.rs

1//! NumPy to Trueno conversion module (BATUTA-008)
2//!
3//! Converts Python NumPy operations to Rust Trueno operations with
4//! automatic backend selection via MoE routing.
5//!
6//! # Conversion Strategy
7//!
8//! NumPy operations are mapped to equivalent Trueno operations:
9//! - `np.array(...)` → `Vector::from_slice(...)` or `Matrix::from_slice(...)`
10//! - `np.add(a, b)` → `a.add(&b)`
11//! - `np.dot(a, b)` → `a.dot(&b)` or `a.matmul(&b)`
12//! - `np.sum(a)` → `a.sum()`
13//! - Element-wise ops automatically use MoE routing
14//!
15//! # Example
16//!
17//! ```python
18//! # Python NumPy code
19//! import numpy as np
20//! a = np.array([1.0, 2.0, 3.0])
21//! b = np.array([4.0, 5.0, 6.0])
22//! c = np.add(a, b)
23//! ```
24//!
25//! Converts to:
26//!
27//! ```rust,ignore
28//! use trueno::Vector;
29//! let a = Vector::from_slice(&[1.0, 2.0, 3.0]);
30//! let b = Vector::from_slice(&[4.0, 5.0, 6.0]);
31//! let c = a.add(&b).unwrap();
32//! ```
33
34use std::collections::HashMap;
35
36/// NumPy operation types
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum NumPyOp {
39    /// Array creation: np.array, np.zeros, np.ones
40    Array,
41    /// Element-wise addition: np.add, a + b
42    Add,
43    /// Element-wise subtraction: np.subtract, a - b
44    Subtract,
45    /// Element-wise multiplication: np.multiply, a * b
46    Multiply,
47    /// Element-wise division: np.divide, a / b
48    Divide,
49    /// Dot product / matrix multiply: np.dot, np.matmul, a @ b
50    Dot,
51    /// Sum reduction: np.sum
52    Sum,
53    /// Mean reduction: np.mean
54    Mean,
55    /// Max reduction: np.max
56    Max,
57    /// Min reduction: np.min
58    Min,
59    /// Reshape: np.reshape
60    Reshape,
61    /// Transpose: np.transpose, a.T
62    Transpose,
63}
64
65impl NumPyOp {
66    /// Get the operation complexity for MoE routing
67    pub fn complexity(&self) -> crate::backend::OpComplexity {
68        use crate::backend::OpComplexity;
69
70        match self {
71            // Element-wise operations are Low complexity (memory-bound)
72            NumPyOp::Add | NumPyOp::Subtract | NumPyOp::Multiply | NumPyOp::Divide => {
73                OpComplexity::Low
74            }
75            // Reductions are Medium complexity
76            NumPyOp::Sum | NumPyOp::Mean | NumPyOp::Max | NumPyOp::Min => OpComplexity::Medium,
77            // Dot product and matrix ops are High complexity
78            NumPyOp::Dot => OpComplexity::High,
79            // Structural operations don't need backend selection
80            NumPyOp::Array | NumPyOp::Reshape | NumPyOp::Transpose => OpComplexity::Low,
81        }
82    }
83}
84
85/// Trueno equivalent operation
86#[derive(Debug, Clone)]
87pub struct TruenoOp {
88    /// Rust code template for the operation
89    pub code_template: String,
90    /// Required imports
91    pub imports: Vec<String>,
92    /// Operation complexity
93    pub complexity: crate::backend::OpComplexity,
94}
95
96/// NumPy to Trueno converter
97pub struct NumPyConverter {
98    /// Operation mapping
99    op_map: HashMap<NumPyOp, TruenoOp>,
100    /// Backend selector for MoE routing
101    backend_selector: crate::backend::BackendSelector,
102}
103
104impl Default for NumPyConverter {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl NumPyConverter {
111    /// Create a new NumPy converter with default mappings
112    pub fn new() -> Self {
113        let mut op_map = HashMap::new();
114
115        // Array operations
116        op_map.insert(
117            NumPyOp::Array,
118            TruenoOp {
119                code_template: "Vector::from_slice(&[{values}])".to_string(),
120                imports: vec!["use trueno::Vector;".to_string()],
121                complexity: crate::backend::OpComplexity::Low,
122            },
123        );
124
125        // Element-wise operations
126        op_map.insert(
127            NumPyOp::Add,
128            TruenoOp {
129                code_template: "{lhs}.add(&{rhs}).unwrap()".to_string(),
130                imports: vec!["use trueno::Vector;".to_string()],
131                complexity: crate::backend::OpComplexity::Low,
132            },
133        );
134
135        op_map.insert(
136            NumPyOp::Subtract,
137            TruenoOp {
138                code_template: "{lhs}.sub(&{rhs}).unwrap()".to_string(),
139                imports: vec!["use trueno::Vector;".to_string()],
140                complexity: crate::backend::OpComplexity::Low,
141            },
142        );
143
144        op_map.insert(
145            NumPyOp::Multiply,
146            TruenoOp {
147                code_template: "{lhs}.mul(&{rhs}).unwrap()".to_string(),
148                imports: vec!["use trueno::Vector;".to_string()],
149                complexity: crate::backend::OpComplexity::Low,
150            },
151        );
152
153        // Reductions
154        op_map.insert(
155            NumPyOp::Sum,
156            TruenoOp {
157                code_template: "{array}.sum()".to_string(),
158                imports: vec!["use trueno::Vector;".to_string()],
159                complexity: crate::backend::OpComplexity::Medium,
160            },
161        );
162
163        op_map.insert(
164            NumPyOp::Dot,
165            TruenoOp {
166                code_template: "{lhs}.dot(&{rhs}).unwrap()".to_string(),
167                imports: vec!["use trueno::Vector;".to_string()],
168                complexity: crate::backend::OpComplexity::High,
169            },
170        );
171
172        Self { op_map, backend_selector: crate::backend::BackendSelector::new() }
173    }
174
175    /// Convert a NumPy operation to Trueno
176    pub fn convert(&self, op: &NumPyOp) -> Option<&TruenoOp> {
177        self.op_map.get(op)
178    }
179
180    /// Get recommended backend for an operation
181    pub fn recommend_backend(&self, op: &NumPyOp, data_size: usize) -> crate::backend::Backend {
182        self.backend_selector.select_with_moe(op.complexity(), data_size)
183    }
184
185    /// Get all available conversions
186    pub fn available_ops(&self) -> Vec<&NumPyOp> {
187        self.op_map.keys().collect()
188    }
189
190    /// Generate conversion report
191    pub fn conversion_report(&self) -> String {
192        let mut report = String::from("NumPy → Trueno Conversion Map\n");
193        report.push_str("================================\n\n");
194
195        for (op, trueno_op) in &self.op_map {
196            report.push_str(&format!("{:?}:\n", op));
197            report.push_str(&format!("  Complexity: {:?}\n", trueno_op.complexity));
198            report.push_str(&format!("  Template: {}\n", trueno_op.code_template));
199            report.push_str(&format!("  Imports: {}\n\n", trueno_op.imports.join(", ")));
200        }
201
202        report
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn test_converter_creation() {
212        let converter = NumPyConverter::new();
213        assert!(!converter.available_ops().is_empty());
214    }
215
216    #[test]
217    fn test_operation_complexity() {
218        assert_eq!(NumPyOp::Add.complexity(), crate::backend::OpComplexity::Low);
219        assert_eq!(NumPyOp::Sum.complexity(), crate::backend::OpComplexity::Medium);
220        assert_eq!(NumPyOp::Dot.complexity(), crate::backend::OpComplexity::High);
221    }
222
223    #[test]
224    fn test_add_conversion() {
225        let converter = NumPyConverter::new();
226        let trueno_op = converter.convert(&NumPyOp::Add).expect("conversion failed");
227        assert!(trueno_op.code_template.contains("add"));
228        assert!(trueno_op.imports.iter().any(|i| i.contains("Vector")));
229    }
230
231    #[test]
232    fn test_backend_recommendation() {
233        let converter = NumPyConverter::new();
234
235        // Small element-wise operation should use Scalar
236        let backend = converter.recommend_backend(&NumPyOp::Add, 100);
237        assert_eq!(backend, crate::backend::Backend::Scalar);
238
239        // Large element-wise should use SIMD
240        let backend = converter.recommend_backend(&NumPyOp::Add, 2_000_000);
241        assert_eq!(backend, crate::backend::Backend::SIMD);
242
243        // Large matrix operation should use GPU
244        let backend = converter.recommend_backend(&NumPyOp::Dot, 50_000);
245        assert_eq!(backend, crate::backend::Backend::GPU);
246    }
247
248    #[test]
249    fn test_conversion_report() {
250        let converter = NumPyConverter::new();
251        let report = converter.conversion_report();
252        assert!(report.contains("NumPy → Trueno"));
253        assert!(report.contains("Add"));
254        assert!(report.contains("Complexity"));
255    }
256
257    // ============================================================================
258    // NUMPY OP ENUM TESTS
259    // ============================================================================
260
261    #[test]
262    fn test_all_numpy_ops_exist() {
263        // Test all 13 variants can be constructed
264        let ops = vec![
265            NumPyOp::Array,
266            NumPyOp::Add,
267            NumPyOp::Subtract,
268            NumPyOp::Multiply,
269            NumPyOp::Divide,
270            NumPyOp::Dot,
271            NumPyOp::Sum,
272            NumPyOp::Mean,
273            NumPyOp::Max,
274            NumPyOp::Min,
275            NumPyOp::Reshape,
276            NumPyOp::Transpose,
277        ];
278        assert_eq!(ops.len(), 12); // 12 operations tested
279    }
280
281    #[test]
282    fn test_op_equality() {
283        assert_eq!(NumPyOp::Add, NumPyOp::Add);
284        assert_ne!(NumPyOp::Add, NumPyOp::Multiply);
285    }
286
287    #[test]
288    fn test_op_clone() {
289        let op1 = NumPyOp::Dot;
290        let op2 = op1.clone();
291        assert_eq!(op1, op2);
292    }
293
294    #[test]
295    fn test_complexity_low_ops() {
296        let low_ops = vec![
297            NumPyOp::Add,
298            NumPyOp::Subtract,
299            NumPyOp::Multiply,
300            NumPyOp::Divide,
301            NumPyOp::Array,
302            NumPyOp::Reshape,
303            NumPyOp::Transpose,
304        ];
305
306        for op in low_ops {
307            assert_eq!(op.complexity(), crate::backend::OpComplexity::Low);
308        }
309    }
310
311    #[test]
312    fn test_complexity_medium_ops() {
313        let medium_ops = vec![NumPyOp::Sum, NumPyOp::Mean, NumPyOp::Max, NumPyOp::Min];
314
315        for op in medium_ops {
316            assert_eq!(op.complexity(), crate::backend::OpComplexity::Medium);
317        }
318    }
319
320    #[test]
321    fn test_complexity_high_ops() {
322        let high_ops = vec![NumPyOp::Dot];
323
324        for op in high_ops {
325            assert_eq!(op.complexity(), crate::backend::OpComplexity::High);
326        }
327    }
328
329    // ============================================================================
330    // TRUENO OP STRUCT TESTS
331    // ============================================================================
332
333    #[test]
334    fn test_trueno_op_construction() {
335        let op = TruenoOp {
336            code_template: "test_template".to_string(),
337            imports: vec!["use test;".to_string()],
338            complexity: crate::backend::OpComplexity::Medium,
339        };
340
341        assert_eq!(op.code_template, "test_template");
342        assert_eq!(op.imports.len(), 1);
343        assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
344    }
345
346    #[test]
347    fn test_trueno_op_clone() {
348        let op1 = TruenoOp {
349            code_template: "template".to_string(),
350            imports: vec!["import".to_string()],
351            complexity: crate::backend::OpComplexity::High,
352        };
353
354        let op2 = op1.clone();
355        assert_eq!(op1.code_template, op2.code_template);
356        assert_eq!(op1.imports, op2.imports);
357        assert_eq!(op1.complexity, op2.complexity);
358    }
359
360    // ============================================================================
361    // NUMPY CONVERTER TESTS
362    // ============================================================================
363
364    #[test]
365    fn test_converter_default() {
366        let converter = NumPyConverter::default();
367        assert!(!converter.available_ops().is_empty());
368    }
369
370    #[test]
371    fn test_convert_all_mapped_ops() {
372        let converter = NumPyConverter::new();
373
374        // Test all operations that should have mappings
375        let mapped_ops = vec![
376            NumPyOp::Array,
377            NumPyOp::Add,
378            NumPyOp::Subtract,
379            NumPyOp::Multiply,
380            NumPyOp::Sum,
381            NumPyOp::Dot,
382        ];
383
384        for op in mapped_ops {
385            assert!(converter.convert(&op).is_some(), "Missing mapping for {:?}", op);
386        }
387    }
388
389    #[test]
390    fn test_convert_unmapped_op() {
391        let converter = NumPyConverter::new();
392
393        // Divide, Mean, etc. might not be mapped
394        // Just verify the function handles missing ops gracefully
395        let result = converter.convert(&NumPyOp::Divide);
396        // It's ok if this is None - we're testing the API works
397        let _ = result;
398    }
399
400    #[test]
401    fn test_array_conversion() {
402        let converter = NumPyConverter::new();
403        let op = converter.convert(&NumPyOp::Array).expect("conversion failed");
404
405        assert!(op.code_template.contains("Vector"));
406        assert!(op.code_template.contains("from_slice"));
407        assert!(op.imports.iter().any(|i| i.contains("Vector")));
408        assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
409    }
410
411    #[test]
412    fn test_subtract_conversion() {
413        let converter = NumPyConverter::new();
414        let op = converter.convert(&NumPyOp::Subtract).expect("conversion failed");
415
416        assert!(op.code_template.contains("sub"));
417        assert!(op.imports.iter().any(|i| i.contains("Vector")));
418        assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
419    }
420
421    #[test]
422    fn test_multiply_conversion() {
423        let converter = NumPyConverter::new();
424        let op = converter.convert(&NumPyOp::Multiply).expect("conversion failed");
425
426        assert!(op.code_template.contains("mul"));
427        assert!(op.imports.iter().any(|i| i.contains("Vector")));
428    }
429
430    #[test]
431    fn test_sum_conversion() {
432        let converter = NumPyConverter::new();
433        let op = converter.convert(&NumPyOp::Sum).expect("conversion failed");
434
435        assert!(op.code_template.contains("sum"));
436        assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
437    }
438
439    #[test]
440    fn test_dot_conversion() {
441        let converter = NumPyConverter::new();
442        let op = converter.convert(&NumPyOp::Dot).expect("conversion failed");
443
444        assert!(op.code_template.contains("dot"));
445        assert_eq!(op.complexity, crate::backend::OpComplexity::High);
446    }
447
448    #[test]
449    fn test_available_ops() {
450        let converter = NumPyConverter::new();
451        let ops = converter.available_ops();
452
453        assert!(!ops.is_empty());
454        // Should have at least the mapped operations
455        assert!(ops.len() >= 6);
456    }
457
458    #[test]
459    fn test_recommend_backend_element_wise_small() {
460        let converter = NumPyConverter::new();
461
462        // Small element-wise operations should use Scalar
463        let backend = converter.recommend_backend(&NumPyOp::Add, 10);
464        assert_eq!(backend, crate::backend::Backend::Scalar);
465    }
466
467    #[test]
468    fn test_recommend_backend_element_wise_large() {
469        let converter = NumPyConverter::new();
470
471        // Large element-wise operations should use SIMD
472        let backend = converter.recommend_backend(&NumPyOp::Multiply, 2_000_000);
473        assert_eq!(backend, crate::backend::Backend::SIMD);
474    }
475
476    #[test]
477    fn test_recommend_backend_reduction_medium() {
478        let converter = NumPyConverter::new();
479
480        // Medium-sized reductions should use SIMD
481        let backend = converter.recommend_backend(&NumPyOp::Sum, 50_000);
482        assert_eq!(backend, crate::backend::Backend::SIMD);
483    }
484
485    #[test]
486    fn test_recommend_backend_reduction_large() {
487        let converter = NumPyConverter::new();
488
489        // Large reductions should use GPU
490        let backend = converter.recommend_backend(&NumPyOp::Sum, 500_000);
491        assert_eq!(backend, crate::backend::Backend::GPU);
492    }
493
494    #[test]
495    fn test_recommend_backend_dot_product() {
496        let converter = NumPyConverter::new();
497
498        // Dot product with large data should use GPU
499        let backend = converter.recommend_backend(&NumPyOp::Dot, 100_000);
500        assert_eq!(backend, crate::backend::Backend::GPU);
501    }
502
503    #[test]
504    fn test_conversion_report_structure() {
505        let converter = NumPyConverter::new();
506        let report = converter.conversion_report();
507
508        // Check report contains expected sections
509        assert!(report.contains("NumPy → Trueno"));
510        assert!(report.contains("==="));
511        assert!(report.contains("Complexity:"));
512        assert!(report.contains("Template:"));
513        assert!(report.contains("Imports:"));
514    }
515
516    #[test]
517    fn test_conversion_report_has_all_ops() {
518        let converter = NumPyConverter::new();
519        let report = converter.conversion_report();
520
521        // Spot check a few operations appear in report
522        assert!(report.contains("Add") || report.contains("Sum") || report.contains("Dot"));
523    }
524
525    #[test]
526    fn test_all_conversions_not_empty() {
527        let converter = NumPyConverter::new();
528
529        for op in converter.available_ops() {
530            if let Some(trueno_op) = converter.convert(op) {
531                assert!(!trueno_op.code_template.is_empty(), "Empty code template for {:?}", op);
532                assert!(!trueno_op.imports.is_empty(), "Empty imports for {:?}", op);
533            }
534        }
535    }
536
537    #[test]
538    fn test_imports_are_valid_rust() {
539        let converter = NumPyConverter::new();
540
541        for op in converter.available_ops() {
542            if let Some(trueno_op) = converter.convert(op) {
543                for import in &trueno_op.imports {
544                    assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
545                    assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
546                }
547            }
548        }
549    }
550
551    #[test]
552    fn test_all_ops_use_vector_import() {
553        let converter = NumPyConverter::new();
554
555        for op in converter.available_ops() {
556            if let Some(trueno_op) = converter.convert(op) {
557                assert!(
558                    trueno_op.imports.iter().any(|i| i.contains("Vector")),
559                    "Operation {:?} should import Vector",
560                    op
561                );
562            }
563        }
564    }
565
566    #[test]
567    fn test_element_wise_ops_have_unwrap() {
568        let converter = NumPyConverter::new();
569
570        let element_wise = vec![NumPyOp::Add, NumPyOp::Subtract, NumPyOp::Multiply];
571
572        for op in element_wise {
573            if let Some(trueno_op) = converter.convert(&op) {
574                assert!(
575                    trueno_op.code_template.contains("unwrap"),
576                    "Element-wise op {:?} should have unwrap() for error handling",
577                    op
578                );
579            }
580        }
581    }
582
583    #[test]
584    fn test_complexity_matches_enum() {
585        let converter = NumPyConverter::new();
586
587        // Test that TruenoOp complexity matches NumPyOp complexity
588        if let Some(add_op) = converter.convert(&NumPyOp::Add) {
589            assert_eq!(add_op.complexity, NumPyOp::Add.complexity());
590        }
591
592        if let Some(sum_op) = converter.convert(&NumPyOp::Sum) {
593            assert_eq!(sum_op.complexity, NumPyOp::Sum.complexity());
594        }
595
596        if let Some(dot_op) = converter.convert(&NumPyOp::Dot) {
597            assert_eq!(dot_op.complexity, NumPyOp::Dot.complexity());
598        }
599    }
600}