Skip to main content

TensorDescriptor

Struct TensorDescriptor 

Source
pub struct TensorDescriptor { /* private fields */ }
Expand description

A 4-D tensor descriptor.

Implementations§

Source§

impl TensorDescriptor

Source

pub fn new_4d( format: TensorFormat, dtype: DType, n: i32, c: i32, h: i32, w: i32, ) -> Result<Self>

Describe an N × C × H × W tensor with the given format and dtype.

Examples found in repository?
examples/conv2d.rs (line 37)
21fn main() -> Result<(), Box<dyn std::error::Error>> {
22    baracuda_driver::init()?;
23    let device = Device::get(0)?;
24    println!("device: {}", device.name()?);
25    let ctx = Context::new(&device)?;
26    let cudnn = Handle::new()?;
27    println!("cuDNN version (packed): {}", baracuda_cudnn::version()?);
28
29    // Shapes: NCHW 1×1×8×8 input, 1 output channel, 3×3 kernel,
30    // padding 1, stride 1, dilation 1 → output is the same 8×8.
31    let (n, c, h, w) = (1, 1, 8, 8);
32    let (k, kh, kw) = (1, 3, 3);
33    let (pad_h, pad_w) = (1, 1);
34    let (str_h, str_w) = (1, 1);
35    let (dil_h, dil_w) = (1, 1);
36
37    let x_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, n, c, h, w)?;
38    let w_desc = FilterDescriptor::new_4d(TensorFormat::Nchw, DType::F32, k, c, kh, kw)?;
39    let conv = ConvolutionDescriptor::new_2d(
40        pad_h,
41        pad_w,
42        str_h,
43        str_w,
44        dil_h,
45        dil_w,
46        ConvMode::CrossCorrelation,
47        DType::F32,
48    )?;
49
50    // Compute output dims (should equal `h`, `w` for pad=1, k=3, stride=1).
51    let (on, oc, oh, ow) = conv.output_dim_2d(&x_desc, &w_desc)?;
52    println!("output dim: {on}×{oc}×{oh}×{ow}");
53    let y_desc = TensorDescriptor::new_4d(TensorFormat::Nchw, DType::F32, on, oc, oh, ow)?;
54
55    let algo = FwdAlgo::ImplicitGemm;
56    let ws_bytes = convolution_forward_workspace_size(&cudnn, &x_desc, &w_desc, &conv, &y_desc, algo)?;
57    println!("workspace bytes: {ws_bytes}");
58
59    // Inputs: all ones in X and W → each output cell = number of valid
60    // (in-bounds, non-padded) (kh, kw, c) taps contributing to it. For a
61    // padded interior cell that's 3×3 = 9; corners get 4; edges get 6.
62    let x_host = vec![1.0f32; (n * c * h * w) as usize];
63    let w_host = vec![1.0f32; (k * c * kh * kw) as usize];
64    let x_buf = DeviceBuffer::from_slice(&ctx, &x_host)?;
65    let w_buf = DeviceBuffer::from_slice(&ctx, &w_host)?;
66    let mut y_buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, (on * oc * oh * ow) as usize)?;
67    let mut ws: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, ws_bytes.max(1))?;
68
69    convolution_forward(
70        &cudnn,
71        1.0,
72        &x_desc,
73        &x_buf,
74        &w_desc,
75        &w_buf,
76        &conv,
77        algo,
78        &mut ws,
79        0.0,
80        &y_desc,
81        &mut y_buf,
82    )?;
83
84    let mut y_host = vec![0.0f32; (on * oc * oh * ow) as usize];
85    y_buf.copy_to_host(&mut y_host)?;
86
87    // Corners contribute 4, edges 6, interior 9; sum over an 8×8 image:
88    //   4 corners × 4 + (4 edges × 6 cells) × 6 + (6×6 interior) × 9
89    // = 16 + 144 + 324 = 484.
90    let sum: f32 = y_host.iter().sum();
91    println!("output sum = {sum}  (expected 484 for all-ones 8×8 with 3×3 pad-1)");
92    assert!((sum - 484.0).abs() < 1e-2, "unexpected conv2d sum {sum}");
93    println!("OK");
94    Ok(())
95}
Source

pub fn new_nd(dtype: DType, dims: &[i32], strides: &[i32]) -> Result<Self>

Describe an N-dimensional tensor. dims and strides must have the same length (≤8) and correspond to a valid, non-overlapping cuDNN-supported layout.

Source

pub fn as_raw(&self) -> cudnnTensorDescriptor_t

Raw descriptor. Use with care.

Source§

impl TensorDescriptor

Source

pub fn new_4d_ex( dtype: DType, n: i32, c: i32, h: i32, w: i32, n_stride: i32, c_stride: i32, h_stride: i32, w_stride: i32, ) -> Result<Self>

Strided 4-D constructor — per-axis strides instead of the row-major / channels-last layouts new_4d implies.

Source

pub fn get_4d(&self) -> Result<(DType, i32, i32, i32, i32, i32, i32, i32, i32)>

Read the 4-D parameters back out: (dtype, n, c, h, w, n_stride, c_stride, h_stride, w_stride).

Trait Implementations§

Source§

impl Debug for TensorDescriptor

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Drop for TensorDescriptor

Source§

fn drop(&mut self)

Executes the destructor for this type. Read more
Source§

fn pin_drop(self: Pin<&mut Self>)

🔬This is a nightly-only experimental API. (pin_ergonomics)
Execute the destructor for this type, but different to Drop::drop, it requires self to be pinned. Read more
Source§

impl Send for TensorDescriptor

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.