baracuda_runtime/
launch.rs1use core::ffi::c_void;
4
5use baracuda_cuda_sys::runtime::{cudaStream_t, runtime, types::dim3};
6use baracuda_types::KernelArg;
7
8use crate::error::{check, Result};
9use crate::module::Kernel;
10use crate::stream::Stream;
11
12#[derive(Copy, Clone, Debug, Eq, PartialEq)]
14pub struct Dim3 {
15 pub x: u32,
16 pub y: u32,
17 pub z: u32,
18}
19
20impl Dim3 {
21 #[inline]
22 fn to_sys(self) -> dim3 {
23 dim3::new(self.x, self.y, self.z)
24 }
25}
26
27impl From<u32> for Dim3 {
28 fn from(x: u32) -> Self {
29 Self { x, y: 1, z: 1 }
30 }
31}
32
33impl From<(u32, u32)> for Dim3 {
34 fn from((x, y): (u32, u32)) -> Self {
35 Self { x, y, z: 1 }
36 }
37}
38
39impl From<(u32, u32, u32)> for Dim3 {
40 fn from((x, y, z): (u32, u32, u32)) -> Self {
41 Self { x, y, z }
42 }
43}
44
45impl Kernel {
46 #[inline]
48 pub fn launch(&self) -> LaunchBuilder<'_> {
49 LaunchBuilder {
50 kernel: self,
51 grid: Dim3 { x: 1, y: 1, z: 1 },
52 block: Dim3 { x: 1, y: 1, z: 1 },
53 shared_mem_bytes: 0,
54 stream: None,
55 args: Vec::new(),
56 }
57 }
58}
59
60#[must_use = "the launch builder does nothing until `.launch()` is called"]
62pub struct LaunchBuilder<'k> {
63 kernel: &'k Kernel,
64 grid: Dim3,
65 block: Dim3,
66 shared_mem_bytes: usize,
67 stream: Option<&'k Stream>,
68 args: Vec<*mut c_void>,
69}
70
71impl core::fmt::Debug for LaunchBuilder<'_> {
72 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
73 f.debug_struct("LaunchBuilder")
74 .field("grid", &self.grid)
75 .field("block", &self.block)
76 .field("shared_mem_bytes", &self.shared_mem_bytes)
77 .field("arg_count", &self.args.len())
78 .finish_non_exhaustive()
79 }
80}
81
82impl<'k> LaunchBuilder<'k> {
83 #[inline]
84 pub fn grid(mut self, grid: impl Into<Dim3>) -> Self {
85 self.grid = grid.into();
86 self
87 }
88
89 #[inline]
90 pub fn block(mut self, block: impl Into<Dim3>) -> Self {
91 self.block = block.into();
92 self
93 }
94
95 #[inline]
96 pub fn shared_mem_bytes(mut self, bytes: usize) -> Self {
97 self.shared_mem_bytes = bytes;
98 self
99 }
100
101 #[inline]
102 pub fn stream(mut self, stream: &'k Stream) -> Self {
103 self.stream = Some(stream);
104 self
105 }
106
107 #[inline]
108 pub fn arg<K: KernelArg>(mut self, arg: K) -> Self {
109 self.args.push(arg.as_kernel_arg_ptr());
110 self
111 }
112
113 pub unsafe fn launch(mut self) -> Result<()> { unsafe {
122 let r = runtime()?;
123 let cu = r.cuda_launch_kernel()?;
124 let stream_handle: cudaStream_t = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
125 let args_ptr = if self.args.is_empty() {
126 core::ptr::null_mut()
127 } else {
128 self.args.as_mut_ptr()
129 };
130 check(cu(
131 self.kernel.as_launch_ptr(),
132 self.grid.to_sys(),
133 self.block.to_sys(),
134 args_ptr,
135 self.shared_mem_bytes,
136 stream_handle,
137 ))
138 }}
139
140 pub unsafe fn launch_cooperative(mut self) -> Result<()> { unsafe {
151 let r = runtime()?;
152 let cu = r.cuda_launch_cooperative_kernel()?;
153 let stream_handle: cudaStream_t = self.stream.map_or(core::ptr::null_mut(), |s| s.as_raw());
154 let args_ptr = if self.args.is_empty() {
155 core::ptr::null_mut()
156 } else {
157 self.args.as_mut_ptr()
158 };
159 check(cu(
160 self.kernel.as_launch_ptr(),
161 self.grid.to_sys(),
162 self.block.to_sys(),
163 args_ptr,
164 self.shared_mem_bytes,
165 stream_handle,
166 ))
167 }}
168}