cubecl_matmul/components/global/specialization/
specializer.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::global::GlobalConfig;
5use crate::components::global::specialization::config::LoadingSides;
6use crate::components::global::specialization::roles::RoleRuleConfig;
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub enum SpecializerKind {
11 Specialized {
12 main_flow_loading_side: LoadingSides,
13 load_only_loading_side: LoadingSides,
14 role_rule_config: RoleRuleConfig,
15 },
16 NotSpecialized,
17}
18
19#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
20pub struct Specializer {
22 #[cube(comptime)]
23 pub kind: SpecializerKind,
24}
25
26#[cube]
27impl Specializer {
28 pub fn new<G: GlobalConfig>(#[comptime] config: G) -> Specializer {
29 let plane_role_config = config.plane_role_config();
30 let loading_sides = config.specialized_loading_sides();
31
32 if config.plane_role_config().has_specialization() {
33 Specializer {
34 kind: comptime! {
35 SpecializerKind::Specialized {
36 main_flow_loading_side: loading_sides.main_flow,
37 load_only_loading_side: loading_sides.load_only,
38 role_rule_config: plane_role_config.rule
39 }
40 },
41 }
42 } else {
43 Specializer {
44 kind: comptime! {SpecializerKind::NotSpecialized},
45 }
46 }
47 }
48}