cubecl_matmul/components/global/specialization/
specializer.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::global::specialization::config::LoadingSides;
5use crate::components::global::specialization::roles::RoleRuleConfig;
6use crate::components::global::{PlaneRoleConfig, SpecializedLoadingSides};
7
8#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
9pub enum SpecializerKind {
11 Specialized {
12 main_flow_loading_side: LoadingSides,
13 load_only_loading_side: LoadingSides,
14 role_rule_config: RoleRuleConfig,
15 },
16 NotSpecialized,
17}
18
19#[derive(CubeType, Copy, Clone, Debug, Hash, PartialEq, Eq)]
20pub struct Specializer {
22 #[cube(comptime)]
23 pub kind: SpecializerKind,
24}
25
26#[cube]
27impl Specializer {
28 pub fn new(
29 #[comptime] plane_role_config: PlaneRoleConfig,
30 #[comptime] loading_sides: SpecializedLoadingSides,
31 ) -> Specializer {
32 if plane_role_config.has_specialization() {
33 Specializer {
34 kind: comptime! {
35 SpecializerKind::Specialized {
36 main_flow_loading_side: loading_sides.main_flow,
37 load_only_loading_side: loading_sides.load_only,
38 role_rule_config: plane_role_config.rule
39 }
40 },
41 }
42 } else {
43 Specializer {
44 kind: comptime! {SpecializerKind::NotSpecialized},
45 }
46 }
47 }
48}