pub fn launch_conv<R: Runtime, MP: MatmulPrecision, Alg: Algorithm, const N_SPATIAL: usize>(
client: &ComputeClient<R::Server, R::Channel>,
input: &TensorHandleRef<'_, R>,
weight: &TensorHandleRef<'_, R>,
bias: &Option<TensorHandleRef<'_, R>>,
out: &TensorHandleRef<'_, R>,
args: ConvolutionArgs<N_SPATIAL>,
) -> Result<(), ConvLaunchError>where
<<Alg as Algorithm>::Args as MatmulArgs>::Input<<MP as MatmulPrecision>::EI>: ConvInputsLaunch,
<<Alg as Algorithm>::Args as MatmulArgs>::Output<<MP as MatmulPrecision>::EO>: ConcreteOutputFactory,
Expand description
Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul components, using the specified algorithm.
input
- The input feature map, layout should be [batches, depth, height, width, in_channels]weight
- The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]out
- The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]bias
- The bias added to each out channeloptions
- The options to use for the convolution