cubecl_matmul/components/
selection.rs

1use cubecl_core::{Runtime, client::ComputeClient, flex32, prelude::CubePrimitive, tf32};
2
3use crate::components::{
4    MatmulElems, TilingScheme,
5    batch::HypercubeSelection,
6    global::{LoadSpecializationConfig, read::ReaderMode},
7    stage::{PartitionBuffering, SwizzleMode},
8};
9
10#[derive(Debug, Clone)]
11pub struct MatmulSelection {
12    pub plane_dim: u32,
13    pub tiling_scheme: TilingScheme,
14    pub shared_swizzle: SwizzleConfig,
15    pub partition_buffering: PartitionBuffering,
16    pub loading_precompute_strategy: LoadingPrecomputeStrategy,
17    pub reader_mode: ReaderMode,
18    pub load_specialization_config: LoadSpecializationConfig,
19    pub hypercube_selection: HypercubeSelection,
20}
21
22/// Modifies the given matmul element types based on the kind of accelerator the kernel is run on.
23pub fn adjust_dtypes<R: Runtime>(
24    client: &ComputeClient<R>,
25    dtypes: &mut MatmulElems,
26    requires_accelerator: bool,
27) {
28    let f32_dtype = f32::as_type_native_unchecked();
29    let flex_dtype = flex32::as_type_native_unchecked();
30    let tf32_dtype = tf32::as_type_native_unchecked();
31    let f16_dtype = half::f16::as_type_native_unchecked();
32
33    if requires_accelerator {
34        if *dtypes.lhs_global == f32_dtype
35            && *dtypes.rhs_global == f32_dtype
36            && client.properties().supports_type(tf32_dtype)
37        {
38            dtypes.lhs_stage.dtype = tf32_dtype;
39            dtypes.rhs_stage.dtype = tf32_dtype;
40            dtypes.lhs_register.dtype = tf32_dtype;
41            dtypes.rhs_register.dtype = tf32_dtype;
42        } else if *dtypes.lhs_global == flex_dtype
43            && *dtypes.rhs_global == flex_dtype
44            && client.properties().supports_type(f16_dtype)
45        {
46            dtypes.lhs_stage.dtype = f16_dtype;
47            dtypes.rhs_stage.dtype = f16_dtype;
48            dtypes.lhs_register.dtype = f16_dtype;
49            dtypes.rhs_register.dtype = f16_dtype;
50        }
51    }
52}
53
54#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
55pub struct SwizzleConfig {
56    pub lhs: SwizzleMode,
57    pub rhs: SwizzleMode,
58    pub acc: SwizzleMode,
59    pub out: SwizzleMode,
60}
61
62impl SwizzleConfig {
63    pub fn has_swizzle(&self) -> bool {
64        self.lhs != SwizzleMode::None
65            || self.rhs != SwizzleMode::None
66            || self.acc != SwizzleMode::None
67            || self.out != SwizzleMode::None
68    }
69}
70
71impl MatmulSelection {
72    pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
73        let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
74        MatmulSelectionBuilder::new()
75            .tiling_scheme(tiling_scheme)
76            .hypercube_config(hypercube_config)
77            .plane_dim(plane_dim)
78    }
79}
80
81pub struct MatmulSelectionBuilder {
82    plane_dim: Option<u32>,
83    pub tiling_scheme: Option<TilingScheme>,
84    shared_swizzle: SwizzleConfig,
85    hypercube_selection: Option<HypercubeSelection>,
86    partition_buffering: PartitionBuffering,
87    loading_precompute_strategy: LoadingPrecomputeStrategy,
88    reader_mode: ReaderMode,
89    load_specialization_config: LoadSpecializationConfig,
90}
91
92impl MatmulSelectionBuilder {
93    fn new() -> Self {
94        Self {
95            plane_dim: None,
96            tiling_scheme: None,
97            shared_swizzle: Default::default(),
98            hypercube_selection: None,
99            partition_buffering: PartitionBuffering::default(),
100            loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
101            reader_mode: ReaderMode::default(),
102            load_specialization_config: LoadSpecializationConfig::default(),
103        }
104    }
105
106    pub fn plane_dim(mut self, plane_dim: u32) -> Self {
107        self.plane_dim = Some(plane_dim);
108        self
109    }
110
111    pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
112        self.tiling_scheme = Some(tiling_scheme);
113        self
114    }
115
116    pub fn shared_swizzle(mut self, swizzle: SwizzleConfig) -> Self {
117        self.shared_swizzle = swizzle;
118        self
119    }
120
121    pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
122        self.hypercube_selection = Some(hypercube_config);
123        self
124    }
125
126    pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
127        self.partition_buffering = partition_buffering;
128        self
129    }
130
131    pub fn loading_precompute_strategy(
132        mut self,
133        loading_precompute_strategy: LoadingPrecomputeStrategy,
134    ) -> Self {
135        self.loading_precompute_strategy = loading_precompute_strategy;
136        self
137    }
138
139    pub fn reader_mode(mut self, reader_mode: ReaderMode) -> Self {
140        self.reader_mode = reader_mode;
141        self
142    }
143
144    pub fn load_specialization_config(
145        mut self,
146        load_specialization_config: LoadSpecializationConfig,
147    ) -> Self {
148        self.load_specialization_config = load_specialization_config;
149        self
150    }
151
152    pub fn build(self) -> MatmulSelection {
153        MatmulSelection {
154            plane_dim: self.plane_dim.unwrap(),
155            tiling_scheme: self.tiling_scheme.unwrap(),
156            shared_swizzle: self.shared_swizzle,
157            hypercube_selection: self.hypercube_selection.unwrap(),
158            partition_buffering: self.partition_buffering,
159            loading_precompute_strategy: self.loading_precompute_strategy,
160            reader_mode: self.reader_mode,
161            load_specialization_config: self.load_specialization_config,
162        }
163    }
164}
165
166#[derive(Debug, Clone, Copy, Default)]
167pub enum MultiRowStrategy {
168    /// Always one row per plane
169    #[default]
170    Never,
171    /// Always multiple rows per plane
172    Always(u32),
173    /// Uses multiple rows if the `m` dimension of the matmul implies at least the minimum number of stages along `m`
174    Adaptive { minimum_stage_count: u32 },
175}
176
177#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
178pub enum LoadingPrecomputeStrategy {
179    /// Don't precompute anything in loading jobs
180    #[default]
181    Never,
182    /// Precompute values that are shared across tasks
183    Always,
184}
185
186impl From<LoadingPrecomputeStrategy> for bool {
187    fn from(strategy: LoadingPrecomputeStrategy) -> Self {
188        match strategy {
189            LoadingPrecomputeStrategy::Always => true,
190            LoadingPrecomputeStrategy::Never => false,
191        }
192    }
193}