cuda_rust_wasm/runtime/
device.rs

1//! Device abstraction for different backends
2
3use crate::{Result, runtime_error};
4use std::sync::Arc;
5
6/// Device properties
7#[derive(Debug, Clone)]
8pub struct DeviceProperties {
9    pub name: String,
10    pub total_memory: usize,
11    pub max_threads_per_block: u32,
12    pub max_blocks_per_grid: u32,
13    pub warp_size: u32,
14    pub compute_capability: (u32, u32),
15}
16
17/// Backend type
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum BackendType {
20    Native,
21    WebGPU,
22    CPU,
23}
24
25/// Device abstraction
26pub struct Device {
27    backend: BackendType,
28    properties: DeviceProperties,
29    id: usize,
30}
31
32impl Device {
33    /// Get the default device
34    pub fn get_default() -> Result<Arc<Self>> {
35        // Detect available backend
36        let backend = Self::detect_backend();
37        
38        let properties = match backend {
39            BackendType::Native => Self::get_native_properties()?,
40            BackendType::WebGPU => Self::get_webgpu_properties()?,
41            BackendType::CPU => Self::get_cpu_properties(),
42        };
43        
44        Ok(Arc::new(Self {
45            backend,
46            properties,
47            id: 0,
48        }))
49    }
50    
51    /// Get device by ID
52    pub fn get_by_id(id: usize) -> Result<Arc<Self>> {
53        // For now, only support device 0
54        if id != 0 {
55            return Err(runtime_error!("Device {} not found", id));
56        }
57        Self::get_default()
58    }
59    
60    /// Get device count
61    pub fn count() -> Result<usize> {
62        // For now, always return 1
63        Ok(1)
64    }
65    
66    /// Get device properties
67    pub fn properties(&self) -> &DeviceProperties {
68        &self.properties
69    }
70    
71    /// Get backend type
72    pub fn backend(&self) -> BackendType {
73        self.backend
74    }
75    
76    /// Get device ID
77    pub fn id(&self) -> usize {
78        self.id
79    }
80    
81    /// Detect available backend
82    fn detect_backend() -> BackendType {
83        #[cfg(target_arch = "wasm32")]
84        {
85            BackendType::WebGPU
86        }
87        
88        #[cfg(not(target_arch = "wasm32"))]
89        {
90            // Check for native GPU support
91            #[cfg(feature = "cuda-backend")]
92            {
93                if Self::has_cuda() {
94                    return BackendType::Native;
95                }
96            }
97            
98            // Fallback to CPU
99            BackendType::CPU
100        }
101    }
102    
103    /// Check if CUDA is available
104    #[cfg(feature = "cuda-backend")]
105    fn has_cuda() -> bool {
106        // TODO: Actually check for CUDA availability
107        false
108    }
109    
110    /// Get native GPU properties
111    fn get_native_properties() -> Result<DeviceProperties> {
112        // TODO: Query actual GPU properties
113        Ok(DeviceProperties {
114            name: "NVIDIA GPU (Simulated)".to_string(),
115            total_memory: 8 * 1024 * 1024 * 1024, // 8GB
116            max_threads_per_block: 1024,
117            max_blocks_per_grid: 65535,
118            warp_size: 32,
119            compute_capability: (8, 0),
120        })
121    }
122    
123    /// Get WebGPU properties
124    fn get_webgpu_properties() -> Result<DeviceProperties> {
125        Ok(DeviceProperties {
126            name: "WebGPU Device".to_string(),
127            total_memory: 2 * 1024 * 1024 * 1024, // 2GB
128            max_threads_per_block: 256,
129            max_blocks_per_grid: 65535,
130            warp_size: 32,
131            compute_capability: (1, 0),
132        })
133    }
134    
135    /// Get CPU properties
136    fn get_cpu_properties() -> DeviceProperties {
137        DeviceProperties {
138            name: "CPU Device".to_string(),
139            total_memory: 16 * 1024 * 1024 * 1024, // 16GB
140            max_threads_per_block: 1024,
141            max_blocks_per_grid: 65535,
142            warp_size: 1, // No warps on CPU
143            compute_capability: (0, 0),
144        }
145    }
146}