cubek_convolution/routines/
base.rs1use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
2use cubecl::{
3 prelude::*,
4 std::tensor::{into_contiguous_pitched, is_contiguous_pitched},
5};
6use cubek_matmul::{
7 definition::{AvailableVectorSizes, Blueprint},
8 launch::MatmulArgs,
9 routines::Routine as MatmulRoutine,
10};
11use std::fmt::Display;
12
13pub trait Routine {
22 type Blueprint: Blueprint;
23 type Strategy: Default + Display + Clone;
24
25 type MatmulRoutine: MatmulRoutine<RuntimeArgs, Blueprint = Self::Blueprint, Strategy = Self::Strategy>;
26 type Args: MatmulArgs<Config = RuntimeArgs>;
27
28 const IS_SPECIALIZED: bool = false;
31
32 fn correct_layout<R: Runtime>(
33 client: &ComputeClient<R>,
34 handle: TensorBinding<R>,
35 dtype: StorageType,
36 operation: ConvolutionOperation,
37 ) -> Result<TensorBinding<R>, LaunchError>;
38
39 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
40 vector_sizes
41 }
42}
43
44pub(crate) fn contiguous_pitched_layout<R: Runtime>(
45 client: &ComputeClient<R>,
46 binding: TensorBinding<R>,
47 dtype: StorageType,
48) -> Result<TensorBinding<R>, LaunchError> {
49 let binding = if has_valid_layout(&binding) {
50 binding
51 } else {
52 into_contiguous_pitched(client, binding, dtype).binding()
53 };
54 Ok(binding)
55}
56
57fn has_valid_layout<R: Runtime>(binding: &TensorBinding<R>) -> bool {
58 let rank = binding.shape.len();
59 let dim_c = rank - 1;
60 binding.strides[dim_c] == 1
61}
62
63const TMA_STRIDE_ALIGN: usize = 16;
64
65pub(crate) fn into_tensor_handle_tma<R: Runtime>(
66 client: &ComputeClient<R>,
67 handle: TensorBinding<R>,
68 dtype: StorageType,
69 operation: ConvolutionOperation,
70) -> Result<TensorBinding<R>, LaunchError> {
71 let binding = if has_valid_layout_tma(&handle, dtype, operation) {
72 handle
73 } else {
74 into_contiguous_pitched(client, handle, dtype).binding()
75 };
76 Ok(binding)
77}
78
79pub(crate) fn has_valid_layout_tma<R: Runtime>(
80 binding: &TensorBinding<R>,
81 dtype: StorageType,
82 operation: ConvolutionOperation,
83) -> bool {
84 let stride_align = TMA_STRIDE_ALIGN / dtype.size();
85 let rank = binding.shape.len();
86 let dim_c = rank - 1;
87
88 let aligned = binding.strides[..dim_c]
89 .iter()
90 .all(|stride| stride % stride_align == 0);
91
92 let valid_layout = binding.strides[dim_c] == 1;
93
94 let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
95 is_contiguous_pitched(&binding.shape, &binding.strides)
96 } else {
97 true
98 };
99
100 valid_layout && aligned && is_valid_wgrad
101}