Skip to main content

03_launch_kernel/
03-launch-kernel.rs

1use cudarc::{
2    driver::{CudaContext, DriverError, LaunchConfig, PushKernelArg},
3    nvrtc::Ptx,
4};
5
6fn main() -> Result<(), DriverError> {
7    let ctx = CudaContext::new(0)?;
8    let stream = ctx.default_stream();
9
10    // You can load a function from a pre-compiled PTX like so:
11    let module = ctx.load_module(Ptx::from_file("./examples/sin.ptx"))?;
12
13    // and then load a function from it:
14    let f = module.load_function("sin_kernel").unwrap();
15
16    let a_host = [1.0, 2.0, 3.0];
17
18    let a_dev = stream.clone_htod(&a_host)?;
19    let mut b_dev = a_dev.clone();
20
21    // we use a buidler pattern to launch kernels.
22    let n = 3i32;
23    let cfg = LaunchConfig::for_num_elems(n as u32);
24    let mut launch_args = stream.launch_builder(&f);
25    launch_args.arg(&mut b_dev);
26    launch_args.arg(&a_dev);
27    launch_args.arg(&n);
28    unsafe { launch_args.launch(cfg) }?;
29
30    let a_host_2 = stream.clone_dtoh(&a_dev)?;
31    let b_host = stream.clone_dtoh(&b_dev)?;
32
33    println!("Found {b_host:?}");
34    println!("Expected {:?}", a_host.map(f32::sin));
35    assert_eq!(&a_host, a_host_2.as_slice());
36
37    Ok(())
38}