cubek_convolution/routines/
specialized.rs1use cubecl::{
2 Runtime, client::ComputeClient, ir::StorageType, prelude::TensorBinding, server::LaunchError,
3};
4use cubek_matmul::{
5 components::{
6 global::read::{
7 AsyncPartialLoadingStrategy, async_partial_cyclic::AsyncPartialCyclicLoading,
8 async_partial_strided::AsyncPartialStridedLoading,
9 async_partial_tma::AsyncPartialTmaLoading,
10 },
11 stage::ColMajorTilingOrder,
12 },
13 definition::{AvailableVectorSizes, TilingBlueprint},
14 launch::{TensorArgs, TensorMapArgs},
15 routines::specialized::{SpecializedAlgorithm, SpecializedStrategy},
16};
17use cubek_std::tile::Strided;
18use std::marker::PhantomData;
19
20use crate::{
21 components::{
22 ConvolutionOperation,
23 global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
24 },
25 routines::{Routine, contiguous_pitched_layout, into_tensor_handle_tma},
26};
27
28pub struct SpecializedConv<L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
30 _loader: PhantomData<L>,
31}
32
33pub type SpecializedAsyncCyclicConv =
34 SpecializedConv<AsyncPartialCyclicLoading<ColMajorTilingOrder>>;
35pub type SpecializedAsyncStridedConv = SpecializedConv<AsyncPartialStridedLoading>;
36
37pub struct SpecializedTmaConv;
38
39impl<L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>> Routine
40 for SpecializedConv<L>
41{
42 type Blueprint = TilingBlueprint;
43 type Strategy = SpecializedStrategy;
44 type MatmulRoutine = SpecializedAlgorithm<L, SyncBiasLoading>;
45 type Args = TensorArgs<RuntimeArgs>;
46 const IS_SPECIALIZED: bool = true;
47
48 fn correct_layout<R: Runtime>(
49 client: &ComputeClient<R>,
50 handle: TensorBinding<R>,
51 dtype: StorageType,
52 _operation: ConvolutionOperation,
53 ) -> Result<TensorBinding<R>, LaunchError> {
54 contiguous_pitched_layout(client, handle, dtype)
55 }
56}
57
58impl Routine for SpecializedTmaConv {
59 type Blueprint = TilingBlueprint;
60 type Strategy = SpecializedStrategy;
61 type MatmulRoutine = SpecializedAlgorithm<AsyncPartialTmaLoading, SyncBiasLoading>;
62 type Args = TensorMapArgs<RuntimeArgs>;
63 const IS_SPECIALIZED: bool = true;
64
65 fn correct_layout<R: Runtime>(
66 client: &ComputeClient<R>,
67 handle: TensorBinding<R>,
68 dtype: StorageType,
69 operation: ConvolutionOperation,
70 ) -> Result<TensorBinding<R>, LaunchError> {
71 into_tensor_handle_tma(client, handle, dtype, operation)
72 }
73
74 fn filter_vector_sizes(vector_sizes: AvailableVectorSizes) -> AvailableVectorSizes {
75 AvailableVectorSizes {
76 lhs: vec![1],
77 rhs: vec![1],
78 out: vector_sizes.out,
79 }
80 }
81}