cuda_rust_wasm/runtime/
device.rs1use crate::{Result, runtime_error};
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
8pub struct DeviceProperties {
9 pub name: String,
10 pub total_memory: usize,
11 pub max_threads_per_block: u32,
12 pub max_blocks_per_grid: u32,
13 pub warp_size: u32,
14 pub compute_capability: (u32, u32),
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum BackendType {
20 Native,
21 WebGPU,
22 CPU,
23}
24
25pub struct Device {
27 backend: BackendType,
28 properties: DeviceProperties,
29 id: usize,
30}
31
32impl Device {
33 pub fn get_default() -> Result<Arc<Self>> {
35 let backend = Self::detect_backend();
37
38 let properties = match backend {
39 BackendType::Native => Self::get_native_properties()?,
40 BackendType::WebGPU => Self::get_webgpu_properties()?,
41 BackendType::CPU => Self::get_cpu_properties(),
42 };
43
44 Ok(Arc::new(Self {
45 backend,
46 properties,
47 id: 0,
48 }))
49 }
50
51 pub fn get_by_id(id: usize) -> Result<Arc<Self>> {
53 if id != 0 {
55 return Err(runtime_error!("Device {} not found", id));
56 }
57 Self::get_default()
58 }
59
60 pub fn count() -> Result<usize> {
62 Ok(1)
64 }
65
66 pub fn properties(&self) -> &DeviceProperties {
68 &self.properties
69 }
70
71 pub fn backend(&self) -> BackendType {
73 self.backend
74 }
75
76 pub fn id(&self) -> usize {
78 self.id
79 }
80
81 fn detect_backend() -> BackendType {
83 #[cfg(target_arch = "wasm32")]
84 {
85 return BackendType::WebGPU;
86 }
87
88 #[cfg(not(target_arch = "wasm32"))]
89 {
90 if crate::backend::native_gpu::is_cuda_available()
92 || crate::backend::native_gpu::is_rocm_available()
93 {
94 return BackendType::Native;
95 }
96
97 if Self::probe_webgpu() {
99 return BackendType::WebGPU;
100 }
101
102 BackendType::CPU
103 }
104 }
105
106 #[cfg(not(target_arch = "wasm32"))]
108 fn probe_webgpu() -> bool {
109 use pollster::FutureExt;
110 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
111 backends: wgpu::Backends::all(),
112 ..Default::default()
113 });
114 instance
115 .request_adapter(&wgpu::RequestAdapterOptions::default())
116 .block_on()
117 .is_some()
118 }
119
120 fn get_native_properties() -> Result<DeviceProperties> {
125 if let Ok(output) = std::process::Command::new("nvidia-smi")
127 .args([
128 "--query-gpu=name,memory.total,driver_version",
129 "--format=csv,noheader,nounits",
130 ])
131 .output()
132 {
133 if output.status.success() {
134 let stdout = String::from_utf8_lossy(&output.stdout);
135 let line = stdout.trim();
136 let parts: Vec<&str> = line.split(", ").collect();
137 if parts.len() >= 2 {
138 let name = parts[0].trim().to_string();
139 let mem_mb: usize = parts[1].trim().parse().unwrap_or(8192);
140 return Ok(DeviceProperties {
141 name,
142 total_memory: mem_mb * 1024 * 1024,
143 max_threads_per_block: 1024,
144 max_blocks_per_grid: 65535,
145 warp_size: 32,
146 compute_capability: (8, 0),
147 });
148 }
149 }
150 }
151
152 if let Ok(entries) = std::fs::read_dir("/sys/class/drm") {
154 for entry in entries.flatten() {
155 let vendor_path = entry.path().join("device/vendor");
156 if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
157 if vendor.trim() == "0x1002" {
158 let name = std::fs::read_to_string(
160 entry.path().join("device/product_name"),
161 )
162 .unwrap_or_else(|_| "AMD GPU".to_string());
163 return Ok(DeviceProperties {
164 name: name.trim().to_string(),
165 total_memory: 16 * 1024 * 1024 * 1024,
166 max_threads_per_block: 1024,
167 max_blocks_per_grid: 65535,
168 warp_size: 64,
169 compute_capability: (9, 0),
170 });
171 }
172 }
173 }
174 }
175
176 Ok(DeviceProperties {
178 name: "GPU Device (properties unavailable)".to_string(),
179 total_memory: 8 * 1024 * 1024 * 1024,
180 max_threads_per_block: 1024,
181 max_blocks_per_grid: 65535,
182 warp_size: 32,
183 compute_capability: (0, 0),
184 })
185 }
186
187 fn get_webgpu_properties() -> Result<DeviceProperties> {
193 #[cfg(not(target_arch = "wasm32"))]
194 {
195 use pollster::FutureExt;
196 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
197 backends: wgpu::Backends::all(),
198 ..Default::default()
199 });
200 if let Some(adapter) = instance
201 .request_adapter(&wgpu::RequestAdapterOptions::default())
202 .block_on()
203 {
204 let info = adapter.get_info();
205 let limits = adapter.limits();
206 return Ok(DeviceProperties {
207 name: info.name,
208 total_memory: 0, max_threads_per_block: limits.max_compute_invocations_per_workgroup,
210 max_blocks_per_grid: limits.max_compute_workgroups_per_dimension,
211 warp_size: 32,
212 compute_capability: (1, 0),
213 });
214 }
215 }
216 Ok(DeviceProperties {
217 name: "WebGPU Device".to_string(),
218 total_memory: 2 * 1024 * 1024 * 1024,
219 max_threads_per_block: 256,
220 max_blocks_per_grid: 65535,
221 warp_size: 32,
222 compute_capability: (1, 0),
223 })
224 }
225
226 fn get_cpu_properties() -> DeviceProperties {
231 let name = std::fs::read_to_string("/proc/cpuinfo")
232 .ok()
233 .and_then(|info| {
234 info.lines()
235 .find(|l| l.starts_with("model name"))
236 .map(|l| l.split(':').nth(1).unwrap_or("CPU").trim().to_string())
237 })
238 .unwrap_or_else(|| "CPU Device".to_string());
239
240 let threads = std::thread::available_parallelism()
241 .map(|n| n.get())
242 .unwrap_or(1);
243
244 DeviceProperties {
245 name,
246 total_memory: 16 * 1024 * 1024 * 1024, max_threads_per_block: threads as u32,
248 max_blocks_per_grid: 65535,
249 warp_size: 1,
250 compute_capability: (0, 0),
251 }
252 }
253}