cubecl_cpp/metal/
arch.rs

1use std::{fmt::Display, str::FromStr};
2
3use crate::shared::Architecture;
4
5// We support Metal 3 family of GPUs
6
7pub enum MetalArchitecture {
8    Metal3,
9    Other,
10}
11
12impl FromStr for MetalArchitecture {
13    type Err = String;
14
15    fn from_str(s: &str) -> Result<Self, Self::Err> {
16        let norm = s.to_lowercase();
17        if norm.starts_with("metal3") {
18            Ok(MetalArchitecture::Metal3)
19        } else {
20            Ok(MetalArchitecture::Other)
21        }
22    }
23}
24
25impl Display for MetalArchitecture {
26    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
27        match self {
28            Self::Metal3 => write!(f, "metal3"),
29            Self::Other => write!(f, "other"),
30        }
31    }
32}
33
34impl Architecture for MetalArchitecture {
35    fn warp_size(&self) -> u32 {
36        64
37    }
38
39    fn is_wmma_capable(&self) -> bool {
40        true
41    }
42
43    fn is_mfma_capable(&self) -> bool {
44        false
45    }
46}