Skip to main content

cubek_convolution/kernels/algorithm/
mod.rs

1use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
2use cubecl::{
3    prelude::*,
4    std::tensor::{into_contiguous_pitched, is_contiguous_pitched},
5};
6use cubek_matmul::{
7    launch::MatmulArgs,
8    {definition::AvailableVectorSizes, routines::Routine},
9};
10
11pub mod simple;
12pub mod specialized;
13
14/// Specifications for a convolution algorithm
15pub trait Algorithm {
16    type Routine: Routine<RuntimeArgs>;
17    type Args: MatmulArgs<Config = RuntimeArgs>;
18
19    /// Whether to select specialized load flow in tests. Should replace with something cleaner
20    /// eventually, but this is nice and simple.
21    const IS_SPECIALIZED: bool = false;
22
23    fn correct_layout<R: Runtime>(
24        client: &ComputeClient<R>,
25        handle: TensorBinding<R>,
26        dtype: StorageType,
27        operation: ConvolutionOperation,
28    ) -> Result<TensorBinding<R>, LaunchError>;
29
30    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
31        vector_sizes
32    }
33}
34
35pub(crate) fn contiguous_pitched_layout<R: Runtime>(
36    client: &ComputeClient<R>,
37    binding: TensorBinding<R>,
38    dtype: StorageType,
39) -> Result<TensorBinding<R>, LaunchError> {
40    let binding = if has_valid_layout(&binding) {
41        binding
42    } else {
43        into_contiguous_pitched(client, binding, dtype).binding()
44    };
45    Ok(binding)
46}
47
48fn has_valid_layout<R: Runtime>(binding: &TensorBinding<R>) -> bool {
49    let rank = binding.shape.len();
50    let dim_c = rank - 1;
51    binding.strides[dim_c] == 1
52}
53
54const TMA_STRIDE_ALIGN: usize = 16;
55
56pub(crate) fn into_tensor_handle_tma<R: Runtime>(
57    client: &ComputeClient<R>,
58    handle: TensorBinding<R>,
59    dtype: StorageType,
60    operation: ConvolutionOperation,
61) -> Result<TensorBinding<R>, LaunchError> {
62    let binding = if has_valid_layout_tma(&handle, dtype, operation) {
63        handle
64    } else {
65        into_contiguous_pitched(client, handle, dtype).binding()
66    };
67    Ok(binding)
68}
69
70pub(crate) fn has_valid_layout_tma<R: Runtime>(
71    binding: &TensorBinding<R>,
72    dtype: StorageType,
73    operation: ConvolutionOperation,
74) -> bool {
75    let stride_align = TMA_STRIDE_ALIGN / dtype.size();
76    let rank = binding.shape.len();
77    let dim_c = rank - 1;
78
79    let aligned = binding.strides[..dim_c]
80        .iter()
81        .all(|stride| stride % stride_align == 0);
82
83    let valid_layout = binding.strides[dim_c] == 1;
84
85    let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
86        is_contiguous_pitched(&binding.shape, &binding.strides)
87    } else {
88        true
89    };
90
91    valid_layout && aligned && is_valid_wgrad
92}