use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NumPyOp {
Array,
Add,
Subtract,
Multiply,
Divide,
Dot,
Sum,
Mean,
Max,
Min,
Reshape,
Transpose,
}
impl NumPyOp {
pub fn complexity(&self) -> crate::backend::OpComplexity {
use crate::backend::OpComplexity;
match self {
NumPyOp::Add | NumPyOp::Subtract | NumPyOp::Multiply | NumPyOp::Divide => {
OpComplexity::Low
}
NumPyOp::Sum | NumPyOp::Mean | NumPyOp::Max | NumPyOp::Min => OpComplexity::Medium,
NumPyOp::Dot => OpComplexity::High,
NumPyOp::Array | NumPyOp::Reshape | NumPyOp::Transpose => OpComplexity::Low,
}
}
}
#[derive(Debug, Clone)]
pub struct TruenoOp {
pub code_template: String,
pub imports: Vec<String>,
pub complexity: crate::backend::OpComplexity,
}
pub struct NumPyConverter {
op_map: HashMap<NumPyOp, TruenoOp>,
backend_selector: crate::backend::BackendSelector,
}
impl Default for NumPyConverter {
fn default() -> Self {
Self::new()
}
}
impl NumPyConverter {
pub fn new() -> Self {
let mut op_map = HashMap::new();
op_map.insert(
NumPyOp::Array,
TruenoOp {
code_template: "Vector::from_slice(&[{values}])".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::Low,
},
);
op_map.insert(
NumPyOp::Add,
TruenoOp {
code_template: "{lhs}.add(&{rhs}).unwrap()".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::Low,
},
);
op_map.insert(
NumPyOp::Subtract,
TruenoOp {
code_template: "{lhs}.sub(&{rhs}).unwrap()".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::Low,
},
);
op_map.insert(
NumPyOp::Multiply,
TruenoOp {
code_template: "{lhs}.mul(&{rhs}).unwrap()".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::Low,
},
);
op_map.insert(
NumPyOp::Sum,
TruenoOp {
code_template: "{array}.sum()".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::Medium,
},
);
op_map.insert(
NumPyOp::Dot,
TruenoOp {
code_template: "{lhs}.dot(&{rhs}).unwrap()".to_string(),
imports: vec!["use trueno::Vector;".to_string()],
complexity: crate::backend::OpComplexity::High,
},
);
Self { op_map, backend_selector: crate::backend::BackendSelector::new() }
}
pub fn convert(&self, op: &NumPyOp) -> Option<&TruenoOp> {
self.op_map.get(op)
}
pub fn recommend_backend(&self, op: &NumPyOp, data_size: usize) -> crate::backend::Backend {
self.backend_selector.select_with_moe(op.complexity(), data_size)
}
pub fn available_ops(&self) -> Vec<&NumPyOp> {
self.op_map.keys().collect()
}
pub fn conversion_report(&self) -> String {
let mut report = String::from("NumPy → Trueno Conversion Map\n");
report.push_str("================================\n\n");
for (op, trueno_op) in &self.op_map {
report.push_str(&format!("{:?}:\n", op));
report.push_str(&format!(" Complexity: {:?}\n", trueno_op.complexity));
report.push_str(&format!(" Template: {}\n", trueno_op.code_template));
report.push_str(&format!(" Imports: {}\n\n", trueno_op.imports.join(", ")));
}
report
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_converter_creation() {
let converter = NumPyConverter::new();
assert!(!converter.available_ops().is_empty());
}
#[test]
fn test_operation_complexity() {
assert_eq!(NumPyOp::Add.complexity(), crate::backend::OpComplexity::Low);
assert_eq!(NumPyOp::Sum.complexity(), crate::backend::OpComplexity::Medium);
assert_eq!(NumPyOp::Dot.complexity(), crate::backend::OpComplexity::High);
}
#[test]
fn test_add_conversion() {
let converter = NumPyConverter::new();
let trueno_op = converter.convert(&NumPyOp::Add).expect("conversion failed");
assert!(trueno_op.code_template.contains("add"));
assert!(trueno_op.imports.iter().any(|i| i.contains("Vector")));
}
#[test]
fn test_backend_recommendation() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Add, 100);
assert_eq!(backend, crate::backend::Backend::Scalar);
let backend = converter.recommend_backend(&NumPyOp::Add, 2_000_000);
assert_eq!(backend, crate::backend::Backend::SIMD);
let backend = converter.recommend_backend(&NumPyOp::Dot, 50_000);
assert_eq!(backend, crate::backend::Backend::GPU);
}
#[test]
fn test_conversion_report() {
let converter = NumPyConverter::new();
let report = converter.conversion_report();
assert!(report.contains("NumPy → Trueno"));
assert!(report.contains("Add"));
assert!(report.contains("Complexity"));
}
#[test]
fn test_all_numpy_ops_exist() {
let ops = vec![
NumPyOp::Array,
NumPyOp::Add,
NumPyOp::Subtract,
NumPyOp::Multiply,
NumPyOp::Divide,
NumPyOp::Dot,
NumPyOp::Sum,
NumPyOp::Mean,
NumPyOp::Max,
NumPyOp::Min,
NumPyOp::Reshape,
NumPyOp::Transpose,
];
assert_eq!(ops.len(), 12); }
#[test]
fn test_op_equality() {
assert_eq!(NumPyOp::Add, NumPyOp::Add);
assert_ne!(NumPyOp::Add, NumPyOp::Multiply);
}
#[test]
fn test_op_clone() {
let op1 = NumPyOp::Dot;
let op2 = op1.clone();
assert_eq!(op1, op2);
}
#[test]
fn test_complexity_low_ops() {
let low_ops = vec![
NumPyOp::Add,
NumPyOp::Subtract,
NumPyOp::Multiply,
NumPyOp::Divide,
NumPyOp::Array,
NumPyOp::Reshape,
NumPyOp::Transpose,
];
for op in low_ops {
assert_eq!(op.complexity(), crate::backend::OpComplexity::Low);
}
}
#[test]
fn test_complexity_medium_ops() {
let medium_ops = vec![NumPyOp::Sum, NumPyOp::Mean, NumPyOp::Max, NumPyOp::Min];
for op in medium_ops {
assert_eq!(op.complexity(), crate::backend::OpComplexity::Medium);
}
}
#[test]
fn test_complexity_high_ops() {
let high_ops = vec![NumPyOp::Dot];
for op in high_ops {
assert_eq!(op.complexity(), crate::backend::OpComplexity::High);
}
}
#[test]
fn test_trueno_op_construction() {
let op = TruenoOp {
code_template: "test_template".to_string(),
imports: vec!["use test;".to_string()],
complexity: crate::backend::OpComplexity::Medium,
};
assert_eq!(op.code_template, "test_template");
assert_eq!(op.imports.len(), 1);
assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
}
#[test]
fn test_trueno_op_clone() {
let op1 = TruenoOp {
code_template: "template".to_string(),
imports: vec!["import".to_string()],
complexity: crate::backend::OpComplexity::High,
};
let op2 = op1.clone();
assert_eq!(op1.code_template, op2.code_template);
assert_eq!(op1.imports, op2.imports);
assert_eq!(op1.complexity, op2.complexity);
}
#[test]
fn test_converter_default() {
let converter = NumPyConverter::default();
assert!(!converter.available_ops().is_empty());
}
#[test]
fn test_convert_all_mapped_ops() {
let converter = NumPyConverter::new();
let mapped_ops = vec![
NumPyOp::Array,
NumPyOp::Add,
NumPyOp::Subtract,
NumPyOp::Multiply,
NumPyOp::Sum,
NumPyOp::Dot,
];
for op in mapped_ops {
assert!(converter.convert(&op).is_some(), "Missing mapping for {:?}", op);
}
}
#[test]
fn test_convert_unmapped_op() {
let converter = NumPyConverter::new();
let result = converter.convert(&NumPyOp::Divide);
let _ = result;
}
#[test]
fn test_array_conversion() {
let converter = NumPyConverter::new();
let op = converter.convert(&NumPyOp::Array).expect("conversion failed");
assert!(op.code_template.contains("Vector"));
assert!(op.code_template.contains("from_slice"));
assert!(op.imports.iter().any(|i| i.contains("Vector")));
assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
}
#[test]
fn test_subtract_conversion() {
let converter = NumPyConverter::new();
let op = converter.convert(&NumPyOp::Subtract).expect("conversion failed");
assert!(op.code_template.contains("sub"));
assert!(op.imports.iter().any(|i| i.contains("Vector")));
assert_eq!(op.complexity, crate::backend::OpComplexity::Low);
}
#[test]
fn test_multiply_conversion() {
let converter = NumPyConverter::new();
let op = converter.convert(&NumPyOp::Multiply).expect("conversion failed");
assert!(op.code_template.contains("mul"));
assert!(op.imports.iter().any(|i| i.contains("Vector")));
}
#[test]
fn test_sum_conversion() {
let converter = NumPyConverter::new();
let op = converter.convert(&NumPyOp::Sum).expect("conversion failed");
assert!(op.code_template.contains("sum"));
assert_eq!(op.complexity, crate::backend::OpComplexity::Medium);
}
#[test]
fn test_dot_conversion() {
let converter = NumPyConverter::new();
let op = converter.convert(&NumPyOp::Dot).expect("conversion failed");
assert!(op.code_template.contains("dot"));
assert_eq!(op.complexity, crate::backend::OpComplexity::High);
}
#[test]
fn test_available_ops() {
let converter = NumPyConverter::new();
let ops = converter.available_ops();
assert!(!ops.is_empty());
assert!(ops.len() >= 6);
}
#[test]
fn test_recommend_backend_element_wise_small() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Add, 10);
assert_eq!(backend, crate::backend::Backend::Scalar);
}
#[test]
fn test_recommend_backend_element_wise_large() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Multiply, 2_000_000);
assert_eq!(backend, crate::backend::Backend::SIMD);
}
#[test]
fn test_recommend_backend_reduction_medium() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Sum, 50_000);
assert_eq!(backend, crate::backend::Backend::SIMD);
}
#[test]
fn test_recommend_backend_reduction_large() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Sum, 500_000);
assert_eq!(backend, crate::backend::Backend::GPU);
}
#[test]
fn test_recommend_backend_dot_product() {
let converter = NumPyConverter::new();
let backend = converter.recommend_backend(&NumPyOp::Dot, 100_000);
assert_eq!(backend, crate::backend::Backend::GPU);
}
#[test]
fn test_conversion_report_structure() {
let converter = NumPyConverter::new();
let report = converter.conversion_report();
assert!(report.contains("NumPy → Trueno"));
assert!(report.contains("==="));
assert!(report.contains("Complexity:"));
assert!(report.contains("Template:"));
assert!(report.contains("Imports:"));
}
#[test]
fn test_conversion_report_has_all_ops() {
let converter = NumPyConverter::new();
let report = converter.conversion_report();
assert!(report.contains("Add") || report.contains("Sum") || report.contains("Dot"));
}
#[test]
fn test_all_conversions_not_empty() {
let converter = NumPyConverter::new();
for op in converter.available_ops() {
if let Some(trueno_op) = converter.convert(op) {
assert!(!trueno_op.code_template.is_empty(), "Empty code template for {:?}", op);
assert!(!trueno_op.imports.is_empty(), "Empty imports for {:?}", op);
}
}
}
#[test]
fn test_imports_are_valid_rust() {
let converter = NumPyConverter::new();
for op in converter.available_ops() {
if let Some(trueno_op) = converter.convert(op) {
for import in &trueno_op.imports {
assert!(import.starts_with("use "), "Invalid import syntax: {}", import);
assert!(import.ends_with(';'), "Import missing semicolon: {}", import);
}
}
}
}
#[test]
fn test_all_ops_use_vector_import() {
let converter = NumPyConverter::new();
for op in converter.available_ops() {
if let Some(trueno_op) = converter.convert(op) {
assert!(
trueno_op.imports.iter().any(|i| i.contains("Vector")),
"Operation {:?} should import Vector",
op
);
}
}
}
#[test]
fn test_element_wise_ops_have_unwrap() {
let converter = NumPyConverter::new();
let element_wise = vec![NumPyOp::Add, NumPyOp::Subtract, NumPyOp::Multiply];
for op in element_wise {
if let Some(trueno_op) = converter.convert(&op) {
assert!(
trueno_op.code_template.contains("unwrap"),
"Element-wise op {:?} should have unwrap() for error handling",
op
);
}
}
}
#[test]
fn test_complexity_matches_enum() {
let converter = NumPyConverter::new();
if let Some(add_op) = converter.convert(&NumPyOp::Add) {
assert_eq!(add_op.complexity, NumPyOp::Add.complexity());
}
if let Some(sum_op) = converter.convert(&NumPyOp::Sum) {
assert_eq!(sum_op.complexity, NumPyOp::Sum.complexity());
}
if let Some(dot_op) = converter.convert(&NumPyOp::Dot) {
assert_eq!(dot_op.complexity, NumPyOp::Dot.complexity());
}
}
}