1use crate::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[cfg(target_os = "linux")]
6use std::process::Command;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum TPUVendor {
11 Google,
12 Intel,
13 Groq,
14 Cerebras,
15 Unknown(String),
16}
17
18impl std::fmt::Display for TPUVendor {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 TPUVendor::Google => write!(f, "Google"),
22 TPUVendor::Intel => write!(f, "Intel"),
23 TPUVendor::Groq => write!(f, "Groq"),
24 TPUVendor::Cerebras => write!(f, "Cerebras"),
25 TPUVendor::Unknown(name) => write!(f, "{name}"),
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TPUArchitecture {
33 GoogleTPUv2,
34 GoogleTPUv3,
35 GoogleTPUv4,
36 GoogleTPUv5,
37 GoogleCoralEdge,
38 IntelHabanaGaudi,
39 IntelHabanaGaudi2,
40 IntelHabanaGoya,
41 GroqLPU,
42 CerebrasWSE,
43 Unknown(String),
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum TPUConnectionType {
49 CloudTPU, PCIe, USB, M2, Network, Unknown,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct TPUInfo {
60 pub vendor: TPUVendor,
62
63 pub model_name: String,
65
66 pub architecture: TPUArchitecture,
68
69 pub connection_type: TPUConnectionType,
71
72 pub tops_performance: Option<f32>,
74
75 pub memory_gb: Option<f32>,
77
78 pub memory_bandwidth_gbps: Option<f32>,
80
81 pub core_count: Option<u32>,
83
84 pub driver_version: Option<String>,
86
87 pub firmware_version: Option<String>,
89
90 pub device_id: Option<String>,
92
93 pub supported_frameworks: Vec<String>,
95
96 pub power_consumption: Option<f32>,
98
99 pub temperature: Option<f32>,
101
102 pub clock_frequency: Option<u32>,
104
105 pub supported_dtypes: Vec<String>,
107
108 pub capabilities: HashMap<String, String>,
110}
111
112impl TPUInfo {
113 pub fn query_all() -> Result<Vec<TPUInfo>> {
115 let mut tpus = Vec::new();
116
117 tpus.extend(Self::detect_google_tpus()?);
119 tpus.extend(Self::detect_intel_habana()?);
120 tpus.extend(Self::detect_edge_tpus()?);
121 tpus.extend(Self::detect_groq_lpus()?);
122 tpus.extend(Self::detect_cerebras_wse()?);
123
124 Ok(tpus)
125 }
126
127 fn detect_google_tpus() -> Result<Vec<TPUInfo>> {
129 let mut tpus = Vec::new();
130
131 if let Ok(tpu_name) = std::env::var("TPU_NAME") {
133 let (architecture, tops, memory_gb) = Self::parse_google_tpu_specs(&tpu_name);
135
136 tpus.push(TPUInfo {
137 vendor: TPUVendor::Google,
138 model_name: format!("Google Cloud TPU ({tpu_name})"),
139 architecture,
140 connection_type: TPUConnectionType::CloudTPU,
141 tops_performance: Some(tops),
142 memory_gb: Some(memory_gb),
143 memory_bandwidth_gbps: Some(600.0), core_count: Some(2), driver_version: Self::get_tpu_driver_version(),
146 firmware_version: None,
147 device_id: Some(tpu_name),
148 supported_frameworks: vec![
149 "TensorFlow".to_string(),
150 "JAX".to_string(),
151 "PyTorch/XLA".to_string(),
152 ],
153 power_consumption: Some(200.0), temperature: None,
155 clock_frequency: Some(1000), supported_dtypes: vec![
157 "bfloat16".to_string(),
158 "float32".to_string(),
159 "int8".to_string(),
160 "int32".to_string(),
161 ],
162 capabilities: HashMap::from([
163 ("matrix_units".to_string(), "true".to_string()),
164 ("vector_units".to_string(), "true".to_string()),
165 ("scalar_units".to_string(), "true".to_string()),
166 ]),
167 });
168 }
169
170 Ok(tpus)
171 }
172
173 fn parse_google_tpu_specs(tpu_name: &str) -> (TPUArchitecture, f32, f32) {
174 if tpu_name.contains("v5") {
175 (TPUArchitecture::GoogleTPUv5, 275.0, 16.0)
176 } else if tpu_name.contains("v4") {
177 (TPUArchitecture::GoogleTPUv4, 275.0, 32.0)
178 } else if tpu_name.contains("v3") {
179 (TPUArchitecture::GoogleTPUv3, 123.0, 16.0)
180 } else if tpu_name.contains("v2") {
181 (TPUArchitecture::GoogleTPUv2, 45.0, 8.0)
182 } else {
183 (TPUArchitecture::GoogleTPUv4, 275.0, 32.0) }
185 }
186
187 fn detect_edge_tpus() -> Result<Vec<TPUInfo>> {
189 let tpus = Vec::new();
190
191 #[cfg(target_os = "linux")]
192 {
193 if let Ok(output) = Command::new("lsusb").output() {
195 let output_str = String::from_utf8_lossy(&output.stdout);
196 for line in output_str.lines() {
197 if line.contains("18d1") && line.contains("9302") { tpus.push(TPUInfo {
199 vendor: TPUVendor::Google,
200 model_name: "Google Coral Edge TPU".to_string(),
201 architecture: TPUArchitecture::GoogleCoralEdge,
202 connection_type: TPUConnectionType::USB,
203 tops_performance: Some(4.0), memory_gb: None, memory_bandwidth_gbps: Some(2.0), core_count: Some(1),
207 driver_version: Self::get_edge_tpu_driver_version(),
208 firmware_version: None,
209 device_id: Some("18d1:9302".to_string()),
210 supported_frameworks: vec![
211 "TensorFlow Lite".to_string(),
212 "PyCoral".to_string(),
213 "OpenVINO".to_string(),
214 ],
215 power_consumption: Some(2.0), temperature: None,
217 clock_frequency: Some(500), supported_dtypes: vec![
219 "int8".to_string(),
220 "uint8".to_string(),
221 ],
222 capabilities: HashMap::from([
223 ("quantized_only".to_string(), "true".to_string()),
224 ("edge_optimized".to_string(), "true".to_string()),
225 ]),
226 });
227 }
228 }
229 }
230
231 if let Ok(output) = Command::new("lspci").output() {
233 let output_str = String::from_utf8_lossy(&output.stdout);
234 for line in output_str.lines() {
235 if line.contains("Coral") || (line.contains("Google") && line.contains("Edge")) {
236 tpus.push(TPUInfo {
237 vendor: TPUVendor::Google,
238 model_name: "Google Coral Edge TPU (PCIe)".to_string(),
239 architecture: TPUArchitecture::GoogleCoralEdge,
240 connection_type: TPUConnectionType::M2,
241 tops_performance: Some(4.0),
242 memory_gb: None,
243 memory_bandwidth_gbps: Some(8.0), core_count: Some(1),
245 driver_version: Self::get_edge_tpu_driver_version(),
246 firmware_version: None,
247 device_id: None,
248 supported_frameworks: vec![
249 "TensorFlow Lite".to_string(),
250 "PyCoral".to_string(),
251 ],
252 power_consumption: Some(2.5), temperature: None,
254 clock_frequency: Some(500),
255 supported_dtypes: vec![
256 "int8".to_string(),
257 "uint8".to_string(),
258 ],
259 capabilities: HashMap::from([
260 ("quantized_only".to_string(), "true".to_string()),
261 ("edge_optimized".to_string(), "true".to_string()),
262 ("pcie_interface".to_string(), "true".to_string()),
263 ]),
264 });
265 }
266 }
267 }
268 }
269
270 Ok(tpus)
271 }
272
273 fn detect_intel_habana() -> Result<Vec<TPUInfo>> {
275 let tpus = Vec::new();
276
277 #[cfg(target_os = "linux")]
278 {
279 if std::path::Path::new("/sys/class/accel").exists() {
281 if let Ok(entries) = std::fs::read_dir("/sys/class/accel") {
282 for entry in entries.flatten() {
283 if let Some(name) = entry.file_name().to_str() {
284 if name.starts_with("accel") {
285 let device_path = format!("/sys/class/accel/{}/device", name);
287 if let Ok(vendor) = std::fs::read_to_string(format!("{}/vendor", device_path)) {
288 if vendor.trim() == "0x1da3" { if let Ok(device) = std::fs::read_to_string(format!("{}/device", device_path)) {
290 let (model_name, architecture, tops) = match device.trim() {
291 "0x1000" => ("Intel Habana Gaudi", TPUArchitecture::IntelHabanaGaudi, 400.0),
292 "0x1020" => ("Intel Habana Gaudi2", TPUArchitecture::IntelHabanaGaudi2, 800.0),
293 "0x1050" => ("Intel Habana Goya", TPUArchitecture::IntelHabanaGoya, 100.0),
294 _ => ("Intel Habana Device", TPUArchitecture::IntelHabanaGaudi, 400.0),
295 };
296
297 tpus.push(TPUInfo {
298 vendor: TPUVendor::Intel,
299 model_name: model_name.to_string(),
300 architecture,
301 connection_type: TPUConnectionType::PCIe,
302 tops_performance: Some(tops),
303 memory_gb: Some(32.0), memory_bandwidth_gbps: Some(2400.0), core_count: Some(8), driver_version: Self::get_habana_driver_version(),
307 firmware_version: None,
308 device_id: Some(device.trim().to_string()),
309 supported_frameworks: vec![
310 "PyTorch".to_string(),
311 "TensorFlow".to_string(),
312 "ONNX Runtime".to_string(),
313 "Habana SynapseAI".to_string(),
314 ],
315 power_consumption: Some(350.0), temperature: None,
317 clock_frequency: Some(1300), supported_dtypes: vec![
319 "float32".to_string(),
320 "bfloat16".to_string(),
321 "float16".to_string(),
322 "int8".to_string(),
323 ],
324 capabilities: HashMap::from([
325 ("matrix_multiply_engine".to_string(), "true".to_string()),
326 ("tensor_processor_core".to_string(), "true".to_string()),
327 ("high_bandwidth_memory".to_string(), "true".to_string()),
328 ]),
329 });
330 }
331 }
332 }
333 }
334 }
335 }
336 }
337 }
338 }
339
340 Ok(tpus)
341 }
342
343 fn detect_groq_lpus() -> Result<Vec<TPUInfo>> {
345 Ok(Vec::new())
347 }
348
349 fn detect_cerebras_wse() -> Result<Vec<TPUInfo>> {
350 Ok(Vec::new())
352 }
353
354 fn get_tpu_driver_version() -> Option<String> {
356 #[cfg(target_os = "linux")]
358 {
359 if let Ok(output) = Command::new("python3")
360 .args(["-c", "import tensorflow; print(tensorflow.__version__)"])
361 .output()
362 {
363 if output.status.success() {
364 return Some(format!("TensorFlow {}", String::from_utf8_lossy(&output.stdout).trim()));
365 }
366 }
367 }
368 None
369 }
370
371 pub fn vendor(&self) -> &TPUVendor {
373 &self.vendor
374 }
375
376 pub fn model_name(&self) -> &str {
378 &self.model_name
379 }
380
381 pub fn architecture(&self) -> &TPUArchitecture {
383 &self.architecture
384 }
385
386 pub fn tops_performance(&self) -> Option<f32> {
388 self.tops_performance
389 }
390
391 pub fn supports_framework(&self, framework: &str) -> bool {
393 self.supported_frameworks.iter()
394 .any(|f| f.to_lowercase().contains(&framework.to_lowercase()))
395 }
396
397 pub fn supports_dtype(&self, dtype: &str) -> bool {
399 self.supported_dtypes.iter()
400 .any(|d| d.to_lowercase() == dtype.to_lowercase())
401 }
402}