cubek_convolution/launch/
inputs.rs1use cubecl::{Runtime, prelude::TensorBinding};
2use cubek_std::InputBinding;
3
4use crate::components::ConvolutionOperation;
5
6#[derive(Clone, Debug)]
8pub struct ConvolutionArgs<const N_SPATIAL: usize> {
9 pub stride: [usize; N_SPATIAL],
10 pub padding: [usize; N_SPATIAL],
11 pub dilation: [usize; N_SPATIAL],
12}
13
14#[allow(clippy::large_enum_variant)]
15pub enum ConvolutionInputs<R: Runtime> {
20 Forward {
21 input: InputBinding<R>,
22 weight: InputBinding<R>,
23 bias: Option<InputBinding<R>>,
24 out: TensorBinding<R>,
25 },
26 BackwardData {
27 out_grad: InputBinding<R>,
28 weights: InputBinding<R>,
29 in_grad: TensorBinding<R>,
30 },
31 BackwardWeight {
32 input: InputBinding<R>,
33 out_grad: InputBinding<R>,
34 weight_grad: TensorBinding<R>,
35 },
36}
37
38impl<R: Runtime> ConvolutionInputs<R> {
39 pub fn operation(&self) -> ConvolutionOperation {
40 match self {
41 ConvolutionInputs::Forward { .. } => ConvolutionOperation::Forward,
42 ConvolutionInputs::BackwardData { .. } => ConvolutionOperation::BackwardData,
43 ConvolutionInputs::BackwardWeight { .. } => ConvolutionOperation::BackwardWeight,
44 }
45 }
46}