cubek_convolution/kernels/algorithm/
specialized.rs1use std::marker::PhantomData;
2
3use cubecl::{
4 Runtime,
5 client::ComputeClient,
6 ir::StorageType,
7 prelude::TensorHandleRef,
8 server::LaunchError,
9 std::{CubeOption, tensor::TensorHandle},
10};
11use cubek_matmul::{
12 components::{
13 global::read::{AsyncPartialLoadingStrategy, async_partial_tma::AsyncPartialTmaLoading},
14 tile::{TileMatmulFamily, io::Strided},
15 },
16 definition::AvailableLineSizes,
17 launch::{TensorArgs, TensorMapArgs},
18 routines::specialized::SpecializedAlgorithm,
19};
20
21use crate::{
22 algorithm::{Algorithm, into_tensor_handle, into_tensor_handle_tma},
23 components::{
24 ConvolutionOperation,
25 global::{args::RuntimeArgs, read::strategy::sync_bias::SyncBiasLoading},
26 },
27};
28
29pub struct SpecializedConv<TMM: TileMatmulFamily, L: AsyncPartialLoadingStrategy<RuntimeArgs>> {
31 _tmm: PhantomData<TMM>,
32 _loader: PhantomData<L>,
33}
34
35pub struct SpecializedTmaConv<TMM: TileMatmulFamily> {
40 _tmm: PhantomData<TMM>,
41}
42
43impl<
44 TMM: TileMatmulFamily<
45 LhsTile = Strided,
46 RhsTile = Strided,
47 AccTile = CubeOption<Strided>,
48 OutTile = Strided,
49 >,
50 L: AsyncPartialLoadingStrategy<RuntimeArgs, TileKind = Strided>,
51> Algorithm for SpecializedConv<TMM, L>
52{
53 type Routine = SpecializedAlgorithm<TMM, L, SyncBiasLoading>;
54 type Args = TensorArgs<RuntimeArgs>;
55 const IS_SPECIALIZED: bool = true;
56
57 fn into_tensor_handle<R: Runtime>(
58 client: &ComputeClient<R>,
59 handle: &TensorHandleRef<'_, R>,
60 dtype: StorageType,
61 _operation: ConvolutionOperation,
62 ) -> Result<TensorHandle<R>, LaunchError> {
63 into_tensor_handle(client, handle, dtype)
64 }
65}
66
67impl<
68 TMM: TileMatmulFamily<
69 LhsTile = Strided,
70 RhsTile = Strided,
71 AccTile = CubeOption<Strided>,
72 OutTile = Strided,
73 >,
74> Algorithm for SpecializedTmaConv<TMM>
75{
76 type Routine = SpecializedAlgorithm<TMM, AsyncPartialTmaLoading, SyncBiasLoading>;
77 type Args = TensorMapArgs<RuntimeArgs>;
78 const IS_SPECIALIZED: bool = true;
79
80 fn into_tensor_handle<R: Runtime>(
81 client: &ComputeClient<R>,
82 handle: &TensorHandleRef<'_, R>,
83 dtype: StorageType,
84 operation: ConvolutionOperation,
85 ) -> Result<TensorHandle<R>, LaunchError> {
86 into_tensor_handle_tma(client, handle, dtype, operation)
87 }
88
89 fn filter_line_sizes(line_sizes: AvailableLineSizes) -> AvailableLineSizes {
90 AvailableLineSizes {
91 lhs: vec![1],
92 rhs: vec![1],
93 out: line_sizes.out,
94 }
95 }
96}