1#![allow(missing_docs)]
7#![allow(unsafe_code)]
9
10pub mod error;
11pub mod parser;
12pub mod transpiler;
13pub mod runtime;
14pub mod memory;
15pub mod kernel;
16pub mod backend;
17pub mod utils;
18pub mod prelude;
19pub mod profiling;
20pub mod simd;
22
23pub mod neural_integration;
25
26#[cfg(not(target_arch = "wasm32"))]
28pub mod nutanix;
29
30pub use error::{CudaRustError, Result};
32pub use parser::CudaParser;
33pub use transpiler::{Transpiler, CudaTranspiler};
34pub use runtime::Runtime;
35
36pub use neural_integration::{
40 NeuralBridge, BridgeConfig, NeuralOperation, ActivationFunction as NeuralActivationFunction,
41 SystemCapabilities as NeuralCapabilities, initialize as init_neural_integration,
42 get_capabilities as get_neural_capabilities,
43};
44
45pub struct CudaRust {
47 parser: CudaParser,
48 transpiler: Transpiler,
49}
50
51impl CudaRust {
52 pub fn new() -> Self {
54 Self {
55 parser: CudaParser::new(),
56 transpiler: Transpiler::new(),
57 }
58 }
59
60 pub fn transpile(&self, cuda_source: &str) -> Result<String> {
62 let ast = self.parser.parse(cuda_source)?;
64
65 let rust_code = self.transpiler.transpile(ast)?;
67
68 Ok(rust_code)
69 }
70
71 #[cfg(feature = "webgpu-only")]
73 pub fn to_webgpu(&self, cuda_source: &str) -> Result<String> {
74 let ast = self.parser.parse(cuda_source)?;
75 let wgsl = self.transpiler.to_wgsl(ast)?;
76 Ok(wgsl)
77 }
78}
79
80impl Default for CudaRust {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86pub fn init() -> Result<Runtime> {
88 Runtime::new()
89}
90
91#[cfg(target_arch = "wasm32")]
93pub mod wasm {
94 use wasm_bindgen::prelude::*;
95
96 #[wasm_bindgen(start)]
98 pub fn init_wasm() {
99 console_error_panic_hook::set_once();
100
101 #[cfg(feature = "debug-transpiler")]
102 {
103 console_log::init_with_level(log::Level::Debug).ok();
104 }
105 }
106
107 #[wasm_bindgen]
109 pub fn transpile_cuda(cuda_code: &str) -> Result<String, JsValue> {
110 let transpiler = super::CudaRust::new();
111 transpiler.transpile(cuda_code)
112 .map_err(|e| JsValue::from_str(&e.to_string()))
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_basic_transpilation() {
122 let cuda_rust = CudaRust::new();
123 let cuda_code = r#"
124 __global__ void add(float* a, float* b, float* c) {
125 int i = threadIdx.x;
126 c[i] = a[i] + b[i];
127 }
128 "#;
129
130 let result = cuda_rust.transpile(cuda_code);
131 assert!(result.is_ok());
132 }
133}