hardware_query/
tpu.rs

1use crate::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[cfg(target_os = "linux")]
6use std::process::Command;
7
8/// TPU vendor information
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10pub enum TPUVendor {
11    Google,
12    Intel,
13    Groq,
14    Cerebras,
15    Unknown(String),
16}
17
18impl std::fmt::Display for TPUVendor {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            TPUVendor::Google => write!(f, "Google"),
22            TPUVendor::Intel => write!(f, "Intel"),
23            TPUVendor::Groq => write!(f, "Groq"),
24            TPUVendor::Cerebras => write!(f, "Cerebras"),
25            TPUVendor::Unknown(name) => write!(f, "{name}"),
26        }
27    }
28}
29
30/// TPU generation and architecture
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TPUArchitecture {
33    GoogleTPUv2,
34    GoogleTPUv3,
35    GoogleTPUv4,
36    GoogleTPUv5,
37    GoogleCoralEdge,
38    IntelHabanaGaudi,
39    IntelHabanaGaudi2,
40    IntelHabanaGoya,
41    GroqLPU,
42    CerebrasWSE,
43    Unknown(String),
44}
45
46/// TPU connection type
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum TPUConnectionType {
49    CloudTPU,    // Google Cloud TPU
50    PCIe,        // PCIe card
51    USB,         // USB device (Edge TPU)
52    M2,          // M.2 module
53    Network,     // Network-attached
54    Unknown,
55}
56
57/// TPU information structure
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct TPUInfo {
60    /// TPU vendor
61    pub vendor: TPUVendor,
62    
63    /// TPU model name
64    pub model_name: String,
65    
66    /// TPU architecture/generation
67    pub architecture: TPUArchitecture,
68    
69    /// Connection type
70    pub connection_type: TPUConnectionType,
71    
72    /// Performance in TOPS (Tera Operations Per Second)
73    pub tops_performance: Option<f32>,
74    
75    /// Memory size in GB
76    pub memory_gb: Option<f32>,
77    
78    /// Memory bandwidth in GB/s
79    pub memory_bandwidth_gbps: Option<f32>,
80    
81    /// Number of cores/processing units
82    pub core_count: Option<u32>,
83    
84    /// Driver version
85    pub driver_version: Option<String>,
86    
87    /// Firmware version
88    pub firmware_version: Option<String>,
89    
90    /// Device ID
91    pub device_id: Option<String>,
92    
93    /// Supported frameworks
94    pub supported_frameworks: Vec<String>,
95    
96    /// Power consumption in watts
97    pub power_consumption: Option<f32>,
98    
99    /// Operating temperature in Celsius
100    pub temperature: Option<f32>,
101    
102    /// Clock frequency in MHz
103    pub clock_frequency: Option<u32>,
104    
105    /// Supported data types
106    pub supported_dtypes: Vec<String>,
107    
108    /// Additional capabilities
109    pub capabilities: HashMap<String, String>,
110}
111
112impl TPUInfo {
113    /// Query all available TPUs in the system
114    pub fn query_all() -> Result<Vec<TPUInfo>> {
115        let mut tpus = Vec::new();
116        
117        // Detect various TPU types
118        tpus.extend(Self::detect_google_tpus()?);
119        tpus.extend(Self::detect_intel_habana()?);
120        tpus.extend(Self::detect_edge_tpus()?);
121        tpus.extend(Self::detect_groq_lpus()?);
122        tpus.extend(Self::detect_cerebras_wse()?);
123        
124        Ok(tpus)
125    }
126    
127    /// Detect Google TPUs (Cloud and Edge)
128    fn detect_google_tpus() -> Result<Vec<TPUInfo>> {
129        let mut tpus = Vec::new();
130        
131        // Check for Google Cloud TPU via environment
132        if let Ok(tpu_name) = std::env::var("TPU_NAME") {
133            // Parse TPU version from name or metadata
134            let (architecture, tops, memory_gb) = Self::parse_google_tpu_specs(&tpu_name);
135            
136            tpus.push(TPUInfo {
137                vendor: TPUVendor::Google,
138                model_name: format!("Google Cloud TPU ({tpu_name})"),
139                architecture,
140                connection_type: TPUConnectionType::CloudTPU,
141                tops_performance: Some(tops),
142                memory_gb: Some(memory_gb),
143                memory_bandwidth_gbps: Some(600.0), // Typical for TPU v4
144                core_count: Some(2), // Typical core count
145                driver_version: Self::get_tpu_driver_version(),
146                firmware_version: None,
147                device_id: Some(tpu_name),
148                supported_frameworks: vec![
149                    "TensorFlow".to_string(),
150                    "JAX".to_string(),
151                    "PyTorch/XLA".to_string(),
152                ],
153                power_consumption: Some(200.0), // Estimated
154                temperature: None,
155                clock_frequency: Some(1000), // ~1GHz
156                supported_dtypes: vec![
157                    "bfloat16".to_string(),
158                    "float32".to_string(),
159                    "int8".to_string(),
160                    "int32".to_string(),
161                ],
162                capabilities: HashMap::from([
163                    ("matrix_units".to_string(), "true".to_string()),
164                    ("vector_units".to_string(), "true".to_string()),
165                    ("scalar_units".to_string(), "true".to_string()),
166                ]),
167            });
168        }
169        
170        Ok(tpus)
171    }
172    
173    fn parse_google_tpu_specs(tpu_name: &str) -> (TPUArchitecture, f32, f32) {
174        if tpu_name.contains("v5") {
175            (TPUArchitecture::GoogleTPUv5, 275.0, 16.0)
176        } else if tpu_name.contains("v4") {
177            (TPUArchitecture::GoogleTPUv4, 275.0, 32.0)
178        } else if tpu_name.contains("v3") {
179            (TPUArchitecture::GoogleTPUv3, 123.0, 16.0)
180        } else if tpu_name.contains("v2") {
181            (TPUArchitecture::GoogleTPUv2, 45.0, 8.0)
182        } else {
183            (TPUArchitecture::GoogleTPUv4, 275.0, 32.0) // Default to v4
184        }
185    }
186    
187    /// Detect Google Coral Edge TPUs
188    fn detect_edge_tpus() -> Result<Vec<TPUInfo>> {
189        let tpus = Vec::new();
190        
191        #[cfg(target_os = "linux")]
192        {
193            // Check for Edge TPU via USB
194            if let Ok(output) = Command::new("lsusb").output() {
195                let output_str = String::from_utf8_lossy(&output.stdout);
196                for line in output_str.lines() {
197                    if line.contains("18d1") && line.contains("9302") { // Google Edge TPU USB
198                        tpus.push(TPUInfo {
199                            vendor: TPUVendor::Google,
200                            model_name: "Google Coral Edge TPU".to_string(),
201                            architecture: TPUArchitecture::GoogleCoralEdge,
202                            connection_type: TPUConnectionType::USB,
203                            tops_performance: Some(4.0), // 4 TOPS at INT8
204                            memory_gb: None, // Uses host memory
205                            memory_bandwidth_gbps: Some(2.0), // USB 3.0 bandwidth
206                            core_count: Some(1),
207                            driver_version: Self::get_edge_tpu_driver_version(),
208                            firmware_version: None,
209                            device_id: Some("18d1:9302".to_string()),
210                            supported_frameworks: vec![
211                                "TensorFlow Lite".to_string(),
212                                "PyCoral".to_string(),
213                                "OpenVINO".to_string(),
214                            ],
215                            power_consumption: Some(2.0), // ~2W
216                            temperature: None,
217                            clock_frequency: Some(500), // ~500MHz
218                            supported_dtypes: vec![
219                                "int8".to_string(),
220                                "uint8".to_string(),
221                            ],
222                            capabilities: HashMap::from([
223                                ("quantized_only".to_string(), "true".to_string()),
224                                ("edge_optimized".to_string(), "true".to_string()),
225                            ]),
226                        });
227                    }
228                }
229            }
230            
231            // Check for Edge TPU via PCIe (M.2 or Mini PCIe)
232            if let Ok(output) = Command::new("lspci").output() {
233                let output_str = String::from_utf8_lossy(&output.stdout);
234                for line in output_str.lines() {
235                    if line.contains("Coral") || (line.contains("Google") && line.contains("Edge")) {
236                        tpus.push(TPUInfo {
237                            vendor: TPUVendor::Google,
238                            model_name: "Google Coral Edge TPU (PCIe)".to_string(),
239                            architecture: TPUArchitecture::GoogleCoralEdge,
240                            connection_type: TPUConnectionType::M2,
241                            tops_performance: Some(4.0),
242                            memory_gb: None,
243                            memory_bandwidth_gbps: Some(8.0), // PCIe bandwidth
244                            core_count: Some(1),
245                            driver_version: Self::get_edge_tpu_driver_version(),
246                            firmware_version: None,
247                            device_id: None,
248                            supported_frameworks: vec![
249                                "TensorFlow Lite".to_string(),
250                                "PyCoral".to_string(),
251                            ],
252                            power_consumption: Some(2.5), // Slightly higher for PCIe
253                            temperature: None,
254                            clock_frequency: Some(500),
255                            supported_dtypes: vec![
256                                "int8".to_string(),
257                                "uint8".to_string(),
258                            ],
259                            capabilities: HashMap::from([
260                                ("quantized_only".to_string(), "true".to_string()),
261                                ("edge_optimized".to_string(), "true".to_string()),
262                                ("pcie_interface".to_string(), "true".to_string()),
263                            ]),
264                        });
265                    }
266                }
267            }
268        }
269        
270        Ok(tpus)
271    }
272    
273    /// Detect Intel Habana accelerators
274    fn detect_intel_habana() -> Result<Vec<TPUInfo>> {
275        let tpus = Vec::new();
276        
277        #[cfg(target_os = "linux")]
278        {
279            // Check for Habana devices via sysfs
280            if std::path::Path::new("/sys/class/accel").exists() {
281                if let Ok(entries) = std::fs::read_dir("/sys/class/accel") {
282                    for entry in entries.flatten() {
283                        if let Some(name) = entry.file_name().to_str() {
284                            if name.starts_with("accel") {
285                                // Try to determine if it's Habana
286                                let device_path = format!("/sys/class/accel/{}/device", name);
287                                if let Ok(vendor) = std::fs::read_to_string(format!("{}/vendor", device_path)) {
288                                    if vendor.trim() == "0x1da3" { // Intel vendor ID for Habana
289                                        if let Ok(device) = std::fs::read_to_string(format!("{}/device", device_path)) {
290                                            let (model_name, architecture, tops) = match device.trim() {
291                                                "0x1000" => ("Intel Habana Gaudi", TPUArchitecture::IntelHabanaGaudi, 400.0),
292                                                "0x1020" => ("Intel Habana Gaudi2", TPUArchitecture::IntelHabanaGaudi2, 800.0),
293                                                "0x1050" => ("Intel Habana Goya", TPUArchitecture::IntelHabanaGoya, 100.0),
294                                                _ => ("Intel Habana Device", TPUArchitecture::IntelHabanaGaudi, 400.0),
295                                            };
296                                            
297                                            tpus.push(TPUInfo {
298                                                vendor: TPUVendor::Intel,
299                                                model_name: model_name.to_string(),
300                                                architecture,
301                                                connection_type: TPUConnectionType::PCIe,
302                                                tops_performance: Some(tops),
303                                                memory_gb: Some(32.0), // Typical for Gaudi
304                                                memory_bandwidth_gbps: Some(2400.0), // HBM2E bandwidth
305                                                core_count: Some(8), // Typical core count
306                                                driver_version: Self::get_habana_driver_version(),
307                                                firmware_version: None,
308                                                device_id: Some(device.trim().to_string()),
309                                                supported_frameworks: vec![
310                                                    "PyTorch".to_string(),
311                                                    "TensorFlow".to_string(),
312                                                    "ONNX Runtime".to_string(),
313                                                    "Habana SynapseAI".to_string(),
314                                                ],
315                                                power_consumption: Some(350.0), // High performance = high power
316                                                temperature: None,
317                                                clock_frequency: Some(1300), // ~1.3GHz
318                                                supported_dtypes: vec![
319                                                    "float32".to_string(),
320                                                    "bfloat16".to_string(),
321                                                    "float16".to_string(),
322                                                    "int8".to_string(),
323                                                ],
324                                                capabilities: HashMap::from([
325                                                    ("matrix_multiply_engine".to_string(), "true".to_string()),
326                                                    ("tensor_processor_core".to_string(), "true".to_string()),
327                                                    ("high_bandwidth_memory".to_string(), "true".to_string()),
328                                                ]),
329                                            });
330                                        }
331                                    }
332                                }
333                            }
334                        }
335                    }
336                }
337            }
338        }
339        
340        Ok(tpus)
341    }
342    
343    // Placeholder implementations for other vendors
344    fn detect_groq_lpus() -> Result<Vec<TPUInfo>> {
345        // Note: Groq LPU detection requires Groq SDK and runtime
346        Ok(Vec::new())
347    }
348    
349    fn detect_cerebras_wse() -> Result<Vec<TPUInfo>> {
350        // Note: Cerebras WSE detection requires Cerebras SDK and runtime
351        Ok(Vec::new())
352    }
353    
354    // Helper functions for driver version detection
355    fn get_tpu_driver_version() -> Option<String> {
356        // Try to get TensorFlow version with TPU support
357        #[cfg(target_os = "linux")]
358        {
359            if let Ok(output) = Command::new("python3")
360                .args(["-c", "import tensorflow; print(tensorflow.__version__)"])
361                .output()
362            {
363                if output.status.success() {
364                    return Some(format!("TensorFlow {}", String::from_utf8_lossy(&output.stdout).trim()));
365                }
366            }
367        }
368        None
369    }
370    
371    /// Get TPU vendor
372    pub fn vendor(&self) -> &TPUVendor {
373        &self.vendor
374    }
375    
376    /// Get TPU model name
377    pub fn model_name(&self) -> &str {
378        &self.model_name
379    }
380    
381    /// Get TPU architecture
382    pub fn architecture(&self) -> &TPUArchitecture {
383        &self.architecture
384    }
385    
386    /// Get performance in TOPS
387    pub fn tops_performance(&self) -> Option<f32> {
388        self.tops_performance
389    }
390    
391    /// Check if TPU supports a specific framework
392    pub fn supports_framework(&self, framework: &str) -> bool {
393        self.supported_frameworks.iter()
394            .any(|f| f.to_lowercase().contains(&framework.to_lowercase()))
395    }
396    
397    /// Check if TPU supports a specific data type
398    pub fn supports_dtype(&self, dtype: &str) -> bool {
399        self.supported_dtypes.iter()
400            .any(|d| d.to_lowercase() == dtype.to_lowercase())
401    }
402}