cubecl_cpp/hip/
arch.rs

1use crate::shared::Architecture;
2
3pub enum AMDArchitecture {
4    // RDNA
5    // gfx1100, gfx1101, gfx1102
6    GFX11,
7    // gfx1030, gfx1031, gfx1032
8    GFX10,
9    // CDNA
10    GFX908,
11    GFX90A,
12    // gfx940, gfx941, gfx942
13    GFX94,
14    // Not particularly specific architecture
15    Other,
16}
17
18impl AMDArchitecture {
19    pub fn parse(arg: &str) -> Result<Self, String> {
20        let norm = arg.to_lowercase();
21        if norm.starts_with("gfx11") {
22            Ok(AMDArchitecture::GFX11)
23        } else if norm.starts_with("gfx10") {
24            Ok(AMDArchitecture::GFX10)
25        } else if norm == "gfx908" {
26            Ok(AMDArchitecture::GFX908)
27        } else if norm == "gfx90a" {
28            Ok(AMDArchitecture::GFX90A)
29        } else if norm.starts_with("gfx94") {
30            Ok(AMDArchitecture::GFX94)
31        } else {
32            Ok(AMDArchitecture::Other)
33        }
34    }
35}
36
37impl Architecture for AMDArchitecture {
38    fn warp_size(&self) -> u32 {
39        // CDNA supports wave64 (gfx9 and gfx940+) and RDNA wave32 (gfx10, gfx11)
40        match self {
41            AMDArchitecture::GFX10 | AMDArchitecture::GFX11 => 32,
42            AMDArchitecture::GFX908 | AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => 64,
43            AMDArchitecture::Other => 0,
44        }
45    }
46
47    fn is_wmma_capable(&self) -> bool {
48        matches!(self, AMDArchitecture::GFX10 | AMDArchitecture::GFX11)
49    }
50
51    fn is_mfma_capable(&self) -> bool {
52        matches!(
53            self,
54            AMDArchitecture::GFX908 | AMDArchitecture::GFX90A | AMDArchitecture::GFX94
55        )
56    }
57}