1pub mod device;
4pub mod memory;
5pub mod kernel;
6pub mod stream;
7pub mod event;
8pub mod grid;
9pub mod cooperative_groups;
10pub mod dynamic_parallelism;
11pub mod cuda_graph;
12pub mod multi_gpu;
13pub mod half;
14pub mod bfloat16;
15pub mod benchmark;
16pub mod flash_attention;
17pub mod tensor_ops;
18pub mod kernel_fusion;
19pub mod occupancy;
20pub mod async_pipeline;
21pub mod quantization;
22pub mod warp_intrinsics;
23pub mod coalescing;
24
25use crate::{Result, runtime_error};
26use std::cell::RefCell;
27use std::sync::Arc;
28
29pub use grid::{Grid, Block, Dim3};
30pub use device::{Device, BackendType};
31pub use stream::Stream;
32pub use event::Event;
33pub use kernel::{launch_kernel, LaunchConfig, KernelFunction, ThreadContext};
34
35#[derive(Debug, Clone)]
43pub struct KernelContext {
44 pub thread_idx: Dim3,
46 pub block_idx: Dim3,
48 pub block_dim: Dim3,
50 pub grid_dim: Dim3,
52 pub barrier: Option<Arc<std::sync::Barrier>>,
54}
55
56thread_local! {
57 static KERNEL_CONTEXT: RefCell<Option<KernelContext>> = RefCell::new(None);
58}
59
60pub fn set_kernel_context(ctx: KernelContext) {
65 KERNEL_CONTEXT.with(|c| {
66 *c.borrow_mut() = Some(ctx);
67 });
68}
69
70pub fn clear_kernel_context() {
74 KERNEL_CONTEXT.with(|c| {
75 *c.borrow_mut() = None;
76 });
77}
78
79pub fn with_kernel_context<F, R>(ctx: KernelContext, f: F) -> R
85where
86 F: FnOnce() -> R,
87{
88 set_kernel_context(ctx);
89 let result = f();
90 clear_kernel_context();
91 result
92}
93
94pub struct Runtime {
98 device: Arc<Device>,
100 default_stream: Stream,
102}
103
104impl Runtime {
105 pub fn new() -> Result<Self> {
107 let device = Device::get_default()?;
108 let default_stream = Stream::new(device.clone())?;
109
110 Ok(Self {
111 device,
112 default_stream,
113 })
114 }
115
116 pub fn device(&self) -> &Arc<Device> {
118 &self.device
119 }
120
121 pub fn default_stream(&self) -> &Stream {
123 &self.default_stream
124 }
125
126 pub fn create_stream(&self) -> Result<Stream> {
128 Stream::new(self.device.clone())
129 }
130
131 pub fn synchronize(&self) -> Result<()> {
133 self.default_stream.synchronize()
134 }
135}
136
137pub mod thread {
141 use super::grid::Dim3;
142 use super::KERNEL_CONTEXT;
143
144 pub fn index() -> Dim3 {
149 KERNEL_CONTEXT.with(|c| {
150 c.borrow()
151 .as_ref()
152 .map(|ctx| ctx.thread_idx)
153 .unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
154 })
155 }
156}
157
158pub mod block {
162 use super::grid::Dim3;
163 use super::KERNEL_CONTEXT;
164
165 pub fn index() -> Dim3 {
170 KERNEL_CONTEXT.with(|c| {
171 c.borrow()
172 .as_ref()
173 .map(|ctx| ctx.block_idx)
174 .unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
175 })
176 }
177
178 pub fn dim() -> Dim3 {
184 KERNEL_CONTEXT.with(|c| {
185 c.borrow()
186 .as_ref()
187 .map(|ctx| ctx.block_dim)
188 .unwrap_or(Dim3 { x: 256, y: 1, z: 1 })
189 })
190 }
191}
192
193pub mod grid_dim {
197 use super::grid::Dim3;
198 use super::KERNEL_CONTEXT;
199
200 pub fn dim() -> Dim3 {
205 KERNEL_CONTEXT.with(|c| {
206 c.borrow()
207 .as_ref()
208 .map(|ctx| ctx.grid_dim)
209 .unwrap_or(Dim3 { x: 1, y: 1, z: 1 })
210 })
211 }
212}
213
214pub fn sync_threads() {
222 KERNEL_CONTEXT.with(|c| {
223 if let Some(ref ctx) = *c.borrow() {
224 if let Some(ref barrier) = ctx.barrier {
225 barrier.wait();
226 }
227 }
228 });
229}
230
231#[cfg(test)]
234mod context_tests {
235 use super::*;
236
237 #[test]
238 fn test_defaults_without_context() {
239 clear_kernel_context();
241
242 assert_eq!(thread::index(), Dim3 { x: 0, y: 0, z: 0 });
243 assert_eq!(block::index(), Dim3 { x: 0, y: 0, z: 0 });
244 assert_eq!(block::dim(), Dim3 { x: 256, y: 1, z: 1 });
245 assert_eq!(grid_dim::dim(), Dim3 { x: 1, y: 1, z: 1 });
246 }
247
248 #[test]
249 fn test_kernel_context() {
250 let ctx = KernelContext {
251 thread_idx: Dim3 { x: 5, y: 3, z: 0 },
252 block_idx: Dim3 { x: 2, y: 1, z: 0 },
253 block_dim: Dim3 { x: 128, y: 4, z: 1 },
254 grid_dim: Dim3 { x: 10, y: 10, z: 1 },
255 barrier: None,
256 };
257
258 with_kernel_context(ctx, || {
259 assert_eq!(thread::index().x, 5);
260 assert_eq!(thread::index().y, 3);
261 assert_eq!(thread::index().z, 0);
262 assert_eq!(block::index().x, 2);
263 assert_eq!(block::index().y, 1);
264 assert_eq!(block::dim().x, 128);
265 assert_eq!(block::dim().y, 4);
266 assert_eq!(grid_dim::dim().x, 10);
267 });
268
269 assert_eq!(thread::index().x, 0);
271 assert_eq!(block::index().x, 0);
272 assert_eq!(block::dim().x, 256);
273 }
274
275 #[test]
276 fn test_set_and_clear_context() {
277 let ctx = KernelContext {
278 thread_idx: Dim3 { x: 7, y: 0, z: 0 },
279 block_idx: Dim3 { x: 3, y: 0, z: 0 },
280 block_dim: Dim3 { x: 64, y: 1, z: 1 },
281 grid_dim: Dim3 { x: 8, y: 1, z: 1 },
282 barrier: None,
283 };
284
285 set_kernel_context(ctx);
286 assert_eq!(thread::index().x, 7);
287 assert_eq!(block::index().x, 3);
288
289 clear_kernel_context();
290 assert_eq!(thread::index().x, 0);
291 assert_eq!(block::index().x, 0);
292 }
293
294 #[test]
295 fn test_context_override() {
296 let ctx1 = KernelContext {
297 thread_idx: Dim3 { x: 1, y: 0, z: 0 },
298 block_idx: Dim3 { x: 0, y: 0, z: 0 },
299 block_dim: Dim3 { x: 32, y: 1, z: 1 },
300 grid_dim: Dim3 { x: 1, y: 1, z: 1 },
301 barrier: None,
302 };
303 let ctx2 = KernelContext {
304 thread_idx: Dim3 { x: 99, y: 0, z: 0 },
305 block_idx: Dim3 { x: 50, y: 0, z: 0 },
306 block_dim: Dim3 { x: 512, y: 1, z: 1 },
307 grid_dim: Dim3 { x: 4, y: 1, z: 1 },
308 barrier: None,
309 };
310
311 set_kernel_context(ctx1);
312 assert_eq!(thread::index().x, 1);
313
314 set_kernel_context(ctx2);
316 assert_eq!(thread::index().x, 99);
317 assert_eq!(block::dim().x, 512);
318
319 clear_kernel_context();
320 }
321
322 #[test]
323 fn test_sync_threads_no_barrier() {
324 let ctx = KernelContext {
325 thread_idx: Dim3 { x: 0, y: 0, z: 0 },
326 block_idx: Dim3 { x: 0, y: 0, z: 0 },
327 block_dim: Dim3 { x: 1, y: 1, z: 1 },
328 grid_dim: Dim3 { x: 1, y: 1, z: 1 },
329 barrier: None,
330 };
331
332 with_kernel_context(ctx, || {
333 sync_threads();
335 });
336 }
337
338 #[test]
339 fn test_sync_threads_with_barrier() {
340 use std::sync::Barrier;
341
342 let num_threads: u32 = 4;
343 let barrier = Arc::new(Barrier::new(num_threads as usize));
344
345 let handles: Vec<_> = (0..num_threads)
346 .map(|tid| {
347 let b = Arc::clone(&barrier);
348 std::thread::spawn(move || {
349 let ctx = KernelContext {
350 thread_idx: Dim3 { x: tid, y: 0, z: 0 },
351 block_idx: Dim3 { x: 0, y: 0, z: 0 },
352 block_dim: Dim3 { x: num_threads, y: 1, z: 1 },
353 grid_dim: Dim3 { x: 1, y: 1, z: 1 },
354 barrier: Some(b),
355 };
356
357 with_kernel_context(ctx, || {
358 sync_threads();
360 thread::index().x
361 })
362 })
363 })
364 .collect();
365
366 let mut results: Vec<u32> = handles
367 .into_iter()
368 .map(|h| h.join().expect("thread should not panic"))
369 .collect();
370 results.sort();
371 assert_eq!(results, vec![0, 1, 2, 3]);
372 }
373
374 #[test]
375 fn test_sync_threads_no_context() {
376 clear_kernel_context();
377 sync_threads();
379 }
380}