cubek_convolution/kernels/algorithm/
mod.rs1use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
2use cubecl::{
3 prelude::*,
4 std::tensor::{into_contiguous_pitched, is_contiguous_pitched},
5};
6use cubek_matmul::{
7 launch::MatmulArgs,
8 {definition::AvailableVectorSizes, routines::Routine},
9};
10
11pub mod simple;
12pub mod specialized;
13
14pub trait Algorithm {
16 type Routine: Routine<RuntimeArgs>;
17 type Args: MatmulArgs<Config = RuntimeArgs>;
18
19 const IS_SPECIALIZED: bool = false;
22
23 fn correct_layout<R: Runtime>(
24 client: &ComputeClient<R>,
25 handle: TensorBinding<R>,
26 dtype: StorageType,
27 operation: ConvolutionOperation,
28 ) -> Result<TensorBinding<R>, LaunchError>;
29
30 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
31 vector_sizes
32 }
33}
34
35pub(crate) fn contiguous_pitched_layout<R: Runtime>(
36 client: &ComputeClient<R>,
37 binding: TensorBinding<R>,
38 dtype: StorageType,
39) -> Result<TensorBinding<R>, LaunchError> {
40 let binding = if has_valid_layout(&binding) {
41 binding
42 } else {
43 into_contiguous_pitched(client, binding, dtype).binding()
44 };
45 Ok(binding)
46}
47
48fn has_valid_layout<R: Runtime>(binding: &TensorBinding<R>) -> bool {
49 let rank = binding.shape.len();
50 let dim_c = rank - 1;
51 binding.strides[dim_c] == 1
52}
53
54const TMA_STRIDE_ALIGN: usize = 16;
55
56pub(crate) fn into_tensor_handle_tma<R: Runtime>(
57 client: &ComputeClient<R>,
58 handle: TensorBinding<R>,
59 dtype: StorageType,
60 operation: ConvolutionOperation,
61) -> Result<TensorBinding<R>, LaunchError> {
62 let binding = if has_valid_layout_tma(&handle, dtype, operation) {
63 handle
64 } else {
65 into_contiguous_pitched(client, handle, dtype).binding()
66 };
67 Ok(binding)
68}
69
70pub(crate) fn has_valid_layout_tma<R: Runtime>(
71 binding: &TensorBinding<R>,
72 dtype: StorageType,
73 operation: ConvolutionOperation,
74) -> bool {
75 let stride_align = TMA_STRIDE_ALIGN / dtype.size();
76 let rank = binding.shape.len();
77 let dim_c = rank - 1;
78
79 let aligned = binding.strides[..dim_c]
80 .iter()
81 .all(|stride| stride % stride_align == 0);
82
83 let valid_layout = binding.strides[dim_c] == 1;
84
85 let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
86 is_contiguous_pitched(&binding.shape, &binding.strides)
87 } else {
88 true
89 };
90
91 valid_layout && aligned && is_valid_wgrad
92}