Skip to main content

09_constant_memory/
09-constant-memory.rs

1use cudarc::{
2    driver::{CudaContext, DriverError, LaunchConfig, PushKernelArg},
3    nvrtc::compile_ptx,
4};
5
6fn main() -> Result<(), DriverError> {
7    let ctx = CudaContext::new(0)?;
8    let stream = ctx.default_stream();
9
10    // Load the module containing the kernel with constant memory
11    let ptx = compile_ptx(include_str!("./constant_memory.cu")).expect("compile failure");
12    let module = ctx.load_module(ptx)?;
13
14    // Get the constant memory symbol as a CudaSlice<u8>
15    let mut coefficients_symbol = module.get_global("coefficients", &stream)?;
16    println!(
17        "Constant memory symbol 'coefficients' has {} bytes",
18        coefficients_symbol.len()
19    );
20
21    // Set up polynomial coefficients: 1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3
22    let coefficients = [1.0f32, 2.0, 3.0, 4.0];
23
24    // Transmute the symbol to f32 and copy coefficients to constant memory
25    let mut symbol_f32 = unsafe { coefficients_symbol.transmute_mut::<f32>(4).unwrap() };
26    stream.memcpy_htod(&coefficients, &mut symbol_f32)?;
27
28    // Load the kernel function
29    let polynomial_kernel = module.load_function("polynomial_kernel")?;
30
31    // Prepare input data
32    let input = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
33    let n = input.len();
34
35    // Copy input to device
36    let input_dev = stream.clone_htod(&input)?;
37    let mut output_dev = stream.alloc_zeros::<f32>(n)?;
38
39    // Launch kernel
40    let cfg = LaunchConfig::for_num_elems(n as u32);
41    unsafe {
42        stream
43            .launch_builder(&polynomial_kernel)
44            .arg(&mut output_dev)
45            .arg(&input_dev)
46            .arg(&(n as i32))
47            .launch(cfg)
48    }?;
49
50    // Copy results back
51    let output = stream.clone_dtoh(&output_dev)?;
52
53    // Verify results
54    println!("\nPolynomial evaluation (1.0 + 2.0*x + 3.0*x^2 + 4.0*x^3):");
55    for (i, (&x, &y)) in input.iter().zip(output.iter()).enumerate() {
56        let expected = coefficients[0]
57            + coefficients[1] * x
58            + coefficients[2] * x * x
59            + coefficients[3] * x * x * x;
60        println!("  f({:.1}) = {:.1} (expected {:.1})", x, y, expected);
61        assert!(
62            (y - expected).abs() < 1e-4,
63            "Mismatch at index {}: got {}, expected {}",
64            i,
65            y,
66            expected
67        );
68    }
69
70    println!("\nAll results match expected values!");
71
72    Ok(())
73}