10_function_attributes/
10-function-attributes.rs

1use cudarc::{
2    driver::{CudaContext, DriverError},
3    nvrtc::Ptx,
4};
5
6fn main() -> Result<(), DriverError> {
7    let ctx = CudaContext::new(0)?;
8
9    println!("Device: {}", ctx.name()?);
10    println!();
11
12    // Load the module with the sin_kernel
13    let module = ctx.load_module(Ptx::from_file("./examples/sin.ptx"))?;
14    let sin_kernel = module.load_function("sin_kernel")?;
15
16    // Query function attributes
17    println!("=== Function Attributes for 'sin_kernel' ===");
18    println!();
19
20    println!("Resource Usage:");
21    println!("  Registers per thread:     {}", sin_kernel.num_regs()?);
22    println!(
23        "  Static shared memory:     {} bytes",
24        sin_kernel.shared_size_bytes()?
25    );
26    println!(
27        "  Constant memory:          {} bytes",
28        sin_kernel.const_size_bytes()?
29    );
30    println!(
31        "  Local memory per thread:  {} bytes",
32        sin_kernel.local_size_bytes()?
33    );
34    println!();
35
36    println!("Limits:");
37    println!(
38        "  Max threads per block:    {}",
39        sin_kernel.max_threads_per_block()?
40    );
41    println!();
42
43    println!("Compilation Info:");
44    let ptx_ver = sin_kernel.ptx_version()?;
45    let bin_ver = sin_kernel.binary_version()?;
46    println!(
47        "  PTX version:              {}.{}",
48        ptx_ver / 10,
49        ptx_ver % 10
50    );
51    println!(
52        "  Binary version:           {}.{}",
53        bin_ver / 10,
54        bin_ver % 10
55    );
56    println!();
57
58    // Use occupancy API to get optimal launch configuration
59    extern "C" fn no_dynamic_smem(_block_size: std::ffi::c_int) -> usize {
60        0
61    }
62    let (min_grid_size, block_size) =
63        sin_kernel.occupancy_max_potential_block_size(no_dynamic_smem, 0, 0, None)?;
64
65    println!("=== Optimal Launch Configuration (sin_kernel) ===");
66    println!("  Suggested block size:     {}", block_size);
67    println!("  Min grid size:            {}", min_grid_size);
68    println!("  Total threads per grid:   {}", min_grid_size * block_size);
69
70    Ok(())
71}