Skip to main content

cubek_convolution/routines/
base.rs

1use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
2use cubecl::{
3    prelude::*,
4    std::tensor::{into_contiguous_pitched, is_contiguous_pitched},
5};
6use cubek_matmul::{
7    definition::{AvailableVectorSizes, Blueprint},
8    launch::MatmulArgs,
9    routines::Routine as MatmulRoutine,
10};
11use std::fmt::Display;
12
13/// Specifications for a convolution routine.
14///
15/// A `Routine` is the convolution-side counterpart of `cubek_matmul::routines::Routine`:
16/// it pairs a per-operation matmul routine with the metadata needed to wire the
17/// kernel up (input args, optional layout fixups, vector-size filtering).
18///
19/// `Blueprint` and `Strategy` are surfaced as direct associated types so callers
20/// don't have to reach through `MatmulRoutine` to bound them.
21pub trait Routine {
22    type Blueprint: Blueprint;
23    type Strategy: Default + Display + Clone;
24
25    type MatmulRoutine: MatmulRoutine<RuntimeArgs, Blueprint = Self::Blueprint, Strategy = Self::Strategy>;
26    type Args: MatmulArgs<Config = RuntimeArgs>;
27
28    /// Whether to select specialized load flow in tests. Should replace with something cleaner
29    /// eventually, but this is nice and simple.
30    const IS_SPECIALIZED: bool = false;
31
32    fn correct_layout<R: Runtime>(
33        client: &ComputeClient<R>,
34        handle: TensorBinding<R>,
35        dtype: StorageType,
36        operation: ConvolutionOperation,
37    ) -> Result<TensorBinding<R>, LaunchError>;
38
39    fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
40        vector_sizes
41    }
42}
43
44pub(crate) fn contiguous_pitched_layout<R: Runtime>(
45    client: &ComputeClient<R>,
46    binding: TensorBinding<R>,
47    dtype: StorageType,
48) -> Result<TensorBinding<R>, LaunchError> {
49    let binding = if has_valid_layout(&binding) {
50        binding
51    } else {
52        into_contiguous_pitched(client, binding, dtype).binding()
53    };
54    Ok(binding)
55}
56
57fn has_valid_layout<R: Runtime>(binding: &TensorBinding<R>) -> bool {
58    let rank = binding.shape.len();
59    let dim_c = rank - 1;
60    binding.strides[dim_c] == 1
61}
62
63const TMA_STRIDE_ALIGN: usize = 16;
64
65pub(crate) fn into_tensor_handle_tma<R: Runtime>(
66    client: &ComputeClient<R>,
67    handle: TensorBinding<R>,
68    dtype: StorageType,
69    operation: ConvolutionOperation,
70) -> Result<TensorBinding<R>, LaunchError> {
71    let binding = if has_valid_layout_tma(&handle, dtype, operation) {
72        handle
73    } else {
74        into_contiguous_pitched(client, handle, dtype).binding()
75    };
76    Ok(binding)
77}
78
79pub(crate) fn has_valid_layout_tma<R: Runtime>(
80    binding: &TensorBinding<R>,
81    dtype: StorageType,
82    operation: ConvolutionOperation,
83) -> bool {
84    let stride_align = TMA_STRIDE_ALIGN / dtype.size();
85    let rank = binding.shape.len();
86    let dim_c = rank - 1;
87
88    let aligned = binding.strides[..dim_c]
89        .iter()
90        .all(|stride| stride % stride_align == 0);
91
92    let valid_layout = binding.strides[dim_c] == 1;
93
94    let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
95        is_contiguous_pitched(&binding.shape, &binding.strides)
96    } else {
97        true
98    };
99
100    valid_layout && aligned && is_valid_wgrad
101}