cubecl_cpp/hip/
arch.rs

1use crate::shared::Architecture;
2
3pub enum AMDArchitecture {
4    // RDNA
5    // gfx1200, gfx1201 (RDNA4)
6    GFX12,
7    // gfx1100, gfx1101, gfx1102
8    GFX11,
9    // gfx1030, gfx1031, gfx1032
10    GFX10,
11    // CDNA
12    GFX908,
13    GFX90A,
14    // gfx940, gfx941, gfx942
15    GFX94,
16    // Not particularly specific architecture
17    Other,
18}
19
20impl AMDArchitecture {
21    pub fn parse(arg: &str) -> Result<Self, String> {
22        let norm = arg.to_lowercase();
23        if norm.starts_with("gfx12") {
24            Ok(AMDArchitecture::GFX12)
25        } else if norm.starts_with("gfx11") {
26            Ok(AMDArchitecture::GFX11)
27        } else if norm.starts_with("gfx10") {
28            Ok(AMDArchitecture::GFX10)
29        } else if norm == "gfx908" {
30            Ok(AMDArchitecture::GFX908)
31        } else if norm == "gfx90a" {
32            Ok(AMDArchitecture::GFX90A)
33        } else if norm.starts_with("gfx94") {
34            Ok(AMDArchitecture::GFX94)
35        } else {
36            Ok(AMDArchitecture::Other)
37        }
38    }
39}
40
41impl Architecture for AMDArchitecture {
42    fn warp_size(&self) -> u32 {
43        // CDNA supports wave64 (gfx9 and gfx940+) and RDNA wave32 (gfx10, gfx11, gfx12)
44        match self {
45            AMDArchitecture::GFX10 | AMDArchitecture::GFX11 | AMDArchitecture::GFX12 => 32,
46            AMDArchitecture::GFX908 | AMDArchitecture::GFX90A | AMDArchitecture::GFX94 => 64,
47            AMDArchitecture::Other => 0,
48        }
49    }
50
51    fn is_wmma_capable(&self) -> bool {
52        matches!(
53            self,
54            AMDArchitecture::GFX10 | AMDArchitecture::GFX11 | AMDArchitecture::GFX12
55        )
56    }
57
58    fn is_mfma_capable(&self) -> bool {
59        matches!(
60            self,
61            AMDArchitecture::GFX908 | AMDArchitecture::GFX90A | AMDArchitecture::GFX94
62        )
63    }
64}