Trait dfdx::tensor_ops::TryConv2D
source · pub trait TryConv2D<Stride, Padding, Dilation, Groups>: Sized {
type Convolved;
type Error: Debug;
// Required method
fn try_conv2d(
self,
stride: Stride,
padding: Padding,
dilation: Dilation,
groups: Groups
) -> Result<Self::Convolved, Self::Error>;
// Provided method
fn conv2d(
self,
stride: Stride,
padding: Padding,
dilation: Dilation,
groups: Groups
) -> Self::Convolved { ... }
}
Expand description
Apply the 2d convolution to a tensor.
Const dims require nightly:
ⓘ
#![feature(generic_const_exprs)]
let x: Tensor<Rank4<2, 3, 32, 32>, f32, _> = dev.sample_normal();
let w: Tensor<Rank4<6, 3, 3, 3>, f32, _> = dev.sample_normal();
let y = (x, w).conv2d(
Const::<1>, // stride
Const::<0>, // padding
Const::<1>, // dilation
Const::<1>, // groups
);
usize dims can be used on stable:
let x: Tensor<_, f32, _> = dev.sample_normal_like(&(
2, // batch size
3, // input channels
32, // height
32, // width
));
let w: Tensor<_, f32, _> = dev.sample_normal_like(&(
6, // output channels
3, // input channels
3, // kernel size
3, // kernel size
));
let y = (x, w).conv2d(
1, // stride
0, // padding
1, // dilation
1, // groups
);
Required Associated Types§
Required Methods§
sourcefn try_conv2d(
self,
stride: Stride,
padding: Padding,
dilation: Dilation,
groups: Groups
) -> Result<Self::Convolved, Self::Error>
fn try_conv2d( self, stride: Stride, padding: Padding, dilation: Dilation, groups: Groups ) -> Result<Self::Convolved, Self::Error>
Fallibly applies a 2D convolution to the input tensor.