Skip to main content

cubek_convolution/kernels/algorithm/
mod.rs

1use cubek_matmul::launch::MatmulArgs;
2use cubek_matmul::{definition::AvailableLineSizes, routines::Routine};
3
4use cubecl::std::tensor::{TensorHandle, into_contiguous_pitched_ref, is_contiguous_pitched};
5
6use cubecl::prelude::*;
7
8use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
9
10pub mod simple;
11pub mod specialized;
12
13/// Specifications for a convolution algorithm
14pub trait Algorithm {
15    type Routine: Routine<RuntimeArgs>;
16    type Args: MatmulArgs<Config = RuntimeArgs>;
17
18    /// Whether to select specialized load flow in tests. Should replace with something cleaner
19    /// eventually, but this is nice and simple.
20    const IS_SPECIALIZED: bool = false;
21
22    fn into_tensor_handle<R: Runtime>(
23        client: &ComputeClient<R>,
24        handle: &TensorHandleRef<'_, R>,
25        dtype: StorageType,
26        operation: ConvolutionOperation,
27    ) -> Result<TensorHandle<R>, LaunchError>;
28
29    fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
30        line_sizes
31    }
32}
33
34pub(crate) fn into_tensor_handle<R: Runtime>(
35    client: &ComputeClient<R>,
36    handle: &TensorHandleRef<'_, R>,
37    dtype: StorageType,
38) -> Result<TensorHandle<R>, LaunchError> {
39    let handle = if has_valid_layout(handle) {
40        TensorHandle::from_ref(handle, dtype)
41    } else {
42        into_contiguous_pitched_ref(client, handle, dtype)?
43    };
44    Ok(handle)
45}
46
47fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
48    let rank = handle.shape.len();
49    let dim_c = rank - 1;
50    handle.strides[dim_c] == 1
51}
52
53const TMA_STRIDE_ALIGN: usize = 16;
54
55pub(crate) fn into_tensor_handle_tma<R: Runtime>(
56    client: &ComputeClient<R>,
57    handle: &TensorHandleRef<'_, R>,
58    dtype: StorageType,
59    operation: ConvolutionOperation,
60) -> Result<TensorHandle<R>, LaunchError> {
61    let handle = if has_valid_layout_tma(handle, operation) {
62        TensorHandle::from_ref(handle, dtype)
63    } else {
64        into_contiguous_pitched_ref(client, handle, dtype)?
65    };
66    Ok(handle)
67}
68
69pub(crate) fn has_valid_layout_tma<R: Runtime>(
70    handle: &TensorHandleRef<'_, R>,
71    operation: ConvolutionOperation,
72) -> bool {
73    let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
74    let rank = handle.shape.len();
75    let dim_c = rank - 1;
76
77    let aligned = handle.strides[..dim_c]
78        .iter()
79        .all(|stride| stride % stride_align == 0);
80
81    let valid_layout = handle.strides[dim_c] == 1;
82
83    let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
84        is_contiguous_pitched(handle.shape, handle.strides)
85    } else {
86        true
87    };
88
89    valid_layout && aligned && is_valid_wgrad
90}