cuda_rust_wasm/runtime/
kernel.rs1use crate::{Result, runtime_error};
4use super::{Grid, Block, Device, Stream};
5use std::sync::Arc;
6use std::marker::PhantomData;
7
8pub trait KernelFunction<Args> {
10 fn execute(&self, args: Args, thread_ctx: ThreadContext);
12
13 fn name(&self) -> &str;
15}
16
17#[derive(Debug, Clone, Copy)]
19pub struct ThreadContext {
20 pub thread_idx: super::grid::Dim3,
22 pub block_idx: super::grid::Dim3,
24 pub block_dim: super::grid::Dim3,
26 pub grid_dim: super::grid::Dim3,
28}
29
30impl ThreadContext {
31 pub fn global_thread_id(&self) -> usize {
33 let block_offset = self.block_idx.x as usize * self.block_dim.x as usize;
34 block_offset + self.thread_idx.x as usize
35 }
36
37 pub fn global_thread_id_2d(&self) -> (usize, usize) {
39 let x = self.block_idx.x as usize * self.block_dim.x as usize + self.thread_idx.x as usize;
40 let y = self.block_idx.y as usize * self.block_dim.y as usize + self.thread_idx.y as usize;
41 (x, y)
42 }
43
44 pub fn global_thread_id_3d(&self) -> (usize, usize, usize) {
46 let x = self.block_idx.x as usize * self.block_dim.x as usize + self.thread_idx.x as usize;
47 let y = self.block_idx.y as usize * self.block_dim.y as usize + self.thread_idx.y as usize;
48 let z = self.block_idx.z as usize * self.block_dim.z as usize + self.thread_idx.z as usize;
49 (x, y, z)
50 }
51}
52
53pub struct LaunchConfig {
55 pub grid: Grid,
56 pub block: Block,
57 pub stream: Option<Arc<Stream>>,
58 pub shared_memory_bytes: usize,
59}
60
61impl LaunchConfig {
62 pub fn new(grid: Grid, block: Block) -> Self {
64 Self {
65 grid,
66 block,
67 stream: None,
68 shared_memory_bytes: 0,
69 }
70 }
71
72 pub fn with_stream(mut self, stream: Arc<Stream>) -> Self {
74 self.stream = Some(stream);
75 self
76 }
77
78 pub fn with_shared_memory(mut self, bytes: usize) -> Self {
80 self.shared_memory_bytes = bytes;
81 self
82 }
83}
84
85struct CpuKernelExecutor<K, Args> {
87 kernel: K,
88 phantom: PhantomData<Args>,
89}
90
91impl<K, Args> CpuKernelExecutor<K, Args>
92where
93 K: KernelFunction<Args>,
94 Args: Clone + Send + Sync,
95{
96 fn execute(&self, config: &LaunchConfig, args: Args) -> Result<()> {
97 let total_blocks = config.grid.num_blocks();
98 let threads_per_block = config.block.num_threads();
99
100 for block_id in 0..total_blocks {
103 let block_idx = super::grid::Dim3 {
105 x: block_id % config.grid.dim.x,
106 y: (block_id / config.grid.dim.x) % config.grid.dim.y,
107 z: block_id / (config.grid.dim.x * config.grid.dim.y),
108 };
109
110 for thread_id in 0..threads_per_block {
111 let thread_idx = super::grid::Dim3 {
113 x: thread_id % config.block.dim.x,
114 y: (thread_id / config.block.dim.x) % config.block.dim.y,
115 z: thread_id / (config.block.dim.x * config.block.dim.y),
116 };
117
118 let thread_ctx = ThreadContext {
119 thread_idx,
120 block_idx,
121 block_dim: config.block.dim,
122 grid_dim: config.grid.dim,
123 };
124
125 self.kernel.execute(args.clone(), thread_ctx);
126 }
127 }
128
129 Ok(())
130 }
131}
132
133pub fn launch_kernel<K, Args>(
135 kernel: K,
136 config: LaunchConfig,
137 args: Args,
138) -> Result<()>
139where
140 K: KernelFunction<Args>,
141 Args: Clone + Send + Sync,
142{
143 config.block.validate()?;
145
146 let device = if let Some(ref stream) = config.stream {
148 stream.device()
149 } else {
150 Device::get_default()?
151 };
152
153 match device.backend() {
155 super::BackendType::CPU => {
156 let executor = CpuKernelExecutor {
157 kernel,
158 phantom: PhantomData,
159 };
160 executor.execute(&config, args)?;
161 }
162 super::BackendType::Native => {
163 return Err(runtime_error!("Native GPU backend not yet implemented"));
165 }
166 super::BackendType::WebGPU => {
167 return Err(runtime_error!("WebGPU backend not yet implemented"));
169 }
170 }
171
172 Ok(())
173}
174
175#[macro_export]
177macro_rules! kernel_function {
178 ($name:ident, $args:ty, |$args_pat:pat, $ctx:ident| $body:block) => {
179 struct $name;
180
181 impl $crate::runtime::kernel::KernelFunction<$args> for $name {
182 fn execute(&self, $args_pat: $args, $ctx: $crate::runtime::kernel::ThreadContext) {
183 $body
184 }
185
186 fn name(&self) -> &str {
187 stringify!($name)
188 }
189 }
190 };
191}