Skip to main content

conv2d/
conv2d.rs

1//! Minimal cuDNN 2-D forward convolution demo.
2//!
3//! Builds NCHW tensor + filter + convolution descriptors, queries the
4//! required workspace, runs `cudnnConvolutionForward`, then verifies the
5//! output sum against a known fixture (all-ones input, all-ones filter
6//! → output cell value equals the receptive-field overlap).
7//!
8//! Run with:
9//!
10//! ```text
11//! cargo run --example conv2d -p baracuda-cudnn
12//! ```
13
14use baracuda_cudnn::{
15    convolution_forward, convolution_forward_workspace_size, ConvMode,
16    ConvolutionDescriptor, DType, FilterDescriptor, FwdAlgo, Handle, TensorDescriptor,
17    TensorFormat,
18};
19use baracuda_driver::{Context, Device, DeviceBuffer};
20
21fn main() -> Result<(), Box<dyn std::error::Error>> {
22    baracuda_driver::init()?;
23    let device = Device::get(0)?;
24    println!("device: {}", device.name()?);
25    let ctx = Context::new(&device)?;
26    let cudnn = Handle::new()?;
27    println!("cuDNN version (packed): {}", baracuda_cudnn::version()?);
28
29    // Shapes: NCHW 1×1×8×8 input, 1 output channel, 3×3 kernel,
30    // padding 1, stride 1, dilation 1 → output is the same 8×8.
31    let (n, c, h, w) = (1, 1, 8, 8);
32    let (k, kh, kw) = (1, 3, 3);
33    let (pad_h, pad_w) = (1, 1);
34    let (str_h, str_w) = (1, 1);
35    let (dil_h, dil_w) = (1, 1);
36
37    let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
38    let w_desc = FilterDescriptor::new_4d(TensorFormat::Nchw, DType::F32, k, c, kh, kw)?;
39    let conv = ConvolutionDescriptor::new_2d(
40        pad_h,
41        pad_w,
42        str_h,
43        str_w,
44        dil_h,
45        dil_w,
46        ConvMode::CrossCorrelation,
47        DType::F32,
48    )?;
49
50    // Compute output dims (should equal `h`, `w` for pad=1, k=3, stride=1).
51    let (on, oc, oh, ow) = conv.output_dim_2d(&x_desc, &w_desc)?;
52    println!("output dim: {on}×{oc}×{oh}×{ow}");
53    let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, on, oc, oh, ow)?;
54
55    let algo = FwdAlgo::ImplicitGemm;
56    let ws_bytes = convolution_forward_workspace_size(&cudnn, &x_desc, &w_desc, &conv, &y_desc, algo)?;
57    println!("workspace bytes: {ws_bytes}");
58
59    // Inputs: all ones in X and W → each output cell = number of valid
60    // (in-bounds, non-padded) (kh, kw, c) taps contributing to it. For a
61    // padded interior cell that's 3×3 = 9; corners get 4; edges get 6.
62    let x_host = vec![1.0f32; (n * c * h * w) as usize];
63    let w_host = vec![1.0f32; (k * c * kh * kw) as usize];
64    let x_buf = DeviceBuffer::from_slice(&ctx, &x_host)?;
65    let w_buf = DeviceBuffer::from_slice(&ctx, &w_host)?;
66    let mut y_buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (on * oc * oh * ow) as usize)?;
67    let mut ws: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, ws_bytes.max(1))?;
68
69    convolution_forward(
70        &cudnn,
71        1.0,
72        &x_desc,
73        &x_buf,
74        &w_desc,
75        &w_buf,
76        &conv,
77        algo,
78        &mut ws,
79        0.0,
80        &y_desc,
81        &mut y_buf,
82    )?;
83
84    let mut y_host = vec![0.0f32; (on * oc * oh * ow) as usize];
85    y_buf.copy_to_host(&mut y_host)?;
86
87    // Corners contribute 4, edges 6, interior 9; sum over an 8×8 image:
88    //   4 corners × 4 + (4 edges × 6 cells) × 6 + (6×6 interior) × 9
89    // = 16 + 144 + 324 = 484.
90    let sum: f32 = y_host.iter().sum();
91    println!("output sum = {sum}  (expected 484 for all-ones 8×8 with 3×3 pad-1)");
92    assert!((sum - 484.0).abs() < 1e-2, "unexpected conv2d sum {sum}");
93    println!("OK");
94    Ok(())
95}