1use std::collections::HashMap;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum NumPyOp {
39 Array,
41 Add,
43 Subtract,
45 Multiply,
47 Divide,
49 Dot,
51 Sum,
53 Mean,
55 Max,
57 Min,
59 Reshape,
61 Transpose,
63}
64
65impl NumPyOp {
66 pub fn complexity(&self) -> crate::backend::OpComplexity {
68 use crate::backend::OpComplexity;
69
70 match self {
71 NumPyOp::Add | NumPyOp::Subtract | NumPyOp::Multiply | NumPyOp::Divide => {
73 OpComplexity::Low
74 }
75 NumPyOp::Sum | NumPyOp::Mean | NumPyOp::Max | NumPyOp::Min => OpComplexity::Medium,
77 NumPyOp::Dot => OpComplexity::High,
79 NumPyOp::Array | NumPyOp::Reshape | NumPyOp::Transpose => OpComplexity::Low,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct TruenoOp {
88 pub code_template: String,
90 pub imports: Vec<String>,
92 pub complexity: crate::backend::OpComplexity,
94}
95
96pub struct NumPyConverter {
98 op_map: HashMap<NumPyOp, TruenoOp>,
100 backend_selector: crate::backend::BackendSelector,
102}
103
104impl Default for NumPyConverter {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl NumPyConverter {
111 pub fn new() -> Self {
113 let mut op_map = HashMap::new();
114
115 op_map.insert(
117 NumPyOp::Array,
118 TruenoOp {
119 code_template: "Vector::from_slice(&[{values}])".to_string(),
120 imports: vec!["use trueno::Vector;".to_string()],
121 complexity: crate::backend::OpComplexity::Low,
122 },
123 );
124
125 op_map.insert(
127 NumPyOp::Add,
128 TruenoOp {
129 code_template: "{lhs}.add(&{rhs}).unwrap()".to_string(),
130 imports: vec!["use trueno::Vector;".to_string()],
131 complexity: crate::backend::OpComplexity::Low,
132 },
133 );
134
135 op_map.insert(
136 NumPyOp::Subtract,
137 TruenoOp {
138 code_template: "{lhs}.sub(&{rhs}).unwrap()".to_string(),
139 imports: vec!["use trueno::Vector;".to_string()],
140 complexity: crate::backend::OpComplexity::Low,
141 },
142 );
143
144 op_map.insert(
145 NumPyOp::Multiply,
146 TruenoOp {
147 code_template: "{lhs}.mul(&{rhs}).unwrap()".to_string(),
148 imports: vec!["use trueno::Vector;".to_string()],
149 complexity: crate::backend::OpComplexity::Low,
150 },
151 );
152
153 op_map.insert(
155 NumPyOp::Sum,
156 TruenoOp {
157 code_template: "{array}.sum()".to_string(),
158 imports: vec!["use trueno::Vector;".to_string()],
159 complexity: crate::backend::OpComplexity::Medium,
160 },
161 );
162
163 op_map.insert(
164 NumPyOp::Dot,
165 TruenoOp {
166 code_template: "{lhs}.dot(&{rhs}).unwrap()".to_string(),
167 imports: vec!["use trueno::Vector;".to_string()],
168 complexity: crate::backend::OpComplexity::High,
169 },
170 );
171
172 Self { op_map, backend_selector: crate::backend::BackendSelector::new() }
173 }
174
175 pub fn convert(&self, op: &NumPyOp) -> Option<&TruenoOp> {
177 self.op_map.get(op)
178 }
179
180 pub fn recommend_backend(&self, op: &NumPyOp, data_size: usize) -> crate::backend::Backend {
182 self.backend_selector.select_with_moe(op.complexity(), data_size)
183 }
184
185 pub fn available_ops(&self) -> Vec<&NumPyOp> {
187 self.op_map.keys().collect()
188 }
189
190 pub fn conversion_report(&self) -> String {
192 let mut report = String::from("NumPy → Trueno Conversion Map\n");
193 report.push_str("================================\n\n");
194
195 for (op, trueno_op) in &self.op_map {
196 report.push_str(&format!("{:?}:\n", op));
197 report.push_str(&format!(" Complexity: {:?}\n", trueno_op.complexity));
198 report.push_str(&format!(" Template: {}\n", trueno_op.code_template));
199 report.push_str(&format!(" Imports: {}\n\n", trueno_op.imports.join(", ")));
200 }
201
202 report
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn test_converter_creation() {
212 let converter = NumPyConverter::new();
213 assert!(!converter.available_ops().is_empty());
214 }
215
216 #[test]
217 fn test_operation_complexity() {
218 assert_eq!(NumPyOp::Add.complexity(), crate::backend::OpComplexity::Low);
219 assert_eq!(NumPyOp::Sum.complexity(), crate::backend::OpComplexity::Medium);
220 assert_eq!(NumPyOp::Dot.complexity(), crate::backend::OpComplexity::High);
221 }
222
223 #[test]
224 fn test_add_conversion() {
225 let converter = NumPyConverter::new();
226 let trueno_op = converter.convert(&NumPyOp::Add).expect("conversion failed");
227 assert!(trueno_op.code_template.contains("add"));
228 assert!(trueno_op.imports.iter().any(|i| i.contains("Vector")));
229 }
230
231 #[test]
232 fn test_backend_recommendation() {
233 let converter = NumPyConverter::new();
234
235 let backend = converter.recommend_backend(&NumPyOp::Add, 100);
237 assert_eq!(backend, crate::backend::Backend::Scalar);
238
239 let backend = converter.recommend_backend(&NumPyOp::Add, 2_000_000);
241 assert_eq!(backend, crate::backend::Backend::SIMD);
242
243 let backend = converter.recommend_backend(&NumPyOp::Dot, 50_000);
245 assert_eq!(backend, crate::backend::Backend::GPU);
246 }
247
248 #[test]
249 fn test_conversion_report() {
250 let converter = NumPyConverter::new();
251 let report = converter.conversion_report();
252 assert!(report.contains("NumPy → Trueno"));
253 assert!(report.contains("Add"));
254 assert!(report.contains("Complexity"));
255 }
256
257 #[test]
262 fn test_all_numpy_ops_exist() {
263 let ops = vec![
265 NumPyOp::Array,
266 NumPyOp::Add,
267 NumPyOp::Subtract,
268 NumPyOp::Multiply,
269 NumPyOp::Divide,
270 NumPyOp::Dot,
271 NumPyOp::Sum,
272 NumPyOp::Mean,
273 NumPyOp::Max,
274 NumPyOp::Min,
275 NumPyOp::Reshape,
276 NumPyOp::Transpose,
277 ];
278 assert_eq!(ops.len(), 12); }
280
281 #[test]
282 fn test_op_equality() {
283 assert_eq!(NumPyOp::Add, NumPyOp::Add);
284 assert_ne!(NumPyOp::Add, NumPyOp::Multiply);
285 }
286
287 #[test]
288 fn test_op_clone() {
289 let op1 = NumPyOp::Dot;
290 let op2 = op1.clone();
291 assert_eq!(op1, op2);
292 }
293
294 #[test]
295 fn test_complexity_low_ops() {
296 let low_ops = vec![
297 NumPyOp::Add,
298 NumPyOp::Subtract,
299 NumPyOp::Multiply,
300 NumPyOp::Divide,
301 NumPyOp::Array,
302 NumPyOp::Reshape,
303 NumPyOp::Transpose,
304 ];
305
306 for op in low_ops {
307 assert_eq!(op.complexity(), crate::backend::OpComplexity::Low);
308 }
309 }
310
311 #[test]
312 fn test_complexity_medium_ops() {
313 let medium_ops = vec![NumPyOp::Sum, NumPyOp::Mean, NumPyOp::Max, NumPyOp::Min];
314
315 for op in medium_ops {
316 assert_eq!(op.complexity(), crate::backend::OpComplexity::Medium);
317 }
318 }
319
320 #[test]
321 fn test_complexity_high_ops() {
322 let high_ops = vec![NumPyOp::Dot];
323
324 for op in high_ops {
325 assert_eq!(op.complexity(), crate::backend::OpComplexity::High);
326 }
327 }
328
329 #[test]
334 fn test_trueno_op_construction() {
335 let op = TruenoOp {
336 code_template: "test_template".to_string(),
337 imports: vec!["use test;".to_string()],
338 complexity: crate::backend::OpComplexity::Medium,
339 };
340
341 assert_eq!(op.code_template, "test_template");
342 assert_eq!(op.imports.len(), 1);
343 assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
344 }
345
346 #[test]
347 fn test_trueno_op_clone() {
348 let op1 = TruenoOp {
349 code_template: "template".to_string(),
350 imports: vec!["import".to_string()],
351 complexity: crate::backend::OpComplexity::High,
352 };
353
354 let op2 = op1.clone();
355 assert_eq!(op1.code_template, op2.code_template);
356 assert_eq!(op1.imports, op2.imports);
357 assert_eq!(op1.complexity, op2.complexity);
358 }
359
360 #[test]
365 fn test_converter_default() {
366 let converter = NumPyConverter::default();
367 assert!(!converter.available_ops().is_empty());
368 }
369
370 #[test]
371 fn test_convert_all_mapped_ops() {
372 let converter = NumPyConverter::new();
373
374 let mapped_ops = vec![
376 NumPyOp::Array,
377 NumPyOp::Add,
378 NumPyOp::Subtract,
379 NumPyOp::Multiply,
380 NumPyOp::Sum,
381 NumPyOp::Dot,
382 ];
383
384 for op in mapped_ops {
385 assert!(converter.convert(&op).is_some(), "Missing mapping for {:?}", op);
386 }
387 }
388
389 #[test]
390 fn test_convert_unmapped_op() {
391 let converter = NumPyConverter::new();
392
393 let result = converter.convert(&NumPyOp::Divide);
396 let _ = result;
398 }
399
400 #[test]
401 fn test_array_conversion() {
402 let converter = NumPyConverter::new();
403 let op = converter.convert(&NumPyOp::Array).expect("conversion failed");
404
405 assert!(op.code_template.contains("Vector"));
406 assert!(op.code_template.contains("from_slice"));
407 assert!(op.imports.iter().any(|i| i.contains("Vector")));
408 assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
409 }
410
411 #[test]
412 fn test_subtract_conversion() {
413 let converter = NumPyConverter::new();
414 let op = converter.convert(&NumPyOp::Subtract).expect("conversion failed");
415
416 assert!(op.code_template.contains("sub"));
417 assert!(op.imports.iter().any(|i| i.contains("Vector")));
418 assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
419 }
420
421 #[test]
422 fn test_multiply_conversion() {
423 let converter = NumPyConverter::new();
424 let op = converter.convert(&NumPyOp::Multiply).expect("conversion failed");
425
426 assert!(op.code_template.contains("mul"));
427 assert!(op.imports.iter().any(|i| i.contains("Vector")));
428 }
429
430 #[test]
431 fn test_sum_conversion() {
432 let converter = NumPyConverter::new();
433 let op = converter.convert(&NumPyOp::Sum).expect("conversion failed");
434
435 assert!(op.code_template.contains("sum"));
436 assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
437 }
438
439 #[test]
440 fn test_dot_conversion() {
441 let converter = NumPyConverter::new();
442 let op = converter.convert(&NumPyOp::Dot).expect("conversion failed");
443
444 assert!(op.code_template.contains("dot"));
445 assert_eq!(op.complexity, crate::backend::OpComplexity::High);
446 }
447
448 #[test]
449 fn test_available_ops() {
450 let converter = NumPyConverter::new();
451 let ops = converter.available_ops();
452
453 assert!(!ops.is_empty());
454 assert!(ops.len() >= 6);
456 }
457
458 #[test]
459 fn test_recommend_backend_element_wise_small() {
460 let converter = NumPyConverter::new();
461
462 let backend = converter.recommend_backend(&NumPyOp::Add, 10);
464 assert_eq!(backend, crate::backend::Backend::Scalar);
465 }
466
467 #[test]
468 fn test_recommend_backend_element_wise_large() {
469 let converter = NumPyConverter::new();
470
471 let backend = converter.recommend_backend(&NumPyOp::Multiply, 2_000_000);
473 assert_eq!(backend, crate::backend::Backend::SIMD);
474 }
475
476 #[test]
477 fn test_recommend_backend_reduction_medium() {
478 let converter = NumPyConverter::new();
479
480 let backend = converter.recommend_backend(&NumPyOp::Sum, 50_000);
482 assert_eq!(backend, crate::backend::Backend::SIMD);
483 }
484
485 #[test]
486 fn test_recommend_backend_reduction_large() {
487 let converter = NumPyConverter::new();
488
489 let backend = converter.recommend_backend(&NumPyOp::Sum, 500_000);
491 assert_eq!(backend, crate::backend::Backend::GPU);
492 }
493
494 #[test]
495 fn test_recommend_backend_dot_product() {
496 let converter = NumPyConverter::new();
497
498 let backend = converter.recommend_backend(&NumPyOp::Dot, 100_000);
500 assert_eq!(backend, crate::backend::Backend::GPU);
501 }
502
503 #[test]
504 fn test_conversion_report_structure() {
505 let converter = NumPyConverter::new();
506 let report = converter.conversion_report();
507
508 assert!(report.contains("NumPy → Trueno"));
510 assert!(report.contains("==="));
511 assert!(report.contains("Complexity:"));
512 assert!(report.contains("Template:"));
513 assert!(report.contains("Imports:"));
514 }
515
516 #[test]
517 fn test_conversion_report_has_all_ops() {
518 let converter = NumPyConverter::new();
519 let report = converter.conversion_report();
520
521 assert!(report.contains("Add") || report.contains("Sum") || report.contains("Dot"));
523 }
524
525 #[test]
526 fn test_all_conversions_not_empty() {
527 let converter = NumPyConverter::new();
528
529 for op in converter.available_ops() {
530 if let Some(trueno_op) = converter.convert(op) {
531 assert!(!trueno_op.code_template.is_empty(), "Empty code template for {:?}", op);
532 assert!(!trueno_op.imports.is_empty(), "Empty imports for {:?}", op);
533 }
534 }
535 }
536
537 #[test]
538 fn test_imports_are_valid_rust() {
539 let converter = NumPyConverter::new();
540
541 for op in converter.available_ops() {
542 if let Some(trueno_op) = converter.convert(op) {
543 for import in &trueno_op.imports {
544 assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
545 assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
546 }
547 }
548 }
549 }
550
551 #[test]
552 fn test_all_ops_use_vector_import() {
553 let converter = NumPyConverter::new();
554
555 for op in converter.available_ops() {
556 if let Some(trueno_op) = converter.convert(op) {
557 assert!(
558 trueno_op.imports.iter().any(|i| i.contains("Vector")),
559 "Operation {:?} should import Vector",
560 op
561 );
562 }
563 }
564 }
565
566 #[test]
567 fn test_element_wise_ops_have_unwrap() {
568 let converter = NumPyConverter::new();
569
570 let element_wise = vec![NumPyOp::Add, NumPyOp::Subtract, NumPyOp::Multiply];
571
572 for op in element_wise {
573 if let Some(trueno_op) = converter.convert(&op) {
574 assert!(
575 trueno_op.code_template.contains("unwrap"),
576 "Element-wise op {:?} should have unwrap() for error handling",
577 op
578 );
579 }
580 }
581 }
582
583 #[test]
584 fn test_complexity_matches_enum() {
585 let converter = NumPyConverter::new();
586
587 if let Some(add_op) = converter.convert(&NumPyOp::Add) {
589 assert_eq!(add_op.complexity, NumPyOp::Add.complexity());
590 }
591
592 if let Some(sum_op) = converter.convert(&NumPyOp::Sum) {
593 assert_eq!(sum_op.complexity, NumPyOp::Sum.complexity());
594 }
595
596 if let Some(dot_op) = converter.convert(&NumPyOp::Dot) {
597 assert_eq!(dot_op.complexity, NumPyOp::Dot.complexity());
598 }
599 }
600}