cubecl_matmul/components/global/specialization/
specializer.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::global::GlobalConfig;
5use crate::components::global::specialization::config::LoadingSides;
6use crate::components::global::specialization::roles::RoleRuleConfig;
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9/// Comptime information of specializer
10pub enum SpecializerKind {
11    Specialized {
12        main_flow_loading_side: LoadingSides,
13        load_only_loading_side: LoadingSides,
14        role_rule_config: RoleRuleConfig,
15    },
16    NotSpecialized,
17}
18
19#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
20/// Specialization information in cube functions
21pub struct Specializer {
22    #[cube(comptime)]
23    pub kind: SpecializerKind,
24}
25
26#[cube]
27impl Specializer {
28    pub fn new<G: GlobalConfig>(#[comptime] config: G) -> Specializer {
29        let plane_role_config = config.plane_role_config();
30        let loading_sides = config.specialized_loading_sides();
31
32        if config.plane_role_config().has_specialization() {
33            Specializer {
34                kind: comptime! {
35                    SpecializerKind::Specialized {
36                        main_flow_loading_side: loading_sides.main_flow,
37                        load_only_loading_side: loading_sides.load_only,
38                        role_rule_config: plane_role_config.rule
39                    }
40                },
41            }
42        } else {
43            Specializer {
44                kind: comptime! {SpecializerKind::NotSpecialized},
45            }
46        }
47    }
48}