05_device_repr/
05-device-repr.rs1use cudarc::{driver::*, nvrtc::compile_ptx};
2
3#[repr(C)]
6struct MyCoolRustStruct {
7 a: f32,
8 b: f64,
9 c: u32,
10 d: usize,
11}
12
13unsafe impl DeviceRepr for MyCoolRustStruct {}
15
16const PTX_SRC: &str = "
17// here's the same struct in cuda
18struct MyCoolStruct {
19 float a;
20 double b;
21 unsigned int c;
22 size_t d;
23};
24extern \"C\" __global__ void my_custom_kernel(MyCoolStruct thing) {
25 assert(thing.a == 1.0);
26 assert(thing.b == 2.34);
27 assert(thing.c == 57);
28 assert(thing.d == 420);
29}
30";
31
32fn main() -> Result<(), DriverError> {
33 let ctx = CudaContext::new(0)?;
34 let stream = ctx.default_stream();
35
36 let ptx = compile_ptx(PTX_SRC).unwrap();
37 let module = ctx.load_module(ptx)?;
38 let f = module.load_function("my_custom_kernel")?;
39
40 let thing = MyCoolRustStruct {
42 a: 1.0,
43 b: 2.34,
44 c: 57,
45 d: 420,
46 };
47
48 let mut builder = stream.launch_builder(&f);
49 builder.arg(&thing);
51 unsafe { builder.launch(LaunchConfig::for_num_elems(1)) }?;
52
53 Ok(())
54}