1use oxicuda_driver::device::Device;
30
31use crate::error::LaunchError;
32use crate::grid::Dim3;
33
34#[derive(Debug, Clone, Copy)]
50pub struct LaunchParams {
51 pub grid: Dim3,
53 pub block: Dim3,
55 pub shared_mem_bytes: u32,
57}
58
59impl LaunchParams {
60 #[inline]
68 pub fn new(grid: impl Into<Dim3>, block: impl Into<Dim3>) -> Self {
69 Self {
70 grid: grid.into(),
71 block: block.into(),
72 shared_mem_bytes: 0,
73 }
74 }
75
76 #[inline]
80 pub fn with_shared_mem(mut self, bytes: u32) -> Self {
81 self.shared_mem_bytes = bytes;
82 self
83 }
84
85 #[inline]
87 pub fn builder() -> LaunchParamsBuilder {
88 LaunchParamsBuilder::default()
89 }
90
91 #[inline]
96 pub fn total_threads(&self) -> u64 {
97 self.grid.total() as u64 * self.block.total() as u64
98 }
99
100 pub fn validate(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
128 self.validate_inner(device)
129 }
130
131 fn validate_inner(&self, device: &Device) -> Result<(), Box<dyn std::error::Error>> {
133 if self.block.x == 0 {
135 return Err(Box::new(LaunchError::InvalidDimension {
136 dim: "block.x",
137 value: 0,
138 }));
139 }
140 if self.block.y == 0 {
141 return Err(Box::new(LaunchError::InvalidDimension {
142 dim: "block.y",
143 value: 0,
144 }));
145 }
146 if self.block.z == 0 {
147 return Err(Box::new(LaunchError::InvalidDimension {
148 dim: "block.z",
149 value: 0,
150 }));
151 }
152 if self.grid.x == 0 {
153 return Err(Box::new(LaunchError::InvalidDimension {
154 dim: "grid.x",
155 value: 0,
156 }));
157 }
158 if self.grid.y == 0 {
159 return Err(Box::new(LaunchError::InvalidDimension {
160 dim: "grid.y",
161 value: 0,
162 }));
163 }
164 if self.grid.z == 0 {
165 return Err(Box::new(LaunchError::InvalidDimension {
166 dim: "grid.z",
167 value: 0,
168 }));
169 }
170
171 let max_threads = device.max_threads_per_block()? as u32;
173 let block_total = self.block.total();
174 if block_total > max_threads {
175 return Err(Box::new(LaunchError::BlockSizeExceedsLimit {
176 requested: block_total,
177 max: max_threads,
178 }));
179 }
180
181 let (max_bx, max_by, max_bz) = device.max_block_dim()?;
183 if self.block.x > max_bx as u32 {
184 return Err(Box::new(LaunchError::InvalidDimension {
185 dim: "block.x",
186 value: self.block.x,
187 }));
188 }
189 if self.block.y > max_by as u32 {
190 return Err(Box::new(LaunchError::InvalidDimension {
191 dim: "block.y",
192 value: self.block.y,
193 }));
194 }
195 if self.block.z > max_bz as u32 {
196 return Err(Box::new(LaunchError::InvalidDimension {
197 dim: "block.z",
198 value: self.block.z,
199 }));
200 }
201
202 let (max_gx, max_gy, max_gz) = device.max_grid_dim()?;
204 if self.grid.x > max_gx as u32 {
205 return Err(Box::new(LaunchError::GridSizeExceedsLimit {
206 requested: self.grid.x,
207 max: max_gx as u32,
208 }));
209 }
210 if self.grid.y > max_gy as u32 {
211 return Err(Box::new(LaunchError::GridSizeExceedsLimit {
212 requested: self.grid.y,
213 max: max_gy as u32,
214 }));
215 }
216 if self.grid.z > max_gz as u32 {
217 return Err(Box::new(LaunchError::GridSizeExceedsLimit {
218 requested: self.grid.z,
219 max: max_gz as u32,
220 }));
221 }
222
223 let max_smem = device.max_shared_memory_per_block()? as u32;
225 if self.shared_mem_bytes > max_smem {
226 return Err(Box::new(LaunchError::SharedMemoryExceedsLimit {
227 requested: self.shared_mem_bytes,
228 max: max_smem,
229 }));
230 }
231
232 Ok(())
233 }
234}
235
236#[derive(Debug, Default)]
257pub struct LaunchParamsBuilder {
258 grid: Option<Dim3>,
260 block: Option<Dim3>,
262 shared_mem_bytes: u32,
264}
265
266impl LaunchParamsBuilder {
267 #[inline]
271 pub fn grid(mut self, dim: impl Into<Dim3>) -> Self {
272 self.grid = Some(dim.into());
273 self
274 }
275
276 #[inline]
280 pub fn block(mut self, dim: impl Into<Dim3>) -> Self {
281 self.block = Some(dim.into());
282 self
283 }
284
285 #[inline]
287 pub fn shared_mem(mut self, bytes: u32) -> Self {
288 self.shared_mem_bytes = bytes;
289 self
290 }
291
292 #[inline]
297 pub fn build(self) -> LaunchParams {
298 LaunchParams {
299 grid: self.grid.unwrap_or(Dim3::x(1)),
300 block: self.block.unwrap_or(Dim3::x(1)),
301 shared_mem_bytes: self.shared_mem_bytes,
302 }
303 }
304}
305
306#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn launch_params_new_basic() {
316 let p = LaunchParams::new(4u32, 256u32);
317 assert_eq!(p.grid, Dim3::x(4));
318 assert_eq!(p.block, Dim3::x(256));
319 assert_eq!(p.shared_mem_bytes, 0);
320 }
321
322 #[test]
323 fn launch_params_new_with_dim3() {
324 let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
325 assert_eq!(p.grid.total(), 16);
326 assert_eq!(p.block.total(), 256);
327 }
328
329 #[test]
330 fn launch_params_new_with_tuples() {
331 let p = LaunchParams::new((4u32, 4u32), (16u32, 16u32));
332 assert_eq!(p.grid, Dim3::xy(4, 4));
333 assert_eq!(p.block, Dim3::xy(16, 16));
334 }
335
336 #[test]
337 fn launch_params_with_shared_mem() {
338 let p = LaunchParams::new(1u32, 256u32).with_shared_mem(8192);
339 assert_eq!(p.shared_mem_bytes, 8192);
340 }
341
342 #[test]
343 fn launch_params_total_threads() {
344 let p = LaunchParams::new(4u32, 256u32);
345 assert_eq!(p.total_threads(), 1024);
346
347 let p = LaunchParams::new(Dim3::xy(4, 4), Dim3::xy(16, 16));
348 assert_eq!(p.total_threads(), 16 * 256);
349 }
350
351 #[test]
352 fn launch_params_total_threads_large() {
353 let p = LaunchParams::new(Dim3::xy(65535, 65535), Dim3::x(1024));
355 let expected = 65535u64 * 65535u64 * 1024u64;
356 assert_eq!(p.total_threads(), expected);
357 }
358
359 #[test]
360 fn builder_defaults() {
361 let p = LaunchParams::builder().build();
362 assert_eq!(p.grid, Dim3::x(1));
363 assert_eq!(p.block, Dim3::x(1));
364 assert_eq!(p.shared_mem_bytes, 0);
365 }
366
367 #[test]
368 fn builder_full() {
369 let p = LaunchParams::builder()
370 .grid(128u32)
371 .block(256u32)
372 .shared_mem(4096)
373 .build();
374 assert_eq!(p.grid, Dim3::x(128));
375 assert_eq!(p.block, Dim3::x(256));
376 assert_eq!(p.shared_mem_bytes, 4096);
377 }
378
379 #[test]
380 fn builder_partial_grid_only() {
381 let p = LaunchParams::builder().grid(64u32).build();
382 assert_eq!(p.grid, Dim3::x(64));
383 assert_eq!(p.block, Dim3::x(1));
384 }
385
386 #[test]
387 fn builder_partial_block_only() {
388 let p = LaunchParams::builder().block(512u32).build();
389 assert_eq!(p.grid, Dim3::x(1));
390 assert_eq!(p.block, Dim3::x(512));
391 }
392
393 #[test]
394 fn builder_with_tuple_dims() {
395 let p = LaunchParams::builder()
396 .grid((8u32, 8u32))
397 .block((16u32, 16u32, 1u32))
398 .build();
399 assert_eq!(p.grid, Dim3::xy(8, 8));
400 assert_eq!(p.block, Dim3::new(16, 16, 1));
401 }
402
403 type ValidateFn = fn(&LaunchParams, &Device) -> Result<(), Box<dyn std::error::Error>>;
404
405 #[test]
406 fn validate_zero_block_x() {
407 let p = LaunchParams {
408 grid: Dim3::x(1),
409 block: Dim3::new(0, 1, 1),
410 shared_mem_bytes: 0,
411 };
412 let _validate_fn: ValidateFn = LaunchParams::validate;
415 assert_eq!(p.block.x, 0);
418 }
419
420 #[test]
421 fn validate_zero_grid_z() {
422 let p = LaunchParams {
423 grid: Dim3::new(1, 1, 0),
424 block: Dim3::x(256),
425 shared_mem_bytes: 0,
426 };
427 assert_eq!(p.grid.z, 0);
428 }
429
430 #[test]
431 fn validate_signature_compiles() {
432 let _: ValidateFn = LaunchParams::validate;
434 }
435
436 #[cfg(feature = "gpu-tests")]
437 #[test]
438 fn validate_with_real_device() {
439 oxicuda_driver::init().ok();
440 if let Ok(dev) = Device::get(0) {
441 let p = LaunchParams::new(4u32, 256u32);
442 assert!(p.validate(&dev).is_ok());
443
444 let p2 = LaunchParams::new(1u32, Dim3::new(1024, 1024, 1));
446 assert!(p2.validate(&dev).is_err());
447 }
448 }
449}