entrenar/efficiency/device/
compute.rs1use serde::{Deserialize, Serialize};
4
5use super::apple::AppleSiliconInfo;
6use super::cpu::CpuInfo;
7use super::gpu::GpuInfo;
8use super::tpu::TpuInfo;
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum ComputeDevice {
13 Cpu(CpuInfo),
15 Gpu(GpuInfo),
17 Tpu(TpuInfo),
19 AppleSilicon(AppleSiliconInfo),
21}
22
23impl ComputeDevice {
24 pub fn detect() -> Vec<Self> {
26 let mut devices = Vec::new();
27
28 devices.push(Self::Cpu(CpuInfo::detect()));
30
31 if let Some(apple) = AppleSiliconInfo::detect() {
33 devices.push(Self::AppleSilicon(apple));
34 }
35
36 devices
40 }
41
42 pub fn is_gpu(&self) -> bool {
44 matches!(self, Self::Gpu(_))
45 }
46
47 pub fn is_cpu(&self) -> bool {
49 matches!(self, Self::Cpu(_))
50 }
51
52 pub fn is_tpu(&self) -> bool {
54 matches!(self, Self::Tpu(_))
55 }
56
57 pub fn is_apple_silicon(&self) -> bool {
59 matches!(self, Self::AppleSilicon(_))
60 }
61
62 pub fn memory_bytes(&self) -> u64 {
64 match self {
65 Self::Cpu(info) => {
66 u64::from(info.cores) * 4 * 1024 * 1024 * 1024 }
70 Self::Gpu(info) => info.vram_bytes,
71 Self::Tpu(info) => info.hbm_bytes,
72 Self::AppleSilicon(info) => info.unified_memory_bytes,
73 }
74 }
75
76 pub fn name(&self) -> &str {
78 match self {
79 Self::Cpu(info) => &info.model,
80 Self::Gpu(info) => &info.name,
81 Self::Tpu(info) => &info.version,
82 Self::AppleSilicon(info) => &info.chip,
83 }
84 }
85
86 pub fn compute_units(&self) -> u32 {
88 match self {
89 Self::Cpu(info) => info.threads,
90 Self::Gpu(_) => 0, Self::Tpu(info) => info.cores,
92 Self::AppleSilicon(info) => info.total_cpu_cores() + info.gpu_cores,
93 }
94 }
95
96 pub fn relative_compute_power(&self) -> f64 {
98 match self {
99 Self::Cpu(info) => f64::from(info.threads) / 8.0, Self::Gpu(info) => {
101 10.0 * (info.vram_gb() / 8.0) }
104 Self::Tpu(info) => {
105 50.0 * (f64::from(info.cores) / 8.0)
107 }
108 Self::AppleSilicon(info) => {
109 (f64::from(info.p_cores) * 1.5 + f64::from(info.e_cores) * 0.5) / 8.0
111 + f64::from(info.gpu_cores) * 0.5
112 }
113 }
114 }
115}
116
117impl std::fmt::Display for ComputeDevice {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 Self::Cpu(info) => write!(
121 f,
122 "CPU: {} ({} cores, {} threads, {})",
123 info.model, info.cores, info.threads, info.simd
124 ),
125 Self::Gpu(info) => {
126 write!(f, "GPU: {} ({:.1} GB VRAM", info.name, info.vram_gb())?;
127 if let Some((major, minor)) = info.compute_capability {
128 write!(f, ", SM {major}.{minor}")?;
129 }
130 write!(f, ")")
131 }
132 Self::Tpu(info) => write!(
133 f,
134 "TPU: {} ({} cores, {:.1} GB HBM)",
135 info.version,
136 info.cores,
137 info.hbm_gb()
138 ),
139 Self::AppleSilicon(info) => write!(
140 f,
141 "Apple Silicon: {} ({}P+{}E cores, {} GPU cores, {:.1} GB)",
142 info.chip,
143 info.p_cores,
144 info.e_cores,
145 info.gpu_cores,
146 info.unified_memory_gb()
147 ),
148 }
149 }
150}