Skip to main content

05_device_repr/
05-device-repr.rs

1use cudarc::{driver::*, nvrtc::compile_ptx};
2
3/// Here's the struct in rust, note that we have #[repr(C)]
4/// here which allows us to pass it to cuda.
5#[repr(C)]
6struct MyCoolRustStruct {
7    a: f32,
8    b: f64,
9    c: u32,
10    d: usize,
11}
12
13/// We have to implement this to send it to cuda!
14unsafe impl DeviceRepr for MyCoolRustStruct {}
15
16const PTX_SRC: &str = "
17// here's the same struct in cuda
18struct MyCoolStruct {
19    float a;
20    double b;
21    unsigned int c;
22    size_t d;
23};
24extern \"C\" __global__ void my_custom_kernel(MyCoolStruct thing) {
25    assert(thing.a == 1.0);
26    assert(thing.b == 2.34);
27    assert(thing.c == 57);
28    assert(thing.d == 420);
29}
30";
31
32fn main() -> Result<(), DriverError> {
33    let ctx = CudaContext::new(0)?;
34    let stream = ctx.default_stream();
35
36    let ptx = compile_ptx(PTX_SRC).unwrap();
37    let module = ctx.load_module(ptx)?;
38    let f = module.load_function("my_custom_kernel")?;
39
40    // try changing some of these values to see a device assert
41    let thing = MyCoolRustStruct {
42        a: 1.0,
43        b: 2.34,
44        c: 57,
45        d: 420,
46    };
47
48    let mut builder = stream.launch_builder(&f);
49    // since MyCoolRustStruct implements DeviceRepr, we can pass it to launch.
50    builder.arg(&thing);
51    unsafe { builder.launch(LaunchConfig::for_num_elems(1)) }?;
52
53    Ok(())
54}