cubecl_matmul/components/
selection.rs

1use crate::components::{
2    TilingScheme,
3    batch::HypercubeSelection,
4    global::{LoadSpecializationConfig, load::LoaderMode},
5    stage::PartitionBuffering,
6};
7
8#[derive(Debug, Clone)]
9pub struct MatmulSelection {
10    pub plane_dim: u32,
11    pub tiling_scheme: TilingScheme,
12    pub quantized: bool,
13    pub partition_buffering: PartitionBuffering,
14    pub loading_precompute_strategy: LoadingPrecomputeStrategy,
15    pub loader_mode: LoaderMode,
16    pub load_specialization_config: LoadSpecializationConfig,
17    pub hypercube_selection: HypercubeSelection,
18}
19
20impl MatmulSelection {
21    pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
22        let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
23        MatmulSelectionBuilder::new()
24            .tiling_scheme(tiling_scheme)
25            .hypercube_config(hypercube_config)
26            .plane_dim(plane_dim)
27    }
28}
29
30pub struct MatmulSelectionBuilder {
31    plane_dim: Option<u32>,
32    pub tiling_scheme: Option<TilingScheme>,
33    hypercube_selection: Option<HypercubeSelection>,
34    quantized: bool,
35    partition_buffering: PartitionBuffering,
36    loading_precompute_strategy: LoadingPrecomputeStrategy,
37    loader_mode: LoaderMode,
38    load_specialization_config: LoadSpecializationConfig,
39}
40
41impl MatmulSelectionBuilder {
42    fn new() -> Self {
43        Self {
44            plane_dim: None,
45            tiling_scheme: None,
46            hypercube_selection: None,
47            quantized: false,
48            partition_buffering: PartitionBuffering::default(),
49            loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
50            loader_mode: LoaderMode::default(),
51            load_specialization_config: LoadSpecializationConfig::default(),
52        }
53    }
54
55    pub fn plane_dim(mut self, plane_dim: u32) -> Self {
56        self.plane_dim = Some(plane_dim);
57        self
58    }
59
60    pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
61        self.tiling_scheme = Some(tiling_scheme);
62        self
63    }
64
65    pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
66        self.hypercube_selection = Some(hypercube_config);
67        self
68    }
69
70    pub fn quantized(mut self, quantized: bool) -> Self {
71        self.quantized = quantized;
72        self
73    }
74
75    pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
76        self.partition_buffering = partition_buffering;
77        self
78    }
79
80    pub fn loading_precompute_strategy(
81        mut self,
82        loading_precompute_strategy: LoadingPrecomputeStrategy,
83    ) -> Self {
84        self.loading_precompute_strategy = loading_precompute_strategy;
85        self
86    }
87
88    pub fn loader_mode(mut self, loader_mode: LoaderMode) -> Self {
89        self.loader_mode = loader_mode;
90        self
91    }
92
93    pub fn load_specialization_config(
94        mut self,
95        load_specialization_config: LoadSpecializationConfig,
96    ) -> Self {
97        self.load_specialization_config = load_specialization_config;
98        self
99    }
100
101    pub fn build(self) -> MatmulSelection {
102        MatmulSelection {
103            plane_dim: self.plane_dim.unwrap(),
104            tiling_scheme: self.tiling_scheme.unwrap(),
105            hypercube_selection: self.hypercube_selection.unwrap(),
106            quantized: self.quantized,
107            partition_buffering: self.partition_buffering,
108            loading_precompute_strategy: self.loading_precompute_strategy,
109            loader_mode: self.loader_mode,
110            load_specialization_config: self.load_specialization_config,
111        }
112    }
113}
114
115#[derive(Debug, Clone, Copy, Default)]
116pub enum MultiRowStrategy {
117    /// Always one row per plane
118    #[default]
119    Never,
120    /// Always multiple rows per plane
121    Always(u32),
122    /// Uses multiple rows if the `m` dimension of the matmul implies at least the minimum number of stages along `m`
123    Adaptive { minimum_stage_count: u32 },
124}
125
126#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
127pub enum LoadingPrecomputeStrategy {
128    /// Don't precompute anything in loading jobs
129    #[default]
130    Never,
131    /// Precompute values that are shared across tasks
132    Always,
133}
134
135impl From<LoadingPrecomputeStrategy> for bool {
136    fn from(strategy: LoadingPrecomputeStrategy) -> Self {
137        match strategy {
138            LoadingPrecomputeStrategy::Always => true,
139            LoadingPrecomputeStrategy::Never => false,
140        }
141    }
142}