Skip to main content

cubek_convolution/definition/
blueprint.rs

1use cubek_matmul::definition::TilingBlueprint;
2
3use crate::components::Dimensionality;
4
5/// Per-operation comptime blueprint for the convolution kernel family.
6///
7/// The blueprint captures the minimal comptime information needed to specialize
8/// the kernel: which operation (forward / data-grad / weight-grad), the matmul
9/// `TilingBlueprint`, and per-operation comptime tunables. A different blueprint
10/// retriggers JIT compilation, so it is kept minimal.
11#[derive(Clone, Debug, PartialEq, Eq, Hash)]
12pub enum ConvBlueprint {
13    Forward(ForwardBlueprint),
14    BackwardData(BackwardDataBlueprint),
15    BackwardWeight(BackwardWeightBlueprint),
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19pub struct ForwardBlueprint {
20    pub matmul: TilingBlueprint,
21    pub dimensionality: Dimensionality,
22    pub has_bias: bool,
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, Hash)]
26pub struct BackwardDataBlueprint {
27    pub matmul: TilingBlueprint,
28    pub dimensionality: Dimensionality,
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, Hash)]
32pub struct BackwardWeightBlueprint {
33    pub matmul: TilingBlueprint,
34    pub dimensionality: Dimensionality,
35}
36
37impl ConvBlueprint {
38    pub fn matmul(&self) -> &TilingBlueprint {
39        match self {
40            ConvBlueprint::Forward(b) => &b.matmul,
41            ConvBlueprint::BackwardData(b) => &b.matmul,
42            ConvBlueprint::BackwardWeight(b) => &b.matmul,
43        }
44    }
45
46    pub fn dimensionality(&self) -> Dimensionality {
47        match self {
48            ConvBlueprint::Forward(b) => b.dimensionality,
49            ConvBlueprint::BackwardData(b) => b.dimensionality,
50            ConvBlueprint::BackwardWeight(b) => b.dimensionality,
51        }
52    }
53}