pub struct TensorDescriptor { /* private fields */ }Expand description
A 4-D tensor descriptor.
Implementations§
Source§impl TensorDescriptor
impl TensorDescriptor
Sourcepub fn new_4d(
format: TensorFormat,
dtype: DType,
n: i32,
c: i32,
h: i32,
w: i32,
) -> Result<Self>
pub fn new_4d( format: TensorFormat, dtype: DType, n: i32, c: i32, h: i32, w: i32, ) -> Result<Self>
Describe an N × C × H × W tensor with the given format and dtype.
Examples found in repository?
examples/conv2d.rs (line 37)
21fn main() -> Result<(), Box<dyn std::error::Error>> {
22 baracuda_driver::init()?;
23 let device = Device::get(0)?;
24 println!("device: {}", device.name()?);
25 let ctx = Context::new(&device)?;
26 let cudnn = Handle::new()?;
27 println!("cuDNN version (packed): {}", baracuda_cudnn::version()?);
28
29 // Shapes: NCHW 1×1×8×8 input, 1 output channel, 3×3 kernel,
30 // padding 1, stride 1, dilation 1 → output is the same 8×8.
31 let (n, c, h, w) = (1, 1, 8, 8);
32 let (k, kh, kw) = (1, 3, 3);
33 let (pad_h, pad_w) = (1, 1);
34 let (str_h, str_w) = (1, 1);
35 let (dil_h, dil_w) = (1, 1);
36
37 let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
38 let w_desc = FilterDescriptor::new_4d(TensorFormat::Nchw, DType::F32, k, c, kh, kw)?;
39 let conv = ConvolutionDescriptor::new_2d(
40 pad_h,
41 pad_w,
42 str_h,
43 str_w,
44 dil_h,
45 dil_w,
46 ConvMode::CrossCorrelation,
47 DType::F32,
48 )?;
49
50 // Compute output dims (should equal `h`, `w` for pad=1, k=3, stride=1).
51 let (on, oc, oh, ow) = conv.output_dim_2d(&x_desc, &w_desc)?;
52 println!("output dim: {on}×{oc}×{oh}×{ow}");
53 let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, on, oc, oh, ow)?;
54
55 let algo = FwdAlgo::ImplicitGemm;
56 let ws_bytes = convolution_forward_workspace_size(&cudnn, &x_desc, &w_desc, &conv, &y_desc, algo)?;
57 println!("workspace bytes: {ws_bytes}");
58
59 // Inputs: all ones in X and W → each output cell = number of valid
60 // (in-bounds, non-padded) (kh, kw, c) taps contributing to it. For a
61 // padded interior cell that's 3×3 = 9; corners get 4; edges get 6.
62 let x_host = vec![1.0f32; (n * c * h * w) as usize];
63 let w_host = vec![1.0f32; (k * c * kh * kw) as usize];
64 let x_buf = DeviceBuffer::from_slice(&ctx, &x_host)?;
65 let w_buf = DeviceBuffer::from_slice(&ctx, &w_host)?;
66 let mut y_buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (on * oc * oh * ow) as usize)?;
67 let mut ws: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, ws_bytes.max(1))?;
68
69 convolution_forward(
70 &cudnn,
71 1.0,
72 &x_desc,
73 &x_buf,
74 &w_desc,
75 &w_buf,
76 &conv,
77 algo,
78 &mut ws,
79 0.0,
80 &y_desc,
81 &mut y_buf,
82 )?;
83
84 let mut y_host = vec![0.0f32; (on * oc * oh * ow) as usize];
85 y_buf.copy_to_host(&mut y_host)?;
86
87 // Corners contribute 4, edges 6, interior 9; sum over an 8×8 image:
88 // 4 corners × 4 + (4 edges × 6 cells) × 6 + (6×6 interior) × 9
89 // = 16 + 144 + 324 = 484.
90 let sum: f32 = y_host.iter().sum();
91 println!("output sum = {sum} (expected 484 for all-ones 8×8 with 3×3 pad-1)");
92 assert!((sum - 484.0).abs() < 1e-2, "unexpected conv2d sum {sum}");
93 println!("OK");
94 Ok(())
95}Sourcepub fn new_nd(dtype: DType, dims: &[i32], strides: &[i32]) -> Result<Self>
pub fn new_nd(dtype: DType, dims: &[i32], strides: &[i32]) -> Result<Self>
Describe an N-dimensional tensor. dims and strides must have the
same length (≤8) and correspond to a valid, non-overlapping
cuDNN-supported layout.
Sourcepub fn as_raw(&self) -> cudnnTensorDescriptor_t
pub fn as_raw(&self) -> cudnnTensorDescriptor_t
Raw descriptor. Use with care.
Trait Implementations§
Source§impl Debug for TensorDescriptor
impl Debug for TensorDescriptor
Source§impl Drop for TensorDescriptor
impl Drop for TensorDescriptor
impl Send for TensorDescriptor
Auto Trait Implementations§
impl !Sync for TensorDescriptor
impl Freeze for TensorDescriptor
impl RefUnwindSafe for TensorDescriptor
impl Unpin for TensorDescriptor
impl UnsafeUnpin for TensorDescriptor
impl UnwindSafe for TensorDescriptor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more