cubek_convolution/kernels/algorithm/
mod.rs1use cubek_matmul::launch::MatmulArgs;
2use cubek_matmul::{definition::AvailableLineSizes, routines::Routine};
3
4use cubecl::std::tensor::{TensorHandle, into_contiguous_pitched_ref, is_contiguous_pitched};
5
6use cubecl::prelude::*;
7
8use crate::components::{ConvolutionOperation, global::args::RuntimeArgs};
9
10pub mod simple;
11pub mod specialized;
12
13pub trait Algorithm {
15 type Routine: Routine<RuntimeArgs>;
16 type Args: MatmulArgs<Config = RuntimeArgs>;
17
18 const IS_SPECIALIZED: bool = false;
21
22 fn into_tensor_handle<R: Runtime>(
23 client: &ComputeClient<R>,
24 handle: &TensorHandleRef<'_, R>,
25 dtype: StorageType,
26 operation: ConvolutionOperation,
27 ) -> Result<TensorHandle<R>, LaunchError>;
28
29 fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
30 line_sizes
31 }
32}
33
34pub(crate) fn into_tensor_handle<R: Runtime>(
35 client: &ComputeClient<R>,
36 handle: &TensorHandleRef<'_, R>,
37 dtype: StorageType,
38) -> Result<TensorHandle<R>, LaunchError> {
39 let handle = if has_valid_layout(handle) {
40 TensorHandle::from_ref(handle, dtype)
41 } else {
42 into_contiguous_pitched_ref(client, handle, dtype)?
43 };
44 Ok(handle)
45}
46
47fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>) -> bool {
48 let rank = handle.shape.len();
49 let dim_c = rank - 1;
50 handle.strides[dim_c] == 1
51}
52
53const TMA_STRIDE_ALIGN: usize = 16;
54
55pub(crate) fn into_tensor_handle_tma<R: Runtime>(
56 client: &ComputeClient<R>,
57 handle: &TensorHandleRef<'_, R>,
58 dtype: StorageType,
59 operation: ConvolutionOperation,
60) -> Result<TensorHandle<R>, LaunchError> {
61 let handle = if has_valid_layout_tma(handle, operation) {
62 TensorHandle::from_ref(handle, dtype)
63 } else {
64 into_contiguous_pitched_ref(client, handle, dtype)?
65 };
66 Ok(handle)
67}
68
69pub(crate) fn has_valid_layout_tma<R: Runtime>(
70 handle: &TensorHandleRef<'_, R>,
71 operation: ConvolutionOperation,
72) -> bool {
73 let stride_align = TMA_STRIDE_ALIGN / handle.elem_size;
74 let rank = handle.shape.len();
75 let dim_c = rank - 1;
76
77 let aligned = handle.strides[..dim_c]
78 .iter()
79 .all(|stride| stride % stride_align == 0);
80
81 let valid_layout = handle.strides[dim_c] == 1;
82
83 let is_valid_wgrad = if operation == ConvolutionOperation::BackwardWeight {
84 is_contiguous_pitched(handle.shape, handle.strides)
85 } else {
86 true
87 };
88
89 valid_layout && aligned && is_valid_wgrad
90}