cubecl_matmul/components/
selection.rs1use crate::components::{
2 TilingScheme,
3 batch::HypercubeSelection,
4 global::{LoadSpecializationConfig, load::LoaderMode},
5 stage::PartitionBuffering,
6};
7
8#[derive(Debug, Clone)]
9pub struct MatmulSelection {
10 pub plane_dim: u32,
11 pub tiling_scheme: TilingScheme,
12 pub quantized: bool,
13 pub partition_buffering: PartitionBuffering,
14 pub loading_precompute_strategy: LoadingPrecomputeStrategy,
15 pub loader_mode: LoaderMode,
16 pub load_specialization_config: LoadSpecializationConfig,
17 pub hypercube_selection: HypercubeSelection,
18}
19
20impl MatmulSelection {
21 pub fn builder(tiling_scheme: TilingScheme, plane_dim: u32) -> MatmulSelectionBuilder {
22 let hypercube_config = HypercubeSelection::builder(&tiling_scheme).build();
23 MatmulSelectionBuilder::new()
24 .tiling_scheme(tiling_scheme)
25 .hypercube_config(hypercube_config)
26 .plane_dim(plane_dim)
27 }
28}
29
30pub struct MatmulSelectionBuilder {
31 plane_dim: Option<u32>,
32 pub tiling_scheme: Option<TilingScheme>,
33 hypercube_selection: Option<HypercubeSelection>,
34 quantized: bool,
35 partition_buffering: PartitionBuffering,
36 loading_precompute_strategy: LoadingPrecomputeStrategy,
37 loader_mode: LoaderMode,
38 load_specialization_config: LoadSpecializationConfig,
39}
40
41impl MatmulSelectionBuilder {
42 fn new() -> Self {
43 Self {
44 plane_dim: None,
45 tiling_scheme: None,
46 hypercube_selection: None,
47 quantized: false,
48 partition_buffering: PartitionBuffering::default(),
49 loading_precompute_strategy: LoadingPrecomputeStrategy::default(),
50 loader_mode: LoaderMode::default(),
51 load_specialization_config: LoadSpecializationConfig::default(),
52 }
53 }
54
55 pub fn plane_dim(mut self, plane_dim: u32) -> Self {
56 self.plane_dim = Some(plane_dim);
57 self
58 }
59
60 pub fn tiling_scheme(mut self, tiling_scheme: TilingScheme) -> Self {
61 self.tiling_scheme = Some(tiling_scheme);
62 self
63 }
64
65 pub fn hypercube_config(mut self, hypercube_config: HypercubeSelection) -> Self {
66 self.hypercube_selection = Some(hypercube_config);
67 self
68 }
69
70 pub fn quantized(mut self, quantized: bool) -> Self {
71 self.quantized = quantized;
72 self
73 }
74
75 pub fn partition_buffering(mut self, partition_buffering: PartitionBuffering) -> Self {
76 self.partition_buffering = partition_buffering;
77 self
78 }
79
80 pub fn loading_precompute_strategy(
81 mut self,
82 loading_precompute_strategy: LoadingPrecomputeStrategy,
83 ) -> Self {
84 self.loading_precompute_strategy = loading_precompute_strategy;
85 self
86 }
87
88 pub fn loader_mode(mut self, loader_mode: LoaderMode) -> Self {
89 self.loader_mode = loader_mode;
90 self
91 }
92
93 pub fn load_specialization_config(
94 mut self,
95 load_specialization_config: LoadSpecializationConfig,
96 ) -> Self {
97 self.load_specialization_config = load_specialization_config;
98 self
99 }
100
101 pub fn build(self) -> MatmulSelection {
102 MatmulSelection {
103 plane_dim: self.plane_dim.unwrap(),
104 tiling_scheme: self.tiling_scheme.unwrap(),
105 hypercube_selection: self.hypercube_selection.unwrap(),
106 quantized: self.quantized,
107 partition_buffering: self.partition_buffering,
108 loading_precompute_strategy: self.loading_precompute_strategy,
109 loader_mode: self.loader_mode,
110 load_specialization_config: self.load_specialization_config,
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, Default)]
116pub enum MultiRowStrategy {
117 #[default]
119 Never,
120 Always(u32),
122 Adaptive { minimum_stage_count: u32 },
124}
125
126#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
127pub enum LoadingPrecomputeStrategy {
128 #[default]
130 Never,
131 Always,
133}
134
135impl From<LoadingPrecomputeStrategy> for bool {
136 fn from(strategy: LoadingPrecomputeStrategy) -> Self {
137 match strategy {
138 LoadingPrecomputeStrategy::Always => true,
139 LoadingPrecomputeStrategy::Never => false,
140 }
141 }
142}