1use crate::error::{CudaError, CudaResult};
4use crate::stream::CudaStream;
5use crate::memory::GpuMemoryPool;
6use std::sync::atomic::{AtomicBool, Ordering};
7use parking_lot::Mutex;
8
9#[cfg(feature = "cuda")]
10use crate::ffi;
11
12static INITIALIZED: AtomicBool = AtomicBool::new(false);
13
14#[derive(Debug)]
16pub struct CudaDevice {
17 pub id: i32,
19 pub name: String,
21 pub compute_capability: (i32, i32),
23 pub total_memory: usize,
25 pub free_memory: usize,
27 pub multiprocessor_count: i32,
29 pub max_threads_per_block: i32,
31 pub max_threads_per_mp: i32,
33 pub warp_size: i32,
35 pub shared_mem_per_block: usize,
37 pub default_stream: CudaStream,
39 pub memory_pool: Mutex<GpuMemoryPool>,
41 #[cfg(feature = "cuda")]
43 cublas_handle: ffi::cublasHandle_t,
44}
45
46impl CudaDevice {
47 pub fn init() -> CudaResult<()> {
49 if INITIALIZED.swap(true, Ordering::SeqCst) {
50 return Err(CudaError::AlreadyInitialized);
51 }
52
53 #[cfg(feature = "cuda")]
55 unsafe {
56 let err = ffi::cudaSetDevice(0);
57 if err != 0 {
58 INITIALIZED.store(false, Ordering::SeqCst);
59 return Err(CudaError::DriverError(err));
60 }
61 }
62
63 Ok(())
64 }
65
66 pub fn count() -> CudaResult<i32> {
68 #[cfg(feature = "cuda")]
69 {
70 let mut count: i32 = 0;
71 unsafe {
72 let err = ffi::cudaGetDeviceCount(&mut count);
73 if err != 0 {
74 return Err(CudaError::DriverError(err));
75 }
76 }
77 Ok(count)
78 }
79
80 #[cfg(not(feature = "cuda"))]
81 {
82 Ok(0)
83 }
84 }
85
86 pub fn new(device_id: i32) -> CudaResult<Self> {
88 let count = Self::count()?;
89 if device_id >= count || device_id < 0 {
90 return Err(CudaError::InvalidDevice(device_id));
91 }
92
93 #[cfg(feature = "cuda")]
94 {
95 unsafe {
97 let err = ffi::cudaSetDevice(device_id);
98 if err != 0 {
99 return Err(CudaError::DriverError(err));
100 }
101 }
102
103 let mut props = cudaDeviceProp::default();
105 unsafe {
106 let err = ffi::cudaGetDeviceProperties(&mut props, device_id);
107 if err != 0 {
108 return Err(CudaError::DriverError(err));
109 }
110 }
111
112 let name = unsafe {
114 CStr::from_ptr(props.name.as_ptr())
115 .to_string_lossy()
116 .into_owned()
117 };
118
119 let mut free_mem: usize = 0;
121 let mut total_mem: usize = 0;
122 unsafe {
123 let err = ffi::cudaMemGetInfo(&mut free_mem, &mut total_mem);
124 if err != 0 {
125 return Err(CudaError::DriverError(err));
126 }
127 }
128
129 let mut cublas_handle: ffi::cublasHandle_t = std::ptr::null_mut();
131 unsafe {
132 let status = ffi::cublasCreate_v2(&mut cublas_handle);
133 if status != 0 {
134 return Err(CudaError::CublasError(status));
135 }
136 }
137
138 Ok(CudaDevice {
139 id: device_id,
140 name,
141 compute_capability: (props.major, props.minor),
142 total_memory: props.totalGlobalMem,
143 free_memory: free_mem,
144 multiprocessor_count: props.multiProcessorCount,
145 max_threads_per_block: props.maxThreadsPerBlock,
146 max_threads_per_mp: props.maxThreadsPerMultiProcessor,
147 warp_size: props.warpSize,
148 shared_mem_per_block: props.sharedMemPerBlock,
149 default_stream: CudaStream::default_stream(),
150 memory_pool: Mutex::new(GpuMemoryPool::new(device_id, free_mem)),
151 cublas_handle,
152 })
153 }
154
155 #[cfg(not(feature = "cuda"))]
156 {
157 Err(CudaError::DeviceNotFound)
158 }
159 }
160
161 pub fn set_current(&self) -> CudaResult<()> {
163 #[cfg(feature = "cuda")]
164 unsafe {
165 let err = ffi::cudaSetDevice(self.id);
166 if err != 0 {
167 return Err(CudaError::DriverError(err));
168 }
169 }
170 Ok(())
171 }
172
173 pub fn current() -> CudaResult<i32> {
175 let device: i32 = -1;
176
177 #[cfg(feature = "cuda")]
178 unsafe {
179 let err = ffi::cudaGetDevice(&mut device);
180 if err != 0 {
181 return Err(CudaError::DriverError(err));
182 }
183 }
184
185 Ok(device)
186 }
187
188 pub fn synchronize() -> CudaResult<()> {
190 #[cfg(feature = "cuda")]
191 unsafe {
192 let err = ffi::cudaDeviceSynchronize();
193 if err != 0 {
194 return Err(CudaError::SyncError);
195 }
196 }
197 Ok(())
198 }
199
200 pub fn get_free_memory(&self) -> CudaResult<usize> {
202 #[cfg(feature = "cuda")]
203 {
204 let mut free_mem: usize = 0;
205 let mut total_mem: usize = 0;
206 unsafe {
207 let err = ffi::cudaMemGetInfo(&mut free_mem, &mut total_mem);
208 if err != 0 {
209 return Err(CudaError::DriverError(err));
210 }
211 }
212 Ok(free_mem)
213 }
214
215 #[cfg(not(feature = "cuda"))]
216 Ok(0)
217 }
218
219 #[cfg(feature = "cuda")]
221 pub fn cublas_handle(&self) -> ffi::cublasHandle_t {
222 self.cublas_handle
223 }
224
225 pub fn properties_string(&self) -> String {
227 format!(
228 "Device {}: {}\n\
229 Compute Capability: {}.{}\n\
230 Total Memory: {:.2} GB\n\
231 Free Memory: {:.2} GB\n\
232 Multiprocessors: {}\n\
233 Max Threads/Block: {}\n\
234 Warp Size: {}\n\
235 Shared Mem/Block: {} KB",
236 self.id,
237 self.name,
238 self.compute_capability.0,
239 self.compute_capability.1,
240 self.total_memory as f64 / (1024.0 * 1024.0 * 1024.0),
241 self.free_memory as f64 / (1024.0 * 1024.0 * 1024.0),
242 self.multiprocessor_count,
243 self.max_threads_per_block,
244 self.warp_size,
245 self.shared_mem_per_block / 1024,
246 )
247 }
248
249 pub fn has_tensor_cores(&self) -> bool {
251 self.compute_capability.0 >= 7
253 }
254
255 pub fn supports_fp16(&self) -> bool {
257 self.compute_capability.0 >= 6
259 }
260
261 pub fn supports_bf16(&self) -> bool {
263 self.compute_capability.0 >= 8
265 }
266
267 pub fn optimal_block_size(&self, shared_mem_per_thread: usize) -> usize {
269 let max_threads = self.max_threads_per_block as usize;
270 let shared_limit = self.shared_mem_per_block / shared_mem_per_thread.max(1);
271
272 let optimal = max_threads.min(shared_limit);
274 (optimal / self.warp_size as usize) * self.warp_size as usize
275 }
276}
277
278impl Drop for CudaDevice {
279 fn drop(&mut self) {
280 #[cfg(feature = "cuda")]
281 unsafe {
282 if !self.cublas_handle.is_null() {
283 ffi::cublasDestroy_v2(self.cublas_handle);
284 }
285 }
286 }
287}
288
289pub struct DeviceGuard {
291 #[allow(dead_code)]
292 previous_device: i32,
293}
294
295impl DeviceGuard {
296 pub fn new(_device_id: i32) -> CudaResult<Self> {
297 let previous = CudaDevice::current()?;
298
299 #[cfg(feature = "cuda")]
300 unsafe {
301 let err = ffi::cudaSetDevice(_device_id);
302 if err != 0 {
303 return Err(CudaError::DriverError(err));
304 }
305 }
306
307 Ok(DeviceGuard { previous_device: previous })
308 }
309}
310
311impl Drop for DeviceGuard {
312 fn drop(&mut self) {
313 #[cfg(feature = "cuda")]
314 unsafe {
315 let _ = ffi::cudaSetDevice(self.previous_device);
316 }
317 }
318}
319
320pub fn get_all_devices() -> CudaResult<Vec<CudaDevice>> {
322 let count = CudaDevice::count()?;
323 let mut devices = Vec::with_capacity(count as usize);
324
325 for i in 0..count {
326 devices.push(CudaDevice::new(i)?);
327 }
328
329 Ok(devices)
330}
331
332pub fn select_best_device() -> CudaResult<CudaDevice> {
334 let devices = get_all_devices()?;
335
336 if devices.is_empty() {
337 return Err(CudaError::DeviceNotFound);
338 }
339
340 let best = devices.into_iter()
342 .max_by_key(|d| {
343 let compute_score = d.compute_capability.0 * 10 + d.compute_capability.1;
344 let memory_score = (d.total_memory / (1024 * 1024 * 1024)) as i32; compute_score * 100 + memory_score
346 })
347 .unwrap();
348
349 Ok(best)
350}