1use baracuda_cudnn::{
15 convolution_forward, convolution_forward_workspace_size, ConvMode,
16 ConvolutionDescriptor, DType, FilterDescriptor, FwdAlgo, Handle, TensorDescriptor,
17 TensorFormat,
18};
19use baracuda_driver::{Context, Device, DeviceBuffer};
20
21fn main() -> Result<(), Box<dyn std::error::Error>> {
22 baracuda_driver::init()?;
23 let device = Device::get(0)?;
24 println!("device: {}", device.name()?);
25 let ctx = Context::new(&device)?;
26 let cudnn = Handle::new()?;
27 println!("cuDNN version (packed): {}", baracuda_cudnn::version()?);
28
29 let (n, c, h, w) = (1, 1, 8, 8);
32 let (k, kh, kw) = (1, 3, 3);
33 let (pad_h, pad_w) = (1, 1);
34 let (str_h, str_w) = (1, 1);
35 let (dil_h, dil_w) = (1, 1);
36
37 let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
38 let w_desc = FilterDescriptor::new_4d(TensorFormat::Nchw, DType::F32, k, c, kh, kw)?;
39 let conv = ConvolutionDescriptor::new_2d(
40 pad_h,
41 pad_w,
42 str_h,
43 str_w,
44 dil_h,
45 dil_w,
46 ConvMode::CrossCorrelation,
47 DType::F32,
48 )?;
49
50 let (on, oc, oh, ow) = conv.output_dim_2d(&x_desc, &w_desc)?;
52 println!("output dim: {on}×{oc}×{oh}×{ow}");
53 let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, on, oc, oh, ow)?;
54
55 let algo = FwdAlgo::ImplicitGemm;
56 let ws_bytes = convolution_forward_workspace_size(&cudnn, &x_desc, &w_desc, &conv, &y_desc, algo)?;
57 println!("workspace bytes: {ws_bytes}");
58
59 let x_host = vec![1.0f32; (n * c * h * w) as usize];
63 let w_host = vec![1.0f32; (k * c * kh * kw) as usize];
64 let x_buf = DeviceBuffer::from_slice(&ctx, &x_host)?;
65 let w_buf = DeviceBuffer::from_slice(&ctx, &w_host)?;
66 let mut y_buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (on * oc * oh * ow) as usize)?;
67 let mut ws: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, ws_bytes.max(1))?;
68
69 convolution_forward(
70 &cudnn,
71 1.0,
72 &x_desc,
73 &x_buf,
74 &w_desc,
75 &w_buf,
76 &conv,
77 algo,
78 &mut ws,
79 0.0,
80 &y_desc,
81 &mut y_buf,
82 )?;
83
84 let mut y_host = vec![0.0f32; (on * oc * oh * ow) as usize];
85 y_buf.copy_to_host(&mut y_host)?;
86
87 let sum: f32 = y_host.iter().sum();
91 println!("output sum = {sum} (expected 484 for all-ones 8×8 with 3×3 pad-1)");
92 assert!((sum - 484.0).abs() < 1e-2, "unexpected conv2d sum {sum}");
93 println!("OK");
94 Ok(())
95}