pub mod device;
pub mod memory;
pub mod kernel;
pub mod stream;
pub mod event;
pub mod grid;
pub mod cooperative_groups;
pub mod dynamic_parallelism;
pub mod cuda_graph;
pub mod multi_gpu;
pub mod half;
pub mod bfloat16;
pub mod benchmark;
pub mod flash_attention;
pub mod tensor_ops;
pub mod kernel_fusion;
pub mod occupancy;
pub mod async_pipeline;
pub mod quantization;
pub mod warp_intrinsics;
pub mod coalescing;
use crate::{Result, runtime_error};
use std::cell::RefCell;
use std::sync::Arc;
pub use grid::{Grid, Block, Dim3};
pub use device::{Device, BackendType};
pub use stream::Stream;
pub use event::Event;
pub use kernel::{launch_kernel, LaunchConfig, KernelFunction, ThreadContext};
#[derive(Debug, Clone)]
pub struct KernelContext {
pub thread_idx: Dim3,
pub block_idx: Dim3,
pub block_dim: Dim3,
pub grid_dim: Dim3,
pub barrier: Option<Arc<std::sync::Barrier>>,
}
thread_local! {
static KERNEL_CONTEXT: RefCell<Option<KernelContext>> = RefCell::new(None);
}
pub fn set_kernel_context(ctx: KernelContext) {
KERNEL_CONTEXT.with(|c| {
*c.borrow_mut() = Some(ctx);
});
}
pub fn clear_kernel_context() {
KERNEL_CONTEXT.with(|c| {
*c.borrow_mut() = None;
});
}
pub fn with_kernel_context<F, R>(ctx: KernelContext, f: F) -> R
where
F: FnOnce() -> R,
{
set_kernel_context(ctx);
let result = f();
clear_kernel_context();
result
}
pub struct Runtime {
device: Arc<Device>,
default_stream: Stream,
}
impl Runtime {
pub fn new() -> Result<Self> {
let device = Device::get_default()?;
let default_stream = Stream::new(device.clone())?;
Ok(Self {
device,
default_stream,
})
}
pub fn device(&self) -> &Arc<Device> {
&self.device
}
pub fn default_stream(&self) -> &Stream {
&self.default_stream
}
pub fn create_stream(&self) -> Result<Stream> {
Stream::new(self.device.clone())
}
pub fn synchronize(&self) -> Result<()> {
self.default_stream.synchronize()
}
}
pub mod thread {
use super::grid::Dim3;
use super::KERNEL_CONTEXT;
pub fn index() -> Dim3 {
KERNEL_CONTEXT.with(|c| {
c.borrow()
.as_ref()
.map(|ctx| ctx.thread_idx)
.unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
})
}
}
pub mod block {
use super::grid::Dim3;
use super::KERNEL_CONTEXT;
pub fn index() -> Dim3 {
KERNEL_CONTEXT.with(|c| {
c.borrow()
.as_ref()
.map(|ctx| ctx.block_idx)
.unwrap_or(Dim3 { x: 0, y: 0, z: 0 })
})
}
pub fn dim() -> Dim3 {
KERNEL_CONTEXT.with(|c| {
c.borrow()
.as_ref()
.map(|ctx| ctx.block_dim)
.unwrap_or(Dim3 { x: 256, y: 1, z: 1 })
})
}
}
pub mod grid_dim {
use super::grid::Dim3;
use super::KERNEL_CONTEXT;
pub fn dim() -> Dim3 {
KERNEL_CONTEXT.with(|c| {
c.borrow()
.as_ref()
.map(|ctx| ctx.grid_dim)
.unwrap_or(Dim3 { x: 1, y: 1, z: 1 })
})
}
}
pub fn sync_threads() {
KERNEL_CONTEXT.with(|c| {
if let Some(ref ctx) = *c.borrow() {
if let Some(ref barrier) = ctx.barrier {
barrier.wait();
}
}
});
}
#[cfg(test)]
mod context_tests {
use super::*;
#[test]
fn test_defaults_without_context() {
clear_kernel_context();
assert_eq!(thread::index(), Dim3 { x: 0, y: 0, z: 0 });
assert_eq!(block::index(), Dim3 { x: 0, y: 0, z: 0 });
assert_eq!(block::dim(), Dim3 { x: 256, y: 1, z: 1 });
assert_eq!(grid_dim::dim(), Dim3 { x: 1, y: 1, z: 1 });
}
#[test]
fn test_kernel_context() {
let ctx = KernelContext {
thread_idx: Dim3 { x: 5, y: 3, z: 0 },
block_idx: Dim3 { x: 2, y: 1, z: 0 },
block_dim: Dim3 { x: 128, y: 4, z: 1 },
grid_dim: Dim3 { x: 10, y: 10, z: 1 },
barrier: None,
};
with_kernel_context(ctx, || {
assert_eq!(thread::index().x, 5);
assert_eq!(thread::index().y, 3);
assert_eq!(thread::index().z, 0);
assert_eq!(block::index().x, 2);
assert_eq!(block::index().y, 1);
assert_eq!(block::dim().x, 128);
assert_eq!(block::dim().y, 4);
assert_eq!(grid_dim::dim().x, 10);
});
assert_eq!(thread::index().x, 0);
assert_eq!(block::index().x, 0);
assert_eq!(block::dim().x, 256);
}
#[test]
fn test_set_and_clear_context() {
let ctx = KernelContext {
thread_idx: Dim3 { x: 7, y: 0, z: 0 },
block_idx: Dim3 { x: 3, y: 0, z: 0 },
block_dim: Dim3 { x: 64, y: 1, z: 1 },
grid_dim: Dim3 { x: 8, y: 1, z: 1 },
barrier: None,
};
set_kernel_context(ctx);
assert_eq!(thread::index().x, 7);
assert_eq!(block::index().x, 3);
clear_kernel_context();
assert_eq!(thread::index().x, 0);
assert_eq!(block::index().x, 0);
}
#[test]
fn test_context_override() {
let ctx1 = KernelContext {
thread_idx: Dim3 { x: 1, y: 0, z: 0 },
block_idx: Dim3 { x: 0, y: 0, z: 0 },
block_dim: Dim3 { x: 32, y: 1, z: 1 },
grid_dim: Dim3 { x: 1, y: 1, z: 1 },
barrier: None,
};
let ctx2 = KernelContext {
thread_idx: Dim3 { x: 99, y: 0, z: 0 },
block_idx: Dim3 { x: 50, y: 0, z: 0 },
block_dim: Dim3 { x: 512, y: 1, z: 1 },
grid_dim: Dim3 { x: 4, y: 1, z: 1 },
barrier: None,
};
set_kernel_context(ctx1);
assert_eq!(thread::index().x, 1);
set_kernel_context(ctx2);
assert_eq!(thread::index().x, 99);
assert_eq!(block::dim().x, 512);
clear_kernel_context();
}
#[test]
fn test_sync_threads_no_barrier() {
let ctx = KernelContext {
thread_idx: Dim3 { x: 0, y: 0, z: 0 },
block_idx: Dim3 { x: 0, y: 0, z: 0 },
block_dim: Dim3 { x: 1, y: 1, z: 1 },
grid_dim: Dim3 { x: 1, y: 1, z: 1 },
barrier: None,
};
with_kernel_context(ctx, || {
sync_threads();
});
}
#[test]
fn test_sync_threads_with_barrier() {
use std::sync::Barrier;
let num_threads: u32 = 4;
let barrier = Arc::new(Barrier::new(num_threads as usize));
let handles: Vec<_> = (0..num_threads)
.map(|tid| {
let b = Arc::clone(&barrier);
std::thread::spawn(move || {
let ctx = KernelContext {
thread_idx: Dim3 { x: tid, y: 0, z: 0 },
block_idx: Dim3 { x: 0, y: 0, z: 0 },
block_dim: Dim3 { x: num_threads, y: 1, z: 1 },
grid_dim: Dim3 { x: 1, y: 1, z: 1 },
barrier: Some(b),
};
with_kernel_context(ctx, || {
sync_threads();
thread::index().x
})
})
})
.collect();
let mut results: Vec<u32> = handles
.into_iter()
.map(|h| h.join().expect("thread should not panic"))
.collect();
results.sort();
assert_eq!(results, vec![0, 1, 2, 3]);
}
#[test]
fn test_sync_threads_no_context() {
clear_kernel_context();
sync_threads();
}
}