Skip to main content

cuda_rust_wasm/runtime/
mod.rs

1//! CUDA-compatible runtime for Rust
2
3pub mod device;
4pub mod memory;
5pub mod kernel;
6pub mod stream;
7pub mod event;
8pub mod grid;
9pub mod cooperative_groups;
10pub mod dynamic_parallelism;
11pub mod cuda_graph;
12pub mod multi_gpu;
13pub mod half;
14pub mod bfloat16;
15pub mod benchmark;
16pub mod flash_attention;
17pub mod tensor_ops;
18pub mod kernel_fusion;
19pub mod occupancy;
20pub mod async_pipeline;
21pub mod quantization;
22pub mod warp_intrinsics;
23pub mod coalescing;
24
25use crate::{Result, runtime_error};
26use std::cell::RefCell;
27use std::sync::Arc;
28
29pub use grid::{Grid, Block, Dim3};
30pub use device::{Device, BackendType};
31pub use stream::Stream;
32pub use event::Event;
33pub use kernel::{launch_kernel, LaunchConfig, KernelFunction, ThreadContext};
34
35// ── Kernel execution context ──────────────────────────────────────
36
37/// Kernel execution context that mirrors CUDA's built-in variables.
38///
39/// Each thread in a kernel launch receives its own `KernelContext` via
40/// thread-local storage, providing access to `threadIdx`, `blockIdx`,
41/// `blockDim`, and `gridDim` equivalents.
42#[derive(Debug, Clone)]
43pub struct KernelContext {
44    /// Thread index within the block (analogous to CUDA `threadIdx`)
45    pub thread_idx: Dim3,
46    /// Block index within the grid (analogous to CUDA `blockIdx`)
47    pub block_idx: Dim3,
48    /// Dimensions of each block (analogous to CUDA `blockDim`)
49    pub block_dim: Dim3,
50    /// Dimensions of the grid (analogous to CUDA `gridDim`)
51    pub grid_dim: Dim3,
52    /// Optional barrier for `sync_threads()` synchronisation within a block
53    pub barrier: Option<Arc<std::sync::Barrier>>,
54}
55
56thread_local! {
57    static KERNEL_CONTEXT: RefCell<Option<KernelContext>> = RefCell::new(None);
58}
59
60/// Set the kernel context for the current thread.
61///
62/// Subsequent calls to `thread::index()`, `block::index()`, `block::dim()`,
63/// and `sync_threads()` will read from this context.
64pub fn set_kernel_context(ctx: KernelContext) {
65    KERNEL_CONTEXT.with(|c| {
66        *c.borrow_mut() = Some(ctx);
67    });
68}
69
70/// Clear the kernel context for the current thread.
71///
72/// After clearing, the accessor functions return their default values.
73pub fn clear_kernel_context() {
74    KERNEL_CONTEXT.with(|c| {
75        *c.borrow_mut() = None;
76    });
77}
78
79/// Execute a closure with a kernel context set, then clear the context.
80///
81/// This is the preferred way to scope kernel context to a region of code. The
82/// context is guaranteed to be cleared even if the closure panics (via drop
83/// semantics of the thread-local borrow).
84pub fn with_kernel_context<F, R>(ctx: KernelContext, f: F) -> R
85where
86    F: FnOnce() -> R,
87{
88    set_kernel_context(ctx);
89    let result = f();
90    clear_kernel_context();
91    result
92}
93
94// ── Main runtime context ──────────────────────────────────────────
95
96/// Main runtime context
97pub struct Runtime {
98    /// Current device
99    device: Arc<Device>,
100    /// Default stream
101    default_stream: Stream,
102}
103
104impl Runtime {
105    /// Create a new runtime instance
106    pub fn new() -> Result<Self> {
107        let device = Device::get_default()?;
108        let default_stream = Stream::new(device.clone())?;
109
110        Ok(Self {
111            device,
112            default_stream,
113        })
114    }
115
116    /// Get the current device
117    pub fn device(&self) -> &Arc<Device> {
118        &self.device
119    }
120
121    /// Get the default stream
122    pub fn default_stream(&self) -> &Stream {
123        &self.default_stream
124    }
125
126    /// Create a new stream
127    pub fn create_stream(&self) -> Result<Stream> {
128        Stream::new(self.device.clone())
129    }
130
131    /// Synchronize all operations
132    pub fn synchronize(&self) -> Result<()> {
133        self.default_stream.synchronize()
134    }
135}
136
137// ── Thread index access ───────────────────────────────────────────
138
139/// Thread index access (analogous to CUDA `threadIdx`)
140pub mod thread {
141    use super::grid::Dim3;
142    use super::KERNEL_CONTEXT;
143
144    /// Get current thread index.
145    ///
146    /// Returns the `thread_idx` from the active kernel context, or
147    /// `Dim3 { x: 0, y: 0, z: 0 }` if no context is set.
148    pub fn index() -> Dim3 {
149        KERNEL_CONTEXT.with(|c| {
150            c.borrow()
151                .as_ref()
152                .map(|ctx| ctx.thread_idx)
153                .unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
154        })
155    }
156}
157
158// ── Block index and dimension access ──────────────────────────────
159
160/// Block index and dimension access (analogous to CUDA `blockIdx` / `blockDim`)
161pub mod block {
162    use super::grid::Dim3;
163    use super::KERNEL_CONTEXT;
164
165    /// Get current block index.
166    ///
167    /// Returns the `block_idx` from the active kernel context, or
168    /// `Dim3 { x: 0, y: 0, z: 0 }` if no context is set.
169    pub fn index() -> Dim3 {
170        KERNEL_CONTEXT.with(|c| {
171            c.borrow()
172                .as_ref()
173                .map(|ctx| ctx.block_idx)
174                .unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
175        })
176    }
177
178    /// Get block dimensions.
179    ///
180    /// Returns the `block_dim` from the active kernel context, or
181    /// `Dim3 { x: 256, y: 1, z: 1 }` as a sensible default if no context is
182    /// set.
183    pub fn dim() -> Dim3 {
184        KERNEL_CONTEXT.with(|c| {
185            c.borrow()
186                .as_ref()
187                .map(|ctx| ctx.block_dim)
188                .unwrap_or(Dim3 { x: 256, y: 1, z: 1 })
189        })
190    }
191}
192
193// ── Grid dimension access ─────────────────────────────────────────
194
195/// Grid dimension access (analogous to CUDA `gridDim`)
196pub mod grid_dim {
197    use super::grid::Dim3;
198    use super::KERNEL_CONTEXT;
199
200    /// Get grid dimensions.
201    ///
202    /// Returns the `grid_dim` from the active kernel context, or
203    /// `Dim3 { x: 1, y: 1, z: 1 }` if no context is set.
204    pub fn dim() -> Dim3 {
205        KERNEL_CONTEXT.with(|c| {
206            c.borrow()
207                .as_ref()
208                .map(|ctx| ctx.grid_dim)
209                .unwrap_or(Dim3 { x: 1, y: 1, z: 1 })
210        })
211    }
212}
213
214// ── Thread synchronisation ────────────────────────────────────────
215
216/// Synchronize threads within a block (analogous to CUDA `__syncthreads()`).
217///
218/// If a `Barrier` is present in the current kernel context, all threads in the
219/// block must reach this call before any can proceed. If no barrier is set (e.g.
220/// single-threaded execution), this is a no-op.
221pub fn sync_threads() {
222    KERNEL_CONTEXT.with(|c| {
223        if let Some(ref ctx) = *c.borrow() {
224            if let Some(ref barrier) = ctx.barrier {
225                barrier.wait();
226            }
227        }
228    });
229}
230
231// ── Tests ─────────────────────────────────────────────────────────
232
233#[cfg(test)]
234mod context_tests {
235    use super::*;
236
237    #[test]
238    fn test_defaults_without_context() {
239        // Ensure no leftover context from other tests
240        clear_kernel_context();
241
242        assert_eq!(thread::index(), Dim3 { x: 0, y: 0, z: 0 });
243        assert_eq!(block::index(), Dim3 { x: 0, y: 0, z: 0 });
244        assert_eq!(block::dim(), Dim3 { x: 256, y: 1, z: 1 });
245        assert_eq!(grid_dim::dim(), Dim3 { x: 1, y: 1, z: 1 });
246    }
247
248    #[test]
249    fn test_kernel_context() {
250        let ctx = KernelContext {
251            thread_idx: Dim3 { x: 5, y: 3, z: 0 },
252            block_idx: Dim3 { x: 2, y: 1, z: 0 },
253            block_dim: Dim3 { x: 128, y: 4, z: 1 },
254            grid_dim: Dim3 { x: 10, y: 10, z: 1 },
255            barrier: None,
256        };
257
258        with_kernel_context(ctx, || {
259            assert_eq!(thread::index().x, 5);
260            assert_eq!(thread::index().y, 3);
261            assert_eq!(thread::index().z, 0);
262            assert_eq!(block::index().x, 2);
263            assert_eq!(block::index().y, 1);
264            assert_eq!(block::dim().x, 128);
265            assert_eq!(block::dim().y, 4);
266            assert_eq!(grid_dim::dim().x, 10);
267        });
268
269        // After context cleared, defaults should return
270        assert_eq!(thread::index().x, 0);
271        assert_eq!(block::index().x, 0);
272        assert_eq!(block::dim().x, 256);
273    }
274
275    #[test]
276    fn test_set_and_clear_context() {
277        let ctx = KernelContext {
278            thread_idx: Dim3 { x: 7, y: 0, z: 0 },
279            block_idx: Dim3 { x: 3, y: 0, z: 0 },
280            block_dim: Dim3 { x: 64, y: 1, z: 1 },
281            grid_dim: Dim3 { x: 8, y: 1, z: 1 },
282            barrier: None,
283        };
284
285        set_kernel_context(ctx);
286        assert_eq!(thread::index().x, 7);
287        assert_eq!(block::index().x, 3);
288
289        clear_kernel_context();
290        assert_eq!(thread::index().x, 0);
291        assert_eq!(block::index().x, 0);
292    }
293
294    #[test]
295    fn test_context_override() {
296        let ctx1 = KernelContext {
297            thread_idx: Dim3 { x: 1, y: 0, z: 0 },
298            block_idx: Dim3 { x: 0, y: 0, z: 0 },
299            block_dim: Dim3 { x: 32, y: 1, z: 1 },
300            grid_dim: Dim3 { x: 1, y: 1, z: 1 },
301            barrier: None,
302        };
303        let ctx2 = KernelContext {
304            thread_idx: Dim3 { x: 99, y: 0, z: 0 },
305            block_idx: Dim3 { x: 50, y: 0, z: 0 },
306            block_dim: Dim3 { x: 512, y: 1, z: 1 },
307            grid_dim: Dim3 { x: 4, y: 1, z: 1 },
308            barrier: None,
309        };
310
311        set_kernel_context(ctx1);
312        assert_eq!(thread::index().x, 1);
313
314        // Overwriting with a new context should work
315        set_kernel_context(ctx2);
316        assert_eq!(thread::index().x, 99);
317        assert_eq!(block::dim().x, 512);
318
319        clear_kernel_context();
320    }
321
322    #[test]
323    fn test_sync_threads_no_barrier() {
324        let ctx = KernelContext {
325            thread_idx: Dim3 { x: 0, y: 0, z: 0 },
326            block_idx: Dim3 { x: 0, y: 0, z: 0 },
327            block_dim: Dim3 { x: 1, y: 1, z: 1 },
328            grid_dim: Dim3 { x: 1, y: 1, z: 1 },
329            barrier: None,
330        };
331
332        with_kernel_context(ctx, || {
333            // Should not block or panic when there is no barrier
334            sync_threads();
335        });
336    }
337
338    #[test]
339    fn test_sync_threads_with_barrier() {
340        use std::sync::Barrier;
341
342        let num_threads: u32 = 4;
343        let barrier = Arc::new(Barrier::new(num_threads as usize));
344
345        let handles: Vec<_> = (0..num_threads)
346            .map(|tid| {
347                let b = Arc::clone(&barrier);
348                std::thread::spawn(move || {
349                    let ctx = KernelContext {
350                        thread_idx: Dim3 { x: tid, y: 0, z: 0 },
351                        block_idx: Dim3 { x: 0, y: 0, z: 0 },
352                        block_dim: Dim3 { x: num_threads, y: 1, z: 1 },
353                        grid_dim: Dim3 { x: 1, y: 1, z: 1 },
354                        barrier: Some(b),
355                    };
356
357                    with_kernel_context(ctx, || {
358                        // All threads must reach sync_threads before any can proceed
359                        sync_threads();
360                        thread::index().x
361                    })
362                })
363            })
364            .collect();
365
366        let mut results: Vec<u32> = handles
367            .into_iter()
368            .map(|h| h.join().expect("thread should not panic"))
369            .collect();
370        results.sort();
371        assert_eq!(results, vec![0, 1, 2, 3]);
372    }
373
374    #[test]
375    fn test_sync_threads_no_context() {
376        clear_kernel_context();
377        // Should be a no-op, not panic
378        sync_threads();
379    }
380}