Skip to main content

cuda_rust_wasm/runtime/
device.rs

1//! Device abstraction for different backends
2
3use crate::{Result, runtime_error};
4use std::sync::Arc;
5
6/// Device properties
7#[derive(Debug, Clone)]
8pub struct DeviceProperties {
9    pub name: String,
10    pub total_memory: usize,
11    pub max_threads_per_block: u32,
12    pub max_blocks_per_grid: u32,
13    pub warp_size: u32,
14    pub compute_capability: (u32, u32),
15}
16
17/// Backend type
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum BackendType {
20    Native,
21    WebGPU,
22    CPU,
23}
24
25/// Device abstraction
26pub struct Device {
27    backend: BackendType,
28    properties: DeviceProperties,
29    id: usize,
30}
31
32impl Device {
33    /// Get the default device
34    pub fn get_default() -> Result<Arc<Self>> {
35        // Detect available backend
36        let backend = Self::detect_backend();
37
38        let properties = match backend {
39            BackendType::Native => Self::get_native_properties()?,
40            BackendType::WebGPU => Self::get_webgpu_properties()?,
41            BackendType::CPU => Self::get_cpu_properties(),
42        };
43
44        Ok(Arc::new(Self {
45            backend,
46            properties,
47            id: 0,
48        }))
49    }
50
51    /// Get device by ID
52    pub fn get_by_id(id: usize) -> Result<Arc<Self>> {
53        // For now, only support device 0
54        if id != 0 {
55            return Err(runtime_error!("Device {} not found", id));
56        }
57        Self::get_default()
58    }
59
60    /// Get device count
61    pub fn count() -> Result<usize> {
62        // For now, always return 1
63        Ok(1)
64    }
65
66    /// Get device properties
67    pub fn properties(&self) -> &DeviceProperties {
68        &self.properties
69    }
70
71    /// Get backend type
72    pub fn backend(&self) -> BackendType {
73        self.backend
74    }
75
76    /// Get device ID
77    pub fn id(&self) -> usize {
78        self.id
79    }
80
81    /// Detect available backend
82    fn detect_backend() -> BackendType {
83        #[cfg(target_arch = "wasm32")]
84        {
85            return BackendType::WebGPU;
86        }
87
88        #[cfg(not(target_arch = "wasm32"))]
89        {
90            // Try native GPU detection
91            if crate::backend::native_gpu::is_cuda_available()
92                || crate::backend::native_gpu::is_rocm_available()
93            {
94                return BackendType::Native;
95            }
96
97            // Try WebGPU via wgpu
98            if Self::probe_webgpu() {
99                return BackendType::WebGPU;
100            }
101
102            BackendType::CPU
103        }
104    }
105
106    /// Probe whether a WebGPU-compatible adapter is available via wgpu.
107    #[cfg(not(target_arch = "wasm32"))]
108    fn probe_webgpu() -> bool {
109        use pollster::FutureExt;
110        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
111            backends: wgpu::Backends::all(),
112            ..Default::default()
113        });
114        instance
115            .request_adapter(&wgpu::RequestAdapterOptions::default())
116            .block_on()
117            .is_some()
118    }
119
120    /// Get native GPU properties by querying the system.
121    ///
122    /// Tries nvidia-smi for NVIDIA GPUs, then sysfs for AMD GPUs,
123    /// falling back to generic properties if neither is available.
124    fn get_native_properties() -> Result<DeviceProperties> {
125        // Try nvidia-smi first
126        if let Ok(output) = std::process::Command::new("nvidia-smi")
127            .args([
128                "--query-gpu=name,memory.total,driver_version",
129                "--format=csv,noheader,nounits",
130            ])
131            .output()
132        {
133            if output.status.success() {
134                let stdout = String::from_utf8_lossy(&output.stdout);
135                let line = stdout.trim();
136                let parts: Vec<&str> = line.split(", ").collect();
137                if parts.len() >= 2 {
138                    let name = parts[0].trim().to_string();
139                    let mem_mb: usize = parts[1].trim().parse().unwrap_or(8192);
140                    return Ok(DeviceProperties {
141                        name,
142                        total_memory: mem_mb * 1024 * 1024,
143                        max_threads_per_block: 1024,
144                        max_blocks_per_grid: 65535,
145                        warp_size: 32,
146                        compute_capability: (8, 0),
147                    });
148                }
149            }
150        }
151
152        // Try reading sysfs for AMD GPUs
153        if let Ok(entries) = std::fs::read_dir("/sys/class/drm") {
154            for entry in entries.flatten() {
155                let vendor_path = entry.path().join("device/vendor");
156                if let Ok(vendor) = std::fs::read_to_string(&vendor_path) {
157                    if vendor.trim() == "0x1002" {
158                        // AMD vendor ID
159                        let name = std::fs::read_to_string(
160                            entry.path().join("device/product_name"),
161                        )
162                        .unwrap_or_else(|_| "AMD GPU".to_string());
163                        return Ok(DeviceProperties {
164                            name: name.trim().to_string(),
165                            total_memory: 16 * 1024 * 1024 * 1024,
166                            max_threads_per_block: 1024,
167                            max_blocks_per_grid: 65535,
168                            warp_size: 64,
169                            compute_capability: (9, 0),
170                        });
171                    }
172                }
173            }
174        }
175
176        // Fallback: generic GPU properties
177        Ok(DeviceProperties {
178            name: "GPU Device (properties unavailable)".to_string(),
179            total_memory: 8 * 1024 * 1024 * 1024,
180            max_threads_per_block: 1024,
181            max_blocks_per_grid: 65535,
182            warp_size: 32,
183            compute_capability: (0, 0),
184        })
185    }
186
187    /// Get WebGPU device properties by querying a wgpu adapter.
188    ///
189    /// On non-wasm targets this creates a real wgpu instance and reads
190    /// the adapter info and limits.  Falls back to reasonable defaults
191    /// when no adapter is found or on wasm32.
192    fn get_webgpu_properties() -> Result<DeviceProperties> {
193        #[cfg(not(target_arch = "wasm32"))]
194        {
195            use pollster::FutureExt;
196            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
197                backends: wgpu::Backends::all(),
198                ..Default::default()
199            });
200            if let Some(adapter) = instance
201                .request_adapter(&wgpu::RequestAdapterOptions::default())
202                .block_on()
203            {
204                let info = adapter.get_info();
205                let limits = adapter.limits();
206                return Ok(DeviceProperties {
207                    name: info.name,
208                    total_memory: 0, // WebGPU does not expose total memory
209                    max_threads_per_block: limits.max_compute_invocations_per_workgroup,
210                    max_blocks_per_grid: limits.max_compute_workgroups_per_dimension,
211                    warp_size: 32,
212                    compute_capability: (1, 0),
213                });
214            }
215        }
216        Ok(DeviceProperties {
217            name: "WebGPU Device".to_string(),
218            total_memory: 2 * 1024 * 1024 * 1024,
219            max_threads_per_block: 256,
220            max_blocks_per_grid: 65535,
221            warp_size: 32,
222            compute_capability: (1, 0),
223        })
224    }
225
226    /// Get CPU properties by reading system information.
227    ///
228    /// Reads /proc/cpuinfo for the model name and queries
229    /// `available_parallelism` for the thread count.
230    fn get_cpu_properties() -> DeviceProperties {
231        let name = std::fs::read_to_string("/proc/cpuinfo")
232            .ok()
233            .and_then(|info| {
234                info.lines()
235                    .find(|l| l.starts_with("model name"))
236                    .map(|l| l.split(':').nth(1).unwrap_or("CPU").trim().to_string())
237            })
238            .unwrap_or_else(|| "CPU Device".to_string());
239
240        let threads = std::thread::available_parallelism()
241            .map(|n| n.get())
242            .unwrap_or(1);
243
244        DeviceProperties {
245            name,
246            total_memory: 16 * 1024 * 1024 * 1024, // Would need /proc/meminfo
247            max_threads_per_block: threads as u32,
248            max_blocks_per_grid: 65535,
249            warp_size: 1,
250            compute_capability: (0, 0),
251        }
252    }
253}