Skip to main content

ringkernel_wgpu_codegen/
lib.rs

1//! WGSL code generation from Rust DSL for RingKernel.
2//!
3//! This crate provides transpilation from a restricted Rust DSL to WGSL (WebGPU Shading Language),
4//! enabling developers to write GPU kernels in Rust that target WebGPU-compatible devices.
5//!
6//! # Overview
7//!
8//! The transpiler supports the same DSL as `ringkernel-cuda-codegen`:
9//!
10//! - Primitive types: `f32`, `i32`, `u32`, `bool` (with 64-bit emulation)
11//! - Array slices: `&[T]`, `&mut [T]` → storage buffers
12//! - Arithmetic and comparison operators
13//! - Let bindings and if/else expressions
14//! - For/while/loop constructs
15//! - Stencil intrinsics via `GridPos` context
16//! - Ring kernel support with host-driven persistence emulation
17//!
18//! # Example
19//!
20//! ```ignore
21//! use ringkernel_wgpu_codegen::{transpile_stencil_kernel, StencilConfig};
22//! use syn::parse_quote;
23//!
24//! let func: syn::ItemFn = parse_quote! {
25//!     fn fdtd(p: &[f32], p_prev: &mut [f32], c2: f32, pos: GridPos) {
26//!         let curr = p[pos.idx()];
27//!         let lap = pos.north(p) + pos.south(p) + pos.east(p) + pos.west(p) - 4.0 * curr;
28//!         p_prev[pos.idx()] = 2.0 * curr - p_prev[pos.idx()] + c2 * lap;
29//!     }
30//! };
31//!
32//! let config = StencilConfig::new("fdtd")
33//!     .with_tile_size(16, 16)
34//!     .with_halo(1);
35//!
36//! let wgsl_code = transpile_stencil_kernel(&func, &config)?;
37//! ```
38//!
39//! # WGSL Limitations
40//!
41//! Compared to CUDA, WGSL has some limitations that are handled with workarounds:
42//!
43//! - **No 64-bit atomics**: Emulated using lo/hi u32 pairs
44//! - **No f64**: Downcast to f32 with a warning
45//! - **No persistent kernels**: Emulated with host-driven dispatch loops
46//! - **No warp operations**: Mapped to subgroup operations where available
47//! - **No kernel-to-kernel messaging**: K2K is not supported in WGPU
48
49pub mod bindings;
50pub mod dsl;
51pub mod handler;
52pub mod intrinsics;
53pub mod loops;
54pub mod ring_kernel;
55pub mod shared;
56pub mod stencil;
57pub mod transpiler;
58pub mod types;
59pub mod u64_workarounds;
60pub mod validation;
61
62pub use bindings::{generate_bindings, AccessMode, BindingLayout};
63pub use dsl::*;
64pub use handler::{
65    HandlerCodegenConfig, HandlerParam, HandlerParamKind, HandlerReturnType, HandlerSignature,
66    WgslContextMethod,
67};
68pub use intrinsics::{IntrinsicRegistry, WgslIntrinsic};
69pub use loops::{LoopPattern, RangeInfo};
70pub use ring_kernel::RingKernelConfig;
71pub use shared::{SharedArray, SharedMemoryConfig, SharedMemoryDecl, SharedTile};
72pub use stencil::{Grid, GridPos, StencilConfig, StencilLaunchConfig};
73pub use transpiler::{transpile_function, WgslTranspiler};
74pub use types::{AddressSpace, TypeMapper, WgslType};
75pub use u64_workarounds::U64Helpers;
76pub use validation::{ValidationError, ValidationMode};
77
78use thiserror::Error;
79
80/// Errors that can occur during transpilation.
81#[derive(Error, Debug)]
82pub enum TranspileError {
83    /// Failed to parse Rust code.
84    #[error("Parse error: {0}")]
85    Parse(String),
86
87    /// DSL constraint violation.
88    #[error("Validation error: {0}")]
89    Validation(#[from] ValidationError),
90
91    /// Unsupported Rust construct.
92    #[error("Unsupported construct: {0}")]
93    Unsupported(String),
94
95    /// Type mapping failure.
96    #[error("Type error: {0}")]
97    Type(String),
98
99    /// WGSL-specific limitation.
100    #[error("WGSL limitation: {0}")]
101    WgslLimitation(String),
102}
103
104/// Result type for transpilation operations.
105pub type Result<T> = std::result::Result<T, TranspileError>;
106
107/// Transpile a Rust stencil kernel function to WGSL code.
108///
109/// This is the main entry point for stencil code generation. It takes a parsed
110/// Rust function and stencil configuration, validates the DSL constraints,
111/// and generates equivalent WGSL code.
112///
113/// # Arguments
114///
115/// * `func` - The parsed Rust function (from syn)
116/// * `config` - Stencil kernel configuration
117///
118/// # Returns
119///
120/// The generated WGSL source code as a string.
121///
122/// # Example
123///
124/// ```ignore
125/// use ringkernel_wgpu_codegen::{transpile_stencil_kernel, StencilConfig};
126/// use syn::parse_quote;
127///
128/// let func: syn::ItemFn = parse_quote! {
129///     fn heat(temp: &[f32], temp_new: &mut [f32], alpha: f32, pos: GridPos) {
130///         let t = temp[pos.idx()];
131///         let neighbors = pos.north(temp) + pos.south(temp) + pos.east(temp) + pos.west(temp);
132///         temp_new[pos.idx()] = t + alpha * (neighbors - 4.0 * t);
133///     }
134/// };
135///
136/// let config = StencilConfig::new("heat").with_tile_size(16, 16).with_halo(1);
137/// let wgsl = transpile_stencil_kernel(&func, &config)?;
138/// ```
139pub fn transpile_stencil_kernel(func: &syn::ItemFn, config: &StencilConfig) -> Result<String> {
140    // Validate DSL constraints
141    validation::validate_function(func)?;
142
143    // Create transpiler with stencil config
144    let mut transpiler = WgslTranspiler::new_stencil(config.clone());
145
146    // Generate WGSL code
147    transpiler.transpile_stencil(func)
148}
149
150/// Transpile a Rust function to a WGSL helper function.
151///
152/// This generates a callable function (not a compute entry point) from Rust code.
153pub fn transpile_device_function(func: &syn::ItemFn) -> Result<String> {
154    validation::validate_function(func)?;
155    transpile_function(func)
156}
157
158/// Transpile a Rust function to a WGSL `@compute` kernel.
159///
160/// This generates an compute shader entry point without stencil-specific patterns.
161/// Use DSL functions like `thread_idx_x()`, `block_idx_x()` to access WGSL indices.
162///
163/// # Example
164///
165/// ```ignore
166/// use ringkernel_wgpu_codegen::transpile_global_kernel;
167/// use syn::parse_quote;
168///
169/// let func: syn::ItemFn = parse_quote! {
170///     fn saxpy(x: &[f32], y: &mut [f32], a: f32, n: i32) {
171///         let idx = block_idx_x() * block_dim_x() + thread_idx_x();
172///         if idx >= n { return; }
173///         y[idx as usize] = a * x[idx as usize] + y[idx as usize];
174///     }
175/// };
176///
177/// let wgsl = transpile_global_kernel(&func)?;
178/// // Generates: @compute @workgroup_size(256) fn saxpy(...) { ... }
179/// ```
180pub fn transpile_global_kernel(func: &syn::ItemFn) -> Result<String> {
181    validation::validate_function(func)?;
182    let mut transpiler = WgslTranspiler::new_generic();
183    transpiler.transpile_global_kernel(func)
184}
185
186/// Transpile a Rust handler function to a WGSL ring kernel.
187///
188/// Ring kernels in WGPU are emulated using host-driven dispatch loops since
189/// WebGPU does not support true persistent kernels. The handler function
190/// processes a batch of messages per dispatch.
191///
192/// # Example
193///
194/// ```ignore
195/// use ringkernel_wgpu_codegen::{transpile_ring_kernel, RingKernelConfig};
196/// use syn::parse_quote;
197///
198/// let handler: syn::ItemFn = parse_quote! {
199///     fn process(value: f32) -> f32 {
200///         value * 2.0
201///     }
202/// };
203///
204/// let config = RingKernelConfig::new("processor")
205///     .with_workgroup_size(256)
206///     .with_hlc(true);
207///
208/// let wgsl = transpile_ring_kernel(&handler, &config)?;
209/// ```
210///
211/// # WGSL Limitations
212///
213/// - **No K2K**: Kernel-to-kernel messaging is not supported
214/// - **No persistent loop**: Host must re-dispatch until termination
215/// - **No 64-bit atomics**: Counters use lo/hi u32 pair emulation
216pub fn transpile_ring_kernel(handler: &syn::ItemFn, config: &RingKernelConfig) -> Result<String> {
217    // K2K is not supported in WGPU
218    if config.enable_k2k {
219        return Err(TranspileError::WgslLimitation(
220            "Kernel-to-kernel (K2K) messaging is not supported in WGPU. \
221             WebGPU does not allow GPU-direct communication between compute dispatches. \
222             Use host-mediated messaging instead."
223                .to_string(),
224        ));
225    }
226
227    // Validate handler with generic mode (loops allowed but not required)
228    validation::validate_function_with_mode(handler, ValidationMode::Generic)?;
229
230    // Create transpiler in generic mode for the handler
231    let mut transpiler = WgslTranspiler::new_ring_kernel(config.clone());
232
233    // Transpile the handler into a ring kernel wrapper
234    transpiler.transpile_ring_kernel(handler, config)
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use syn::parse_quote;
241
242    #[test]
243    fn test_transpile_error_display() {
244        let err = TranspileError::Parse("unexpected token".to_string());
245        assert!(err.to_string().contains("Parse error"));
246
247        let err = TranspileError::WgslLimitation("no 64-bit atomics".to_string());
248        assert!(err.to_string().contains("WGSL limitation"));
249    }
250
251    #[test]
252    fn test_k2k_rejected() {
253        let handler: syn::ItemFn = parse_quote! {
254            fn forward(msg: f32) -> f32 {
255                msg
256            }
257        };
258
259        let config = RingKernelConfig::new("forwarder")
260            .with_workgroup_size(64)
261            .with_k2k(true);
262
263        let result = transpile_ring_kernel(&handler, &config);
264        assert!(result.is_err());
265        assert!(matches!(
266            result.unwrap_err(),
267            TranspileError::WgslLimitation(_)
268        ));
269    }
270}