cuda_rust_wasm/runtime/
device.rs1use crate::{Result, runtime_error};
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
8pub struct DeviceProperties {
9 pub name: String,
10 pub total_memory: usize,
11 pub max_threads_per_block: u32,
12 pub max_blocks_per_grid: u32,
13 pub warp_size: u32,
14 pub compute_capability: (u32, u32),
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum BackendType {
20 Native,
21 WebGPU,
22 CPU,
23}
24
25pub struct Device {
27 backend: BackendType,
28 properties: DeviceProperties,
29 id: usize,
30}
31
32impl Device {
33 pub fn get_default() -> Result<Arc<Self>> {
35 let backend = Self::detect_backend();
37
38 let properties = match backend {
39 BackendType::Native => Self::get_native_properties()?,
40 BackendType::WebGPU => Self::get_webgpu_properties()?,
41 BackendType::CPU => Self::get_cpu_properties(),
42 };
43
44 Ok(Arc::new(Self {
45 backend,
46 properties,
47 id: 0,
48 }))
49 }
50
51 pub fn get_by_id(id: usize) -> Result<Arc<Self>> {
53 if id != 0 {
55 return Err(runtime_error!("Device {} not found", id));
56 }
57 Self::get_default()
58 }
59
60 pub fn count() -> Result<usize> {
62 Ok(1)
64 }
65
66 pub fn properties(&self) -> &DeviceProperties {
68 &self.properties
69 }
70
71 pub fn backend(&self) -> BackendType {
73 self.backend
74 }
75
76 pub fn id(&self) -> usize {
78 self.id
79 }
80
81 fn detect_backend() -> BackendType {
83 #[cfg(target_arch = "wasm32")]
84 {
85 BackendType::WebGPU
86 }
87
88 #[cfg(not(target_arch = "wasm32"))]
89 {
90 #[cfg(feature = "cuda-backend")]
92 {
93 if Self::has_cuda() {
94 return BackendType::Native;
95 }
96 }
97
98 BackendType::CPU
100 }
101 }
102
103 #[cfg(feature = "cuda-backend")]
105 fn has_cuda() -> bool {
106 false
108 }
109
110 fn get_native_properties() -> Result<DeviceProperties> {
112 Ok(DeviceProperties {
114 name: "NVIDIA GPU (Simulated)".to_string(),
115 total_memory: 8 * 1024 * 1024 * 1024, max_threads_per_block: 1024,
117 max_blocks_per_grid: 65535,
118 warp_size: 32,
119 compute_capability: (8, 0),
120 })
121 }
122
123 fn get_webgpu_properties() -> Result<DeviceProperties> {
125 Ok(DeviceProperties {
126 name: "WebGPU Device".to_string(),
127 total_memory: 2 * 1024 * 1024 * 1024, max_threads_per_block: 256,
129 max_blocks_per_grid: 65535,
130 warp_size: 32,
131 compute_capability: (1, 0),
132 })
133 }
134
135 fn get_cpu_properties() -> DeviceProperties {
137 DeviceProperties {
138 name: "CPU Device".to_string(),
139 total_memory: 16 * 1024 * 1024 * 1024, max_threads_per_block: 1024,
141 max_blocks_per_grid: 65535,
142 warp_size: 1, compute_capability: (0, 0),
144 }
145 }
146}