1#[cfg(feature = "wasm")]
36use wasm_bindgen::prelude::*;
37
38#[cfg(feature = "wasm")]
39use serde::{Deserialize, Serialize};
40
41#[cfg(feature = "wasm")]
42use crate::backend::{BackendSelector, OpComplexity};
43#[cfg(feature = "wasm")]
44use crate::numpy_converter::{NumPyConverter, NumPyOp};
45#[cfg(feature = "wasm")]
46use crate::pytorch_converter::{PyTorchConverter, PyTorchOperation};
47#[cfg(feature = "wasm")]
48use crate::sklearn_converter::{SklearnAlgorithm, SklearnConverter};
49
50#[cfg(feature = "wasm")]
52#[wasm_bindgen(start)]
53pub fn wasm_init() {
54 web_sys::console::log_1(&"Batuta WASM module initialized".into());
59}
60
61#[cfg(feature = "wasm")]
63#[wasm_bindgen]
64#[allow(clippy::unsafe_derive_deserialize)]
65#[derive(Serialize, Deserialize)]
66pub struct AnalysisResult {
67 language: String,
68 has_numpy: bool,
69 has_sklearn: bool,
70 has_pytorch: bool,
71 lines_of_code: usize,
72}
73
74#[cfg(feature = "wasm")]
75#[wasm_bindgen]
76impl AnalysisResult {
77 #[wasm_bindgen(getter)]
78 pub fn language(&self) -> String {
79 self.language.clone()
80 }
81 #[wasm_bindgen(getter)]
82 pub fn has_numpy(&self) -> bool {
83 self.has_numpy
84 }
85 #[wasm_bindgen(getter)]
86 pub fn has_sklearn(&self) -> bool {
87 self.has_sklearn
88 }
89 #[wasm_bindgen(getter)]
90 pub fn has_pytorch(&self) -> bool {
91 self.has_pytorch
92 }
93 #[wasm_bindgen(getter)]
94 pub fn lines_of_code(&self) -> usize {
95 self.lines_of_code
96 }
97
98 #[wasm_bindgen(js_name = toJSON)]
100 pub fn to_json(&self) -> Result<String, JsValue> {
101 serde_json::to_string(self)
102 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
103 }
104}
105
106#[cfg(feature = "wasm")]
108#[wasm_bindgen]
109#[allow(clippy::unsafe_derive_deserialize)]
110#[derive(Serialize, Deserialize)]
111pub struct ConversionResult {
112 original_code: String,
113 rust_code: String,
114 imports: String,
115 backend_recommendation: String,
116 complexity: String,
117}
118
119#[cfg(feature = "wasm")]
120#[wasm_bindgen]
121impl ConversionResult {
122 #[wasm_bindgen(getter)]
123 pub fn original_code(&self) -> String {
124 self.original_code.clone()
125 }
126 #[wasm_bindgen(getter)]
127 pub fn rust_code(&self) -> String {
128 self.rust_code.clone()
129 }
130 #[wasm_bindgen(getter)]
131 pub fn imports(&self) -> String {
132 self.imports.clone()
133 }
134 #[wasm_bindgen(getter)]
135 pub fn backend_recommendation(&self) -> String {
136 self.backend_recommendation.clone()
137 }
138 #[wasm_bindgen(getter)]
139 pub fn complexity(&self) -> String {
140 self.complexity.clone()
141 }
142
143 #[wasm_bindgen(js_name = toJSON)]
145 pub fn to_json(&self) -> Result<String, JsValue> {
146 serde_json::to_string(self)
147 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
148 }
149}
150
151#[cfg(feature = "wasm")]
159#[wasm_bindgen]
160pub fn analyze_code(code: &str) -> Result<AnalysisResult, JsValue> {
161 let lines: Vec<&str> = code.lines().collect();
162
163 const LANG_PATTERNS: &[(&[&str], &str)] = &[
164 (&["import ", "def ", "class "], "Python"),
165 (&["#include", "int main"], "C/C++"),
166 (&["fn ", "struct "], "Rust"),
167 (&["#!/bin/bash", "#!/bin/sh"], "Shell"),
168 ];
169 let language = LANG_PATTERNS
170 .iter()
171 .find(|(pats, _)| pats.iter().any(|p| code.contains(p)))
172 .map(|(_, lang)| *lang)
173 .unwrap_or("Unknown");
174
175 let has_numpy = code.contains("numpy") || code.contains("np.");
177 let has_sklearn = code.contains("sklearn");
178 let has_pytorch = code.contains("torch") || code.contains("transformers");
179
180 Ok(AnalysisResult {
181 language: language.to_string(),
182 has_numpy,
183 has_sklearn,
184 has_pytorch,
185 lines_of_code: lines.len(),
186 })
187}
188
189#[cfg(feature = "wasm")]
198#[wasm_bindgen]
199pub fn convert_numpy(
200 numpy_code: &str,
201 data_size: Option<usize>,
202) -> Result<ConversionResult, JsValue> {
203 let converter = NumPyConverter::new();
204
205 const NUMPY_PATTERNS: &[(&[&str], NumPyOp)] = &[
207 (&["np.add", "numpy.add"], NumPyOp::Add),
208 (&["np.dot", "numpy.dot"], NumPyOp::Dot),
209 (&["np.sum", "numpy.sum"], NumPyOp::Sum),
210 (&["np.mean", "numpy.mean"], NumPyOp::Mean),
211 (&["np.array", "numpy.array"], NumPyOp::Array),
212 (&["reshape"], NumPyOp::Reshape),
213 (&["transpose", ".T"], NumPyOp::Transpose),
214 ];
215 let op = NUMPY_PATTERNS
216 .iter()
217 .find(|(pats, _)| pats.iter().any(|p| numpy_code.contains(p)))
218 .map(|(_, op)| op.clone())
219 .ok_or_else(|| JsValue::from_str("Unsupported NumPy operation"))?;
220
221 let trueno_op = converter
222 .convert(&op)
223 .ok_or_else(|| JsValue::from_str("Conversion failed: operation not supported"))?;
224
225 let size = data_size.unwrap_or(1000);
226 let backend = converter.recommend_backend(&op, size);
227
228 Ok(ConversionResult {
229 original_code: numpy_code.to_string(),
230 rust_code: trueno_op.code_template.clone(),
231 imports: trueno_op.imports.join("\n"),
232 backend_recommendation: format!("{:?}", backend),
233 complexity: format!("{:?}", trueno_op.complexity),
234 })
235}
236
237#[cfg(feature = "wasm")]
246#[wasm_bindgen]
247pub fn convert_sklearn(
248 sklearn_code: &str,
249 data_size: Option<usize>,
250) -> Result<ConversionResult, JsValue> {
251 let converter = SklearnConverter::new();
252
253 const SKLEARN_PATTERNS: &[(&str, SklearnAlgorithm)] = &[
254 ("LinearRegression", SklearnAlgorithm::LinearRegression),
255 ("LogisticRegression", SklearnAlgorithm::LogisticRegression),
256 ("KMeans", SklearnAlgorithm::KMeans),
257 ("DecisionTreeClassifier", SklearnAlgorithm::DecisionTreeClassifier),
258 ("RandomForestClassifier", SklearnAlgorithm::RandomForestClassifier),
259 ("StandardScaler", SklearnAlgorithm::StandardScaler),
260 ];
261 let algo = SKLEARN_PATTERNS
262 .iter()
263 .find(|(pat, _)| sklearn_code.contains(pat))
264 .map(|(_, algo)| algo.clone())
265 .ok_or_else(|| JsValue::from_str("Unsupported sklearn algorithm"))?;
266
267 let aprender_algo = converter
268 .convert(&algo)
269 .ok_or_else(|| JsValue::from_str("Conversion failed: algorithm not supported"))?;
270
271 let size = data_size.unwrap_or(1000);
272 let backend = converter.recommend_backend(&algo, size);
273
274 Ok(ConversionResult {
275 original_code: sklearn_code.to_string(),
276 rust_code: aprender_algo.code_template.clone(),
277 imports: aprender_algo.imports.join("\n"),
278 backend_recommendation: format!("{:?}", backend),
279 complexity: format!("{:?}", aprender_algo.complexity),
280 })
281}
282
283#[cfg(feature = "wasm")]
292#[wasm_bindgen]
293pub fn convert_pytorch(
294 pytorch_code: &str,
295 data_size: Option<usize>,
296) -> Result<ConversionResult, JsValue> {
297 let converter = PyTorchConverter::new();
298
299 const PYTORCH_PATTERNS: &[(&[&str], PyTorchOperation)] = &[
300 (&["generate"], PyTorchOperation::Generate),
301 (&["forward"], PyTorchOperation::Forward),
302 (&["encode"], PyTorchOperation::Encode),
303 (&["decode"], PyTorchOperation::Decode),
304 (&["nn.Linear"], PyTorchOperation::Linear),
305 (&["Attention"], PyTorchOperation::Attention),
306 ];
307 let op = if pytorch_code.contains("load") && pytorch_code.contains("model") {
309 PyTorchOperation::LoadModel
310 } else {
311 PYTORCH_PATTERNS
312 .iter()
313 .find(|(pats, _)| pats.iter().any(|p| pytorch_code.contains(p)))
314 .map(|(_, op)| op.clone())
315 .ok_or_else(|| JsValue::from_str("Unsupported PyTorch operation"))?
316 };
317
318 let realizar_op = converter
319 .convert(&op)
320 .ok_or_else(|| JsValue::from_str("Conversion failed: operation not supported"))?;
321
322 let size = data_size.unwrap_or(1000000); let backend = converter.recommend_backend(&op, size);
324
325 Ok(ConversionResult {
326 original_code: pytorch_code.to_string(),
327 rust_code: realizar_op.code_template.clone(),
328 imports: realizar_op.imports.join("\n"),
329 backend_recommendation: format!("{:?}", backend),
330 complexity: format!("{:?}", realizar_op.complexity),
331 })
332}
333
334#[cfg(feature = "wasm")]
343#[wasm_bindgen]
344pub fn backend_recommend(operation_type: &str, data_size: usize) -> Result<String, JsValue> {
345 let selector = BackendSelector::new();
346
347 let backend = match operation_type.to_lowercase().as_str() {
348 "element-wise" | "elementwise" => selector.select_for_elementwise(data_size),
349 "reduction" | "reduce" => {
350 let complexity = if data_size < 10_000 {
351 OpComplexity::Low
352 } else if data_size < 100_000 {
353 OpComplexity::Medium
354 } else {
355 OpComplexity::High
356 };
357 selector.select_with_moe(complexity, data_size)
358 }
359 "matmul" | "matrix-multiply" => {
360 let n = (data_size as f64).sqrt() as usize;
361 selector.select_for_matmul(n, n, n)
362 }
363 _ => return Err(JsValue::from_str("Unknown operation type")),
364 };
365
366 Ok(format!("{:?}", backend))
367}
368
369#[cfg(feature = "wasm")]
371#[wasm_bindgen]
372pub fn version() -> String {
373 env!("CARGO_PKG_VERSION").to_string()
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
385 #[cfg(feature = "wasm")]
386 fn test_analysis_result_direct_construction() {
387 let result = AnalysisResult {
388 language: "Python".to_string(),
389 has_numpy: true,
390 has_sklearn: false,
391 has_pytorch: true,
392 lines_of_code: 42,
393 };
394
395 assert_eq!(result.language, "Python");
397 assert!(result.has_numpy);
398 assert!(!result.has_sklearn);
399 assert!(result.has_pytorch);
400 assert_eq!(result.lines_of_code, 42);
401 }
402
403 #[test]
404 #[cfg(feature = "wasm")]
405 fn test_analysis_result_serialization() {
406 let result = AnalysisResult {
407 language: "C/C++".to_string(),
408 has_numpy: false,
409 has_sklearn: false,
410 has_pytorch: false,
411 lines_of_code: 50,
412 };
413
414 let json = serde_json::to_string(&result).expect("json serialize failed");
415 let deserialized: AnalysisResult =
416 serde_json::from_str(&json).expect("json deserialize failed");
417
418 assert_eq!(result.language, deserialized.language);
419 assert_eq!(result.has_numpy, deserialized.has_numpy);
420 assert_eq!(result.lines_of_code, deserialized.lines_of_code);
421 }
422
423 #[test]
424 #[cfg(feature = "wasm")]
425 fn test_analysis_result_multiple_languages() {
426 let rust_result = AnalysisResult {
427 language: "Rust".to_string(),
428 has_numpy: false,
429 has_sklearn: false,
430 has_pytorch: false,
431 lines_of_code: 100,
432 };
433 assert_eq!(rust_result.language, "Rust");
434
435 let python_result = AnalysisResult {
436 language: "Python".to_string(),
437 has_numpy: true,
438 has_sklearn: true,
439 has_pytorch: false,
440 lines_of_code: 200,
441 };
442 assert_eq!(python_result.language, "Python");
443 assert!(python_result.has_numpy);
444 assert!(python_result.has_sklearn);
445 }
446
447 #[test]
452 #[cfg(feature = "wasm")]
453 fn test_conversion_result_direct_construction() {
454 let result = ConversionResult {
455 original_code: "np.add(a, b)".to_string(),
456 rust_code: "a.add(&b)".to_string(),
457 imports: "use trueno::Vector;".to_string(),
458 backend_recommendation: "SIMD".to_string(),
459 complexity: "Low".to_string(),
460 };
461
462 assert_eq!(result.original_code, "np.add(a, b)");
463 assert_eq!(result.rust_code, "a.add(&b)");
464 assert_eq!(result.imports, "use trueno::Vector;");
465 assert_eq!(result.backend_recommendation, "SIMD");
466 assert_eq!(result.complexity, "Low");
467 }
468
469 #[test]
470 #[cfg(feature = "wasm")]
471 fn test_conversion_result_serialization() {
472 let result = ConversionResult {
473 original_code: "original".to_string(),
474 rust_code: "rust".to_string(),
475 imports: "imports".to_string(),
476 backend_recommendation: "SIMD".to_string(),
477 complexity: "Medium".to_string(),
478 };
479
480 let json = serde_json::to_string(&result).expect("json serialize failed");
481 let deserialized: ConversionResult =
482 serde_json::from_str(&json).expect("json deserialize failed");
483
484 assert_eq!(result.original_code, deserialized.original_code);
485 assert_eq!(result.rust_code, deserialized.rust_code);
486 assert_eq!(result.backend_recommendation, deserialized.backend_recommendation);
487 }
488
489 #[test]
490 #[cfg(feature = "wasm")]
491 fn test_conversion_result_all_backends() {
492 let scalar_result = ConversionResult {
493 original_code: "test".to_string(),
494 rust_code: "test_rust".to_string(),
495 imports: "".to_string(),
496 backend_recommendation: "Scalar".to_string(),
497 complexity: "Low".to_string(),
498 };
499 assert_eq!(scalar_result.backend_recommendation, "Scalar");
500
501 let simd_result = ConversionResult {
502 original_code: "test".to_string(),
503 rust_code: "test_rust".to_string(),
504 imports: "".to_string(),
505 backend_recommendation: "SIMD".to_string(),
506 complexity: "Medium".to_string(),
507 };
508 assert_eq!(simd_result.backend_recommendation, "SIMD");
509
510 let gpu_result = ConversionResult {
511 original_code: "test".to_string(),
512 rust_code: "test_rust".to_string(),
513 imports: "".to_string(),
514 backend_recommendation: "GPU".to_string(),
515 complexity: "High".to_string(),
516 };
517 assert_eq!(gpu_result.backend_recommendation, "GPU");
518 }
519
520 }