Skip to main content

batuta/
wasm.rs

1//! WebAssembly API for Batuta
2//!
3//! Provides JavaScript-friendly interfaces for core Batuta functionality
4//! that operates on in-memory code without file system access.
5//!
6//! # Features
7//!
8//! - **Language Detection**: Analyze code snippets to detect languages
9//! - **Backend Selection**: Recommend optimal compute backend (SIMD/GPU)
10//! - **NumPy Conversion**: Convert NumPy operations to Trueno
11//! - **sklearn Conversion**: Convert sklearn algorithms to Aprender
12//! - **PyTorch Conversion**: Convert PyTorch operations to Realizar
13//! - **Code Analysis**: PARF pattern detection and analysis
14//!
15//! # Example (JavaScript)
16//!
17//! ```javascript
18//! import init, { analyze_code, convert_numpy, backend_recommend } from './batuta.js';
19//!
20//! await init();
21//!
22//! // Detect language
23//! const analysis = analyze_code("import numpy as np\nx = np.array([1, 2, 3])");
24//! console.log(analysis.language); // "Python"
25//!
26//! // Convert NumPy to Trueno
27//! const conversion = convert_numpy("np.add(a, b)");
28//! console.log(conversion.rust_code);
29//!
30//! // Get backend recommendation
31//! const backend = backend_recommend("matmul", 1024);
32//! console.log(backend); // "SIMD" or "GPU"
33//! ```
34
35#[cfg(feature = "wasm")]
36use wasm_bindgen::prelude::*;
37
38#[cfg(feature = "wasm")]
39use serde::{Deserialize, Serialize};
40
41#[cfg(feature = "wasm")]
42use crate::backend::{BackendSelector, OpComplexity};
43#[cfg(feature = "wasm")]
44use crate::numpy_converter::{NumPyConverter, NumPyOp};
45#[cfg(feature = "wasm")]
46use crate::pytorch_converter::{PyTorchConverter, PyTorchOperation};
47#[cfg(feature = "wasm")]
48use crate::sklearn_converter::{SklearnAlgorithm, SklearnConverter};
49
50/// Initialize the WASM module (sets panic hook for better error messages)
51#[cfg(feature = "wasm")]
52#[wasm_bindgen(start)]
53pub fn wasm_init() {
54    // Set panic hook for better error messages in browser console
55    // Note: console_error_panic_hook is not included as a feature
56    // Add feature to Cargo.toml if needed: console_error_panic_hook = "0.1"
57
58    web_sys::console::log_1(&"Batuta WASM module initialized".into());
59}
60
61/// Analysis result for code snippets
62#[cfg(feature = "wasm")]
63#[wasm_bindgen]
64#[allow(clippy::unsafe_derive_deserialize)]
65#[derive(Serialize, Deserialize)]
66pub struct AnalysisResult {
67    language: String,
68    has_numpy: bool,
69    has_sklearn: bool,
70    has_pytorch: bool,
71    lines_of_code: usize,
72}
73
74#[cfg(feature = "wasm")]
75#[wasm_bindgen]
76impl AnalysisResult {
77    #[wasm_bindgen(getter)]
78    pub fn language(&self) -> String {
79        self.language.clone()
80    }
81    #[wasm_bindgen(getter)]
82    pub fn has_numpy(&self) -> bool {
83        self.has_numpy
84    }
85    #[wasm_bindgen(getter)]
86    pub fn has_sklearn(&self) -> bool {
87        self.has_sklearn
88    }
89    #[wasm_bindgen(getter)]
90    pub fn has_pytorch(&self) -> bool {
91        self.has_pytorch
92    }
93    #[wasm_bindgen(getter)]
94    pub fn lines_of_code(&self) -> usize {
95        self.lines_of_code
96    }
97
98    /// Get JSON representation
99    #[wasm_bindgen(js_name = toJSON)]
100    pub fn to_json(&self) -> Result<String, JsValue> {
101        serde_json::to_string(self)
102            .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
103    }
104}
105
106/// Conversion result for code transformations
107#[cfg(feature = "wasm")]
108#[wasm_bindgen]
109#[allow(clippy::unsafe_derive_deserialize)]
110#[derive(Serialize, Deserialize)]
111pub struct ConversionResult {
112    original_code: String,
113    rust_code: String,
114    imports: String,
115    backend_recommendation: String,
116    complexity: String,
117}
118
119#[cfg(feature = "wasm")]
120#[wasm_bindgen]
121impl ConversionResult {
122    #[wasm_bindgen(getter)]
123    pub fn original_code(&self) -> String {
124        self.original_code.clone()
125    }
126    #[wasm_bindgen(getter)]
127    pub fn rust_code(&self) -> String {
128        self.rust_code.clone()
129    }
130    #[wasm_bindgen(getter)]
131    pub fn imports(&self) -> String {
132        self.imports.clone()
133    }
134    #[wasm_bindgen(getter)]
135    pub fn backend_recommendation(&self) -> String {
136        self.backend_recommendation.clone()
137    }
138    #[wasm_bindgen(getter)]
139    pub fn complexity(&self) -> String {
140        self.complexity.clone()
141    }
142
143    /// Get JSON representation
144    #[wasm_bindgen(js_name = toJSON)]
145    pub fn to_json(&self) -> Result<String, JsValue> {
146        serde_json::to_string(self)
147            .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
148    }
149}
150
151/// Analyze code snippet and detect language and dependencies
152///
153/// # Arguments
154/// * `code` - Source code to analyze
155///
156/// # Returns
157/// Analysis result with language detection and dependency info
158#[cfg(feature = "wasm")]
159#[wasm_bindgen]
160pub fn analyze_code(code: &str) -> Result<AnalysisResult, JsValue> {
161    let lines: Vec<&str> = code.lines().collect();
162
163    const LANG_PATTERNS: &[(&[&str], &str)] = &[
164        (&["import ", "def ", "class "], "Python"),
165        (&["#include", "int main"], "C/C++"),
166        (&["fn ", "struct "], "Rust"),
167        (&["#!/bin/bash", "#!/bin/sh"], "Shell"),
168    ];
169    let language = LANG_PATTERNS
170        .iter()
171        .find(|(pats, _)| pats.iter().any(|p| code.contains(p)))
172        .map(|(_, lang)| *lang)
173        .unwrap_or("Unknown");
174
175    // Detect ML libraries
176    let has_numpy = code.contains("numpy") || code.contains("np.");
177    let has_sklearn = code.contains("sklearn");
178    let has_pytorch = code.contains("torch") || code.contains("transformers");
179
180    Ok(AnalysisResult {
181        language: language.to_string(),
182        has_numpy,
183        has_sklearn,
184        has_pytorch,
185        lines_of_code: lines.len(),
186    })
187}
188
189/// Convert NumPy operation to Trueno Rust code
190///
191/// # Arguments
192/// * `numpy_code` - NumPy operation string (e.g., "np.add(a, b)")
193/// * `data_size` - Optional data size for backend recommendation
194///
195/// # Returns
196/// Conversion result with Rust code and backend recommendation
197#[cfg(feature = "wasm")]
198#[wasm_bindgen]
199pub fn convert_numpy(
200    numpy_code: &str,
201    data_size: Option<usize>,
202) -> Result<ConversionResult, JsValue> {
203    let converter = NumPyConverter::new();
204
205    // Data-driven pattern matching for NumPy operations
206    const NUMPY_PATTERNS: &[(&[&str], NumPyOp)] = &[
207        (&["np.add", "numpy.add"], NumPyOp::Add),
208        (&["np.dot", "numpy.dot"], NumPyOp::Dot),
209        (&["np.sum", "numpy.sum"], NumPyOp::Sum),
210        (&["np.mean", "numpy.mean"], NumPyOp::Mean),
211        (&["np.array", "numpy.array"], NumPyOp::Array),
212        (&["reshape"], NumPyOp::Reshape),
213        (&["transpose", ".T"], NumPyOp::Transpose),
214    ];
215    let op = NUMPY_PATTERNS
216        .iter()
217        .find(|(pats, _)| pats.iter().any(|p| numpy_code.contains(p)))
218        .map(|(_, op)| op.clone())
219        .ok_or_else(|| JsValue::from_str("Unsupported NumPy operation"))?;
220
221    let trueno_op = converter
222        .convert(&op)
223        .ok_or_else(|| JsValue::from_str("Conversion failed: operation not supported"))?;
224
225    let size = data_size.unwrap_or(1000);
226    let backend = converter.recommend_backend(&op, size);
227
228    Ok(ConversionResult {
229        original_code: numpy_code.to_string(),
230        rust_code: trueno_op.code_template.clone(),
231        imports: trueno_op.imports.join("\n"),
232        backend_recommendation: format!("{:?}", backend),
233        complexity: format!("{:?}", trueno_op.complexity),
234    })
235}
236
237/// Convert sklearn algorithm to Aprender Rust code
238///
239/// # Arguments
240/// * `sklearn_code` - sklearn algorithm string (e.g., "LinearRegression()")
241/// * `data_size` - Optional data size for backend recommendation
242///
243/// # Returns
244/// Conversion result with Rust code and backend recommendation
245#[cfg(feature = "wasm")]
246#[wasm_bindgen]
247pub fn convert_sklearn(
248    sklearn_code: &str,
249    data_size: Option<usize>,
250) -> Result<ConversionResult, JsValue> {
251    let converter = SklearnConverter::new();
252
253    const SKLEARN_PATTERNS: &[(&str, SklearnAlgorithm)] = &[
254        ("LinearRegression", SklearnAlgorithm::LinearRegression),
255        ("LogisticRegression", SklearnAlgorithm::LogisticRegression),
256        ("KMeans", SklearnAlgorithm::KMeans),
257        ("DecisionTreeClassifier", SklearnAlgorithm::DecisionTreeClassifier),
258        ("RandomForestClassifier", SklearnAlgorithm::RandomForestClassifier),
259        ("StandardScaler", SklearnAlgorithm::StandardScaler),
260    ];
261    let algo = SKLEARN_PATTERNS
262        .iter()
263        .find(|(pat, _)| sklearn_code.contains(pat))
264        .map(|(_, algo)| algo.clone())
265        .ok_or_else(|| JsValue::from_str("Unsupported sklearn algorithm"))?;
266
267    let aprender_algo = converter
268        .convert(&algo)
269        .ok_or_else(|| JsValue::from_str("Conversion failed: algorithm not supported"))?;
270
271    let size = data_size.unwrap_or(1000);
272    let backend = converter.recommend_backend(&algo, size);
273
274    Ok(ConversionResult {
275        original_code: sklearn_code.to_string(),
276        rust_code: aprender_algo.code_template.clone(),
277        imports: aprender_algo.imports.join("\n"),
278        backend_recommendation: format!("{:?}", backend),
279        complexity: format!("{:?}", aprender_algo.complexity),
280    })
281}
282
283/// Convert PyTorch operation to Realizar Rust code
284///
285/// # Arguments
286/// * `pytorch_code` - PyTorch operation string (e.g., "model.generate()")
287/// * `data_size` - Optional data size for backend recommendation
288///
289/// # Returns
290/// Conversion result with Rust code and backend recommendation
291#[cfg(feature = "wasm")]
292#[wasm_bindgen]
293pub fn convert_pytorch(
294    pytorch_code: &str,
295    data_size: Option<usize>,
296) -> Result<ConversionResult, JsValue> {
297    let converter = PyTorchConverter::new();
298
299    const PYTORCH_PATTERNS: &[(&[&str], PyTorchOperation)] = &[
300        (&["generate"], PyTorchOperation::Generate),
301        (&["forward"], PyTorchOperation::Forward),
302        (&["encode"], PyTorchOperation::Encode),
303        (&["decode"], PyTorchOperation::Decode),
304        (&["nn.Linear"], PyTorchOperation::Linear),
305        (&["Attention"], PyTorchOperation::Attention),
306    ];
307    // LoadModel needs special compound check (both "load" AND "model")
308    let op = if pytorch_code.contains("load") && pytorch_code.contains("model") {
309        PyTorchOperation::LoadModel
310    } else {
311        PYTORCH_PATTERNS
312            .iter()
313            .find(|(pats, _)| pats.iter().any(|p| pytorch_code.contains(p)))
314            .map(|(_, op)| op.clone())
315            .ok_or_else(|| JsValue::from_str("Unsupported PyTorch operation"))?
316    };
317
318    let realizar_op = converter
319        .convert(&op)
320        .ok_or_else(|| JsValue::from_str("Conversion failed: operation not supported"))?;
321
322    let size = data_size.unwrap_or(1000000); // LLMs are large by default
323    let backend = converter.recommend_backend(&op, size);
324
325    Ok(ConversionResult {
326        original_code: pytorch_code.to_string(),
327        rust_code: realizar_op.code_template.clone(),
328        imports: realizar_op.imports.join("\n"),
329        backend_recommendation: format!("{:?}", backend),
330        complexity: format!("{:?}", realizar_op.complexity),
331    })
332}
333
334/// Get backend recommendation for an operation
335///
336/// # Arguments
337/// * `operation_type` - Type of operation ("element-wise", "reduction", "matmul")
338/// * `data_size` - Size of data to process
339///
340/// # Returns
341/// Recommended backend as string ("Scalar", "SIMD", "GPU")
342#[cfg(feature = "wasm")]
343#[wasm_bindgen]
344pub fn backend_recommend(operation_type: &str, data_size: usize) -> Result<String, JsValue> {
345    let selector = BackendSelector::new();
346
347    let backend = match operation_type.to_lowercase().as_str() {
348        "element-wise" | "elementwise" => selector.select_for_elementwise(data_size),
349        "reduction" | "reduce" => {
350            let complexity = if data_size < 10_000 {
351                OpComplexity::Low
352            } else if data_size < 100_000 {
353                OpComplexity::Medium
354            } else {
355                OpComplexity::High
356            };
357            selector.select_with_moe(complexity, data_size)
358        }
359        "matmul" | "matrix-multiply" => {
360            let n = (data_size as f64).sqrt() as usize;
361            selector.select_for_matmul(n, n, n)
362        }
363        _ => return Err(JsValue::from_str("Unknown operation type")),
364    };
365
366    Ok(format!("{:?}", backend))
367}
368
369/// Get Batuta version
370#[cfg(feature = "wasm")]
371#[wasm_bindgen]
372pub fn version() -> String {
373    env!("CARGO_PKG_VERSION").to_string()
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    // ============================================================================
381    // ANALYSIS RESULT TESTS (Native-compatible)
382    // ============================================================================
383
384    #[test]
385    #[cfg(feature = "wasm")]
386    fn test_analysis_result_direct_construction() {
387        let result = AnalysisResult {
388            language: "Python".to_string(),
389            has_numpy: true,
390            has_sklearn: false,
391            has_pytorch: true,
392            lines_of_code: 42,
393        };
394
395        // Test direct field access (not through wasm_bindgen getters)
396        assert_eq!(result.language, "Python");
397        assert!(result.has_numpy);
398        assert!(!result.has_sklearn);
399        assert!(result.has_pytorch);
400        assert_eq!(result.lines_of_code, 42);
401    }
402
403    #[test]
404    #[cfg(feature = "wasm")]
405    fn test_analysis_result_serialization() {
406        let result = AnalysisResult {
407            language: "C/C++".to_string(),
408            has_numpy: false,
409            has_sklearn: false,
410            has_pytorch: false,
411            lines_of_code: 50,
412        };
413
414        let json = serde_json::to_string(&result).expect("json serialize failed");
415        let deserialized: AnalysisResult =
416            serde_json::from_str(&json).expect("json deserialize failed");
417
418        assert_eq!(result.language, deserialized.language);
419        assert_eq!(result.has_numpy, deserialized.has_numpy);
420        assert_eq!(result.lines_of_code, deserialized.lines_of_code);
421    }
422
423    #[test]
424    #[cfg(feature = "wasm")]
425    fn test_analysis_result_multiple_languages() {
426        let rust_result = AnalysisResult {
427            language: "Rust".to_string(),
428            has_numpy: false,
429            has_sklearn: false,
430            has_pytorch: false,
431            lines_of_code: 100,
432        };
433        assert_eq!(rust_result.language, "Rust");
434
435        let python_result = AnalysisResult {
436            language: "Python".to_string(),
437            has_numpy: true,
438            has_sklearn: true,
439            has_pytorch: false,
440            lines_of_code: 200,
441        };
442        assert_eq!(python_result.language, "Python");
443        assert!(python_result.has_numpy);
444        assert!(python_result.has_sklearn);
445    }
446
447    // ============================================================================
448    // CONVERSION RESULT TESTS (Native-compatible)
449    // ============================================================================
450
451    #[test]
452    #[cfg(feature = "wasm")]
453    fn test_conversion_result_direct_construction() {
454        let result = ConversionResult {
455            original_code: "np.add(a, b)".to_string(),
456            rust_code: "a.add(&b)".to_string(),
457            imports: "use trueno::Vector;".to_string(),
458            backend_recommendation: "SIMD".to_string(),
459            complexity: "Low".to_string(),
460        };
461
462        assert_eq!(result.original_code, "np.add(a, b)");
463        assert_eq!(result.rust_code, "a.add(&b)");
464        assert_eq!(result.imports, "use trueno::Vector;");
465        assert_eq!(result.backend_recommendation, "SIMD");
466        assert_eq!(result.complexity, "Low");
467    }
468
469    #[test]
470    #[cfg(feature = "wasm")]
471    fn test_conversion_result_serialization() {
472        let result = ConversionResult {
473            original_code: "original".to_string(),
474            rust_code: "rust".to_string(),
475            imports: "imports".to_string(),
476            backend_recommendation: "SIMD".to_string(),
477            complexity: "Medium".to_string(),
478        };
479
480        let json = serde_json::to_string(&result).expect("json serialize failed");
481        let deserialized: ConversionResult =
482            serde_json::from_str(&json).expect("json deserialize failed");
483
484        assert_eq!(result.original_code, deserialized.original_code);
485        assert_eq!(result.rust_code, deserialized.rust_code);
486        assert_eq!(result.backend_recommendation, deserialized.backend_recommendation);
487    }
488
489    #[test]
490    #[cfg(feature = "wasm")]
491    fn test_conversion_result_all_backends() {
492        let scalar_result = ConversionResult {
493            original_code: "test".to_string(),
494            rust_code: "test_rust".to_string(),
495            imports: "".to_string(),
496            backend_recommendation: "Scalar".to_string(),
497            complexity: "Low".to_string(),
498        };
499        assert_eq!(scalar_result.backend_recommendation, "Scalar");
500
501        let simd_result = ConversionResult {
502            original_code: "test".to_string(),
503            rust_code: "test_rust".to_string(),
504            imports: "".to_string(),
505            backend_recommendation: "SIMD".to_string(),
506            complexity: "Medium".to_string(),
507        };
508        assert_eq!(simd_result.backend_recommendation, "SIMD");
509
510        let gpu_result = ConversionResult {
511            original_code: "test".to_string(),
512            rust_code: "test_rust".to_string(),
513            imports: "".to_string(),
514            backend_recommendation: "GPU".to_string(),
515            complexity: "High".to_string(),
516        };
517        assert_eq!(gpu_result.backend_recommendation, "GPU");
518    }
519
520    // NOTE: Tests for wasm_bindgen functions (analyze_code, convert_numpy, etc.)
521    // cannot run on native targets. They require wasm32 target and wasm-bindgen-test.
522    // The tests above cover the data structures (AnalysisResult, ConversionResult)
523    // which is what can be tested natively. For full WASM function testing, use:
524    // wasm-pack test --node --features wasm
525}