Skip to main content

06_threading/
06-threading.rs

1use cudarc::driver::*;
2use cudarc::nvrtc::compile_ptx;
3
4use std::thread;
5
6const KERNEL_SRC: &str = "
7extern \"C\" __global__ void hello_world(int i) {
8    printf(\"Hello from the cuda kernel in thread %d\\n\", i);
9}
10";
11
12fn main() -> Result<(), DriverError> {
13    {
14        // Option 1: sharing ctx & module between threads
15        thread::scope(|s| {
16            let ptx = compile_ptx(KERNEL_SRC).unwrap();
17            let ctx = CudaContext::new(0)?;
18            let module = ctx.load_module(ptx)?;
19            for i in 0..10i32 {
20                let thread_ctx = ctx.clone();
21                let thread_module = module.clone();
22                s.spawn(move || {
23                    let stream = thread_ctx.default_stream();
24                    let f = thread_module.load_function("hello_world")?;
25                    unsafe {
26                        stream
27                            .launch_builder(&f)
28                            .arg(&i)
29                            .launch(LaunchConfig::for_num_elems(1))
30                    }
31                });
32            }
33            Ok(())
34        })?;
35    }
36
37    {
38        // Option 2: initializing different context in each
39        // Note that this will still schedule to the same stream since we are using the
40        // default stream here on the same device.
41        thread::scope(move |s| {
42            for i in 0..10i32 {
43                s.spawn(move || {
44                    let ptx = compile_ptx(KERNEL_SRC).unwrap();
45                    let ctx = CudaContext::new(0)?;
46                    let module = ctx.load_module(ptx)?;
47                    let stream = ctx.default_stream();
48                    let f = module.load_function("hello_world")?;
49                    unsafe {
50                        stream
51                            .launch_builder(&f)
52                            .arg(&i)
53                            .launch(LaunchConfig::for_num_elems(1))
54                    }
55                });
56            }
57            Ok(())
58        })?;
59    }
60
61    Ok(())
62}