cubek_convolution/kernels/forward/algorithm/
mod.rs1use cubek_matmul::definition::{
2 AvailableLineSizes, LoadingPrecomputeStrategy, MatmulElems, MatmulLineSizes, MatmulSetupError,
3 MultiRowStrategy, TilingBlueprint,
4};
5use cubek_matmul::{
6 components::{
7 global::{LoadFlows, read::ReaderMode},
8 stage::{PartitionBuffering, StageMatmulFamily},
9 tile::TileMatmulFamily,
10 },
11 launch::MatmulArgs,
12};
13
14use cubecl::{
15 ir::DeviceProperties,
16 std::tensor::{TensorHandle, into_contiguous_pitched_ref, is_contiguous_pitched},
17};
18
19use cubecl::prelude::*;
20
21use crate::components::{
22 ConvolutionOperation, ConvolutionProblem,
23 global::{GlobalConfig, GlobalConvolutionFamily},
24};
25
26pub mod simple;
27
28pub trait Algorithm {
30 type TileMatmul: TileMatmulFamily;
31 type StageMatmul: StageMatmulFamily;
32 type GlobalConvolution: GlobalConvolutionFamily;
33
34 type Args: MatmulArgs;
35
36 fn cube_count(selection: &TilingBlueprint, problem: &ConvolutionProblem) -> CubeCount {
37 let m_stage = selection.tiling_scheme.elements_per_stage_along_m();
38 let n_stage = selection.tiling_scheme.elements_per_stage_along_n();
39 let cubes_needed_m = (problem.m as u32).div_ceil(m_stage);
40 let cubes_needed_n = (problem.n as u32).div_ceil(n_stage);
41
42 CubeCount::Static(cubes_needed_m, cubes_needed_n, 1)
43 }
44
45 fn multi_row_strategy() -> MultiRowStrategy {
46 MultiRowStrategy::Never
47 }
48
49 fn loading_precompute_strategy() -> LoadingPrecomputeStrategy {
50 LoadingPrecomputeStrategy::Never
51 }
52
53 fn reader_mode() -> ReaderMode {
54 ReaderMode::Relaxed
55 }
56
57 fn load_specialization() -> LoadFlows {
58 LoadFlows::default()
59 }
60
61 fn partition_buffering_strategy() -> PartitionBuffering {
62 PartitionBuffering::Double
63 }
64
65 fn expand_config(
67 device_props: &DeviceProperties,
68 problem: &ConvolutionProblem,
69 selection: &TilingBlueprint,
70 line_sizes: &MatmulLineSizes,
71 dtypes: &MatmulElems,
72 ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> {
73 Self::GlobalConvolution::expand_config(device_props, problem, selection, line_sizes, dtypes)
74 }
75
76 fn into_tensor_handle<R: Runtime>(
77 client: &ComputeClient<R>,
78 handle: &TensorHandleRef<'_, R>,
79 dtype: StorageType,
80 operation: ConvolutionOperation,
81 ) -> Result<TensorHandle<R>, LaunchError>;
82
83 fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
84 line_sizes
85 }
86
87 fn selection<R: Runtime>(
88 client: &ComputeClient<R>,
89 problem: &ConvolutionProblem,
90 plane_dim: u32,
91 line_sizes: &MatmulLineSizes,
92 matmul_elems: &mut MatmulElems,
93 ) -> Result<TilingBlueprint, MatmulSetupError>;
94}
95
96pub(crate) fn into_tensor_handle<R: Runtime>(
97 client: &ComputeClient<R>,
98 handle: &TensorHandleRef<'_, R>,
99 dtype: StorageType,
100) -> Result<TensorHandle<R>, LaunchError> {
101 let handle = if has_valid_layout(handle) {
102 TensorHandle::from_ref(handle, dtype)
103 } else {
104 into_contiguous_pitched_ref(client, handle, dtype)?
105 };
106 Ok(handle)
107}
108
109fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
110 let rank = handle.shape.len();
111 let dim_c = rank - 1;
112 handle.strides[dim_c] == 1
113}
114
115const TMA_STRIDE_ALIGN: usize = 16;
116
117pub(crate) fn into_tensor_handle_tma<R: Runtime>(
118 client: &ComputeClient<R>,
119 handle: &TensorHandleRef<'_, R>,
120 dtype: StorageType,
121 operation: ConvolutionOperation,
122) -> Result<TensorHandle<R>, LaunchError> {
123 let handle = if has_valid_layout_tma(handle, operation) {
124 TensorHandle::from_ref(handle, dtype)
125 } else {
126 into_contiguous_pitched_ref(client, handle, dtype)?
127 };
128 Ok(handle)
129}
130
131pub(crate) fn has_valid_layout_tma<R: Runtime>(
132 handle: &TensorHandleRef<'_, R>,
133 operation: ConvolutionOperation,
134) -> bool {
135 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
136 let rank = handle.shape.len();
137 let dim_c = rank - 1;
138
139 let aligned = handle.strides[..dim_c]
140 .iter()
141 .all(|stride| stride % stride_align == 0);
142
143 let valid_layout = handle.strides[dim_c] == 1;
144
145 let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
146 is_contiguous_pitched(handle.shape, handle.strides)
147 } else {
148 true
149 };
150
151 valid_layout && aligned && is_valid_wgrad
152}