ringkernel_wgpu_codegen/lib.rs
1//! WGSL code generation from Rust DSL for RingKernel.
2//!
3//! This crate provides transpilation from a restricted Rust DSL to WGSL (WebGPU Shading Language),
4//! enabling developers to write GPU kernels in Rust that target WebGPU-compatible devices.
5//!
6//! # Overview
7//!
8//! The transpiler supports the same DSL as `ringkernel-cuda-codegen`:
9//!
10//! - Primitive types: `f32`, `i32`, `u32`, `bool` (with 64-bit emulation)
11//! - Array slices: `&[T]`, `&mut [T]` → storage buffers
12//! - Arithmetic and comparison operators
13//! - Let bindings and if/else expressions
14//! - For/while/loop constructs
15//! - Stencil intrinsics via `GridPos` context
16//! - Ring kernel support with host-driven persistence emulation
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_wgpu_codegen::{transpile_stencil_kernel, StencilConfig};
22//! use syn::parse_quote;
23//!
24//! let func: syn::ItemFn = parse_quote! {
25//! fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
26//! let curr = p[pos.idx()];
27//! let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
28//! p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
29//! }
30//! };
31//!
32//! let config = StencilConfig::new("fdtd")
33//! .with_tile_size(16, 16)
34//! .with_halo(1);
35//!
36//! let wgsl_code = transpile_stencil_kernel(&func, &config)?;
37//! ```
38//!
39//! # WGSL Limitations
40//!
41//! Compared to CUDA, WGSL has some limitations that are handled with workarounds:
42//!
43//! - **No 64-bit atomics**: Emulated using lo/hi u32 pairs
44//! - **No f64**: Downcast to f32 with a warning
45//! - **No persistent kernels**: Emulated with host-driven dispatch loops
46//! - **No warp operations**: Mapped to subgroup operations where available
47//! - **No kernel-to-kernel messaging**: K2K is not supported in WGPU
48
49pub mod bindings;
50pub mod dsl;
51pub mod handler;
52pub mod intrinsics;
53pub mod loops;
54pub mod ring_kernel;
55pub mod shared;
56pub mod stencil;
57pub mod transpiler;
58pub mod types;
59pub mod u64_workarounds;
60pub mod validation;
61
62pub use bindings::{generate_bindings, AccessMode, BindingLayout};
63pub use dsl::*;
64pub use handler::{
65 HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
66 WgslContextMethod,
67};
68pub use intrinsics::{IntrinsicRegistry, WgslIntrinsic};
69pub use loops::{LoopPattern, RangeInfo};
70pub use ring_kernel::RingKernelConfig;
71pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
72pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
73pub use transpiler::{transpile_function, WgslTranspiler};
74pub use types::{AddressSpace, TypeMapper, WgslType};
75pub use u64_workarounds::U64Helpers;
76pub use validation::{ValidationError, ValidationMode};
77
78use thiserror::Error;
79
80/// Errors that can occur during transpilation.
81#[derive(Error, Debug)]
82pub enum TranspileError {
83 /// Failed to parse Rust code.
84 #[error("Parse error: {0}")]
85 Parse(String),
86
87 /// DSL constraint violation.
88 #[error("Validation error: {0}")]
89 Validation(#[from] ValidationError),
90
91 /// Unsupported Rust construct.
92 #[error("Unsupported construct: {0}")]
93 Unsupported(String),
94
95 /// Type mapping failure.
96 #[error("Type error: {0}")]
97 Type(String),
98
99 /// WGSL-specific limitation.
100 #[error("WGSL limitation: {0}")]
101 WgslLimitation(String),
102}
103
104/// Result type for transpilation operations.
105pub type Result<T> = std::result::Result<T, TranspileError>;
106
107/// Transpile a Rust stencil kernel function to WGSL code.
108///
109/// This is the main entry point for stencil code generation. It takes a parsed
110/// Rust function and stencil configuration, validates the DSL constraints,
111/// and generates equivalent WGSL code.
112///
113/// # Arguments
114///
115/// * `func` - The parsed Rust function (from syn)
116/// * `config` - Stencil kernel configuration
117///
118/// # Returns
119///
120/// The generated WGSL source code as a string.
121///
122/// # Example
123///
124/// ```ignore
125/// use ringkernel_wgpu_codegen::{transpile_stencil_kernel, StencilConfig};
126/// use syn::parse_quote;
127///
128/// let func: syn::ItemFn = parse_quote! {
129/// fn heat(temp: &[f32], temp_new: &mut [f32], alpha: f32, pos: GridPos) {
130/// let t = temp[pos.idx()];
131/// let neighbors = pos.north(temp) + pos.south(temp) + pos.east(temp) + pos.west(temp);
132/// temp_new[pos.idx()] = t + alpha * (neighbors - 4.0 * t);
133/// }
134/// };
135///
136/// let config = StencilConfig::new("heat").with_tile_size(16, 16).with_halo(1);
137/// let wgsl = transpile_stencil_kernel(&func, &config)?;
138/// ```
139pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
140 // Validate DSL constraints
141 validation::validate_function(func)?;
142
143 // Create transpiler with stencil config
144 let mut transpiler = WgslTranspiler::new_stencil(config.clone());
145
146 // Generate WGSL code
147 transpiler.transpile_stencil(func)
148}
149
150/// Transpile a Rust function to a WGSL helper function.
151///
152/// This generates a callable function (not a compute entry point) from Rust code.
153pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
154 validation::validate_function(func)?;
155 transpile_function(func)
156}
157
158/// Transpile a Rust function to a WGSL `@compute` kernel.
159///
160/// This generates an compute shader entry point without stencil-specific patterns.
161/// Use DSL functions like `thread_idx_x()`, `block_idx_x()` to access WGSL indices.
162///
163/// # Example
164///
165/// ```ignore
166/// use ringkernel_wgpu_codegen::transpile_global_kernel;
167/// use syn::parse_quote;
168///
169/// let func: syn::ItemFn = parse_quote! {
170/// fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
171/// let idx = block_idx_x() * block_dim_x() + thread_idx_x();
172/// if idx >= n { return; }
173/// y[idx as usize] = a * x[idx as usize] + y[idx as usize];
174/// }
175/// };
176///
177/// let wgsl = transpile_global_kernel(&func)?;
178/// // Generates: @compute @workgroup_size(256) fn saxpy(...) { ... }
179/// ```
180pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
181 validation::validate_function(func)?;
182 let mut transpiler = WgslTranspiler::new_generic();
183 transpiler.transpile_global_kernel(func)
184}
185
186/// Transpile a Rust handler function to a WGSL ring kernel.
187///
188/// Ring kernels in WGPU are emulated using host-driven dispatch loops since
189/// WebGPU does not support true persistent kernels. The handler function
190/// processes a batch of messages per dispatch.
191///
192/// # Example
193///
194/// ```ignore
195/// use ringkernel_wgpu_codegen::{transpile_ring_kernel, RingKernelConfig};
196/// use syn::parse_quote;
197///
198/// let handler: syn::ItemFn = parse_quote! {
199/// fn process(value: f32) -> f32 {
200/// value * 2.0
201/// }
202/// };
203///
204/// let config = RingKernelConfig::new("processor")
205/// .with_workgroup_size(256)
206/// .with_hlc(true);
207///
208/// let wgsl = transpile_ring_kernel(&handler, &config)?;
209/// ```
210///
211/// # WGSL Limitations
212///
213/// - **No K2K**: Kernel-to-kernel messaging is not supported
214/// - **No persistent loop**: Host must re-dispatch until termination
215/// - **No 64-bit atomics**: Counters use lo/hi u32 pair emulation
216pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
217 // K2K is not supported in WGPU
218 if config.enable_k2k {
219 return Err(TranspileError::WgslLimitation(
220 "Kernel-to-kernel (K2K) messaging is not supported in WGPU. \
221 WebGPU does not allow GPU-direct communication between compute dispatches. \
222 Use host-mediated messaging instead."
223 .to_string(),
224 ));
225 }
226
227 // Validate handler with generic mode (loops allowed but not required)
228 validation::validate_function_with_mode(handler, ValidationMode::Generic)?;
229
230 // Create transpiler in generic mode for the handler
231 let mut transpiler = WgslTranspiler::new_ring_kernel(config.clone());
232
233 // Transpile the handler into a ring kernel wrapper
234 transpiler.transpile_ring_kernel(handler, config)
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use syn::parse_quote;
241
242 #[test]
243 fn test_transpile_error_display() {
244 let err = TranspileError::Parse("unexpected token".to_string());
245 assert!(err.to_string().contains("Parse error"));
246
247 let err = TranspileError::WgslLimitation("no 64-bit atomics".to_string());
248 assert!(err.to_string().contains("WGSL limitation"));
249 }
250
251 #[test]
252 fn test_k2k_rejected() {
253 let handler: syn::ItemFn = parse_quote! {
254 fn forward(msg: f32) -> f32 {
255 msg
256 }
257 };
258
259 let config = RingKernelConfig::new("forwarder")
260 .with_workgroup_size(64)
261 .with_k2k(true);
262
263 let result = transpile_ring_kernel(&handler, &config);
264 assert!(result.is_err());
265 assert!(matches!(
266 result.unwrap_err(),
267 TranspileError::WgslLimitation(_)
268 ));
269 }
270}