Skip to main content

cubek_convolution/launch/
inputs.rs

1use cubecl::{Runtime, prelude::TensorBinding};
2use cubek_std::InputBinding;
3
4use crate::components::ConvolutionOperation;
5
6/// Spatial convolution arguments (stride / padding / dilation per spatial dim).
7#[derive(Clone, Debug)]
8pub struct ConvolutionArgs<const N_SPATIAL: usize> {
9    pub stride: [usize; N_SPATIAL],
10    pub padding: [usize; N_SPATIAL],
11    pub dilation: [usize; N_SPATIAL],
12}
13
14#[allow(clippy::large_enum_variant)]
15/// Per-operation tensor bindings supplied to `launch_ref`.
16///
17/// Each variant carries exactly the bindings the corresponding operation needs.
18/// The discriminant maps 1:1 to `ConvolutionOperation`.
19pub enum ConvolutionInputs<R: Runtime> {
20    Forward {
21        input: InputBinding<R>,
22        weight: InputBinding<R>,
23        bias: Option<InputBinding<R>>,
24        out: TensorBinding<R>,
25    },
26    BackwardData {
27        out_grad: InputBinding<R>,
28        weights: InputBinding<R>,
29        in_grad: TensorBinding<R>,
30    },
31    BackwardWeight {
32        input: InputBinding<R>,
33        out_grad: InputBinding<R>,
34        weight_grad: TensorBinding<R>,
35    },
36}
37
38impl<R: Runtime> ConvolutionInputs<R> {
39    pub fn operation(&self) -> ConvolutionOperation {
40        match self {
41            ConvolutionInputs::Forward { .. } => ConvolutionOperation::Forward,
42            ConvolutionInputs::BackwardData { .. } => ConvolutionOperation::BackwardData,
43            ConvolutionInputs::BackwardWeight { .. } => ConvolutionOperation::BackwardWeight,
44        }
45    }
46}