cubek_convolution/kernels/algorithm/
specialized.rs1use cubecl::{
2 Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5 components::global::read::{
6 AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading,
7 },
8 definition::AvailableVectorSizes,
9 launch::{TensorArgs, TensorMapArgs},
10 routines::specialized::SpecializedAlgorithm,
11};
12use cubek_std::tile::Strided;
13use std::marker::PhantomData;
14
15use crate::{
16 algorithm::{Algorithm, contiguous_pitched_layout, into_tensor_handle_tma},
17 components::{
18 ConvolutionOperation,
19 global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
20 },
21};
22
23pub struct SpecializedConv<L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
25 _loader: PhantomData<L>,
26}
27
28pub struct SpecializedTmaConv;
29
30impl<L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>> Algorithm
31 for SpecializedConv<L>
32{
33 type Routine = SpecializedAlgorithm<L, SyncBiasLoading>;
34 type Args = TensorArgs<RuntimeArgs>;
35 const IS_SPECIALIZED: bool = true;
36
37 fn correct_layout<R: Runtime>(
38 client: &ComputeClient<R>,
39 handle: TensorBinding<R>,
40 dtype: StorageType,
41 _operation: ConvolutionOperation,
42 ) -> Result<TensorBinding<R>, LaunchError> {
43 contiguous_pitched_layout(client, handle, dtype)
44 }
45}
46
47impl Algorithm for SpecializedTmaConv {
48 type Routine = SpecializedAlgorithm<AsyncPartialTmaLoading, SyncBiasLoading>;
49 type Args = TensorMapArgs<RuntimeArgs>;
50 const IS_SPECIALIZED: bool = true;
51
52 fn correct_layout<R: Runtime>(
53 client: &ComputeClient<R>,
54 handle: TensorBinding<R>,
55 dtype: StorageType,
56 operation: ConvolutionOperation,
57 ) -> Result<TensorBinding<R>, LaunchError> {
58 into_tensor_handle_tma(client, handle, dtype, operation)
59 }
60
61 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
62 AvailableVectorSizes {
63 lhs: vec![1],
64 rhs: vec![1],
65 out: vector_sizes.out,
66 }
67 }
68}