Skip to main content

matmul_kernel/
matmul-kernel.rs

1use cudarc::driver::{CudaContext, DriverError, LaunchConfig, PushKernelArg};
2use cudarc::nvrtc::compile_ptx;
3
4const PTX_SRC: &str = "
5extern \"C\" __global__ void matmul(float* A, float* B, float* C, int N) {
6    int ROW = blockIdx.y*blockDim.y+threadIdx.y;
7    int COL = blockIdx.x*blockDim.x+threadIdx.x;
8
9    float tmpSum = 0;
10
11    if (ROW < N && COL < N) {
12        // each thread computes one element of the block sub-matrix
13        for (int i = 0; i < N; i++) {
14            tmpSum += A[ROW * N + i] * B[i * N + COL];
15        }
16    }
17    // printf(\"pos, (%d, %d) - N %d - value %d\\n\", ROW, COL, N, tmpSum);
18    C[ROW * N + COL] = tmpSum;
19}
20";
21
22fn main() -> Result<(), DriverError> {
23    let start = std::time::Instant::now();
24
25    let ptx = compile_ptx(PTX_SRC).unwrap();
26    println!("Compilation succeeded in {:?}", start.elapsed());
27
28    let ctx = CudaContext::new(0)?;
29    let stream = ctx.default_stream();
30    println!("Built in {:?}", start.elapsed());
31
32    let module = ctx.load_module(ptx)?;
33    let f = module.load_function("matmul")?;
34    println!("Loaded in {:?}", start.elapsed());
35
36    let a_host = [1.0f32, 2.0, 3.0, 4.0];
37    let b_host = [1.0f32, 2.0, 3.0, 4.0];
38    let mut c_host = [0.0f32; 4];
39
40    let a_dev = stream.clone_htod(&a_host)?;
41    let b_dev = stream.clone_htod(&b_host)?;
42    let mut c_dev = stream.clone_htod(&c_host)?;
43
44    println!("Copied in {:?}", start.elapsed());
45
46    let mut builder = stream.launch_builder(&f);
47    builder.arg(&a_dev);
48    builder.arg(&b_dev);
49    builder.arg(&mut c_dev);
50    builder.arg(&2i32);
51    let cfg = LaunchConfig {
52        block_dim: (2, 2, 1),
53        grid_dim: (1, 1, 1),
54        shared_mem_bytes: 0,
55    };
56    unsafe { builder.launch(cfg) }?;
57
58    stream.memcpy_dtoh(&c_dev, &mut c_host)?;
59    println!("Found {:?} in {:?}", c_host, start.elapsed());
60    Ok(())
61}