cubek_convolution/definition/
blueprint.rs1use cubek_matmul::definition::TilingBlueprint;
2
3use crate::components::Dimensionality;
4
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
12pub enum ConvBlueprint {
13 Forward(ForwardBlueprint),
14 BackwardData(BackwardDataBlueprint),
15 BackwardWeight(BackwardWeightBlueprint),
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19pub struct ForwardBlueprint {
20 pub matmul: TilingBlueprint,
21 pub dimensionality: Dimensionality,
22 pub has_bias: bool,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, Hash)]
26pub struct BackwardDataBlueprint {
27 pub matmul: TilingBlueprint,
28 pub dimensionality: Dimensionality,
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, Hash)]
32pub struct BackwardWeightBlueprint {
33 pub matmul: TilingBlueprint,
34 pub dimensionality: Dimensionality,
35}
36
37impl ConvBlueprint {
38 pub fn matmul(&self) -> &TilingBlueprint {
39 match self {
40 ConvBlueprint::Forward(b) => &b.matmul,
41 ConvBlueprint::BackwardData(b) => &b.matmul,
42 ConvBlueprint::BackwardWeight(b) => &b.matmul,
43 }
44 }
45
46 pub fn dimensionality(&self) -> Dimensionality {
47 match self {
48 ConvBlueprint::Forward(b) => b.dimensionality,
49 ConvBlueprint::BackwardData(b) => b.dimensionality,
50 ConvBlueprint::BackwardWeight(b) => b.dimensionality,
51 }
52 }
53}