1#![warn(missing_docs)]
7#![allow(unsafe_code)]
9
10pub mod error;
11pub mod parser;
12pub mod transpiler;
13pub mod runtime;
14pub mod memory;
15pub mod kernel;
16pub mod backend;
17pub mod utils;
18pub mod prelude;
19pub mod profiling;
20
21pub mod neural_integration;
23
24pub use error::{CudaRustError, Result};
26pub use parser::CudaParser;
27pub use transpiler::{Transpiler, CudaTranspiler};
28pub use runtime::Runtime;
29
30pub use neural_integration::{
34 NeuralBridge, BridgeConfig, NeuralOperation, ActivationFunction as NeuralActivationFunction,
35 SystemCapabilities as NeuralCapabilities, initialize as init_neural_integration,
36 get_capabilities as get_neural_capabilities,
37};
38
39pub struct CudaRust {
41 parser: CudaParser,
42 transpiler: Transpiler,
43}
44
45impl CudaRust {
46 pub fn new() -> Self {
48 Self {
49 parser: CudaParser::new(),
50 transpiler: Transpiler::new(),
51 }
52 }
53
54 pub fn transpile(&self, cuda_source: &str) -> Result<String> {
56 let ast = self.parser.parse(cuda_source)?;
58
59 let rust_code = self.transpiler.transpile(ast)?;
61
62 Ok(rust_code)
63 }
64
65 #[cfg(feature = "webgpu-only")]
67 pub fn to_webgpu(&self, cuda_source: &str) -> Result<String> {
68 let ast = self.parser.parse(cuda_source)?;
69 let wgsl = self.transpiler.to_wgsl(ast)?;
70 Ok(wgsl)
71 }
72}
73
74impl Default for CudaRust {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80pub fn init() -> Result<Runtime> {
82 Runtime::new()
83}
84
85#[cfg(target_arch = "wasm32")]
87pub mod wasm {
88 use wasm_bindgen::prelude::*;
89
90 #[wasm_bindgen(start)]
92 pub fn init_wasm() {
93 console_error_panic_hook::set_once();
94
95 #[cfg(feature = "debug-transpiler")]
96 {
97 console_log::init_with_level(log::Level::Debug).ok();
98 }
99 }
100
101 #[wasm_bindgen]
103 pub fn transpile_cuda(cuda_code: &str) -> Result<String, JsValue> {
104 let transpiler = super::CudaRust::new();
105 transpiler.transpile(cuda_code)
106 .map_err(|e| JsValue::from_str(&e.to_string()))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn test_basic_transpilation() {
116 let cuda_rust = CudaRust::new();
117 let cuda_code = r#"
118 __global__ void add(float* a, float* b, float* c) {
119 int i = threadIdx.x;
120 c[i] = a[i] + b[i];
121 }
122 "#;
123
124 let result = cuda_rust.transpile(cuda_code);
125 assert!(result.is_ok());
126 }
127}