1use cpp::cpp;
2
3use crate::device::DeviceId;
4use crate::device::MemoryInfo;
5use crate::ffi::result;
6
7type Result<T> = std::result::Result<T, crate::error::Error>;
8
9pub fn num_devices() -> Result<usize> {
13 let mut num = 0_i32;
14 let num_ptr = std::ptr::addr_of_mut!(num);
15 let ret = cpp!(unsafe [
16 num_ptr as "std::int32_t*"
17 ] -> i32 as "std::int32_t" {
18 return cudaGetDeviceCount(num_ptr);
19 });
20
21 result!(ret, num as usize)
22}
23
24pub struct Device;
28
29impl Device {
30 #[inline]
31 pub fn get() -> Result<DeviceId> {
32 let mut id: i32 = 0;
33 let id_ptr = std::ptr::addr_of_mut!(id);
34 let ret = cpp!(unsafe [
35 id_ptr as "int*"
36 ] -> i32 as "int" {
37 return cudaGetDevice(id_ptr);
38 });
39 result!(ret, id)
40 }
41
42 #[inline(always)]
43 pub fn get_or_panic() -> DeviceId {
44 Device::get().unwrap_or_else(|err| panic!("failed to get device: {err}"))
45 }
46
47 #[inline]
48 pub fn set(id: DeviceId) -> Result<()> {
49 let ret = cpp!(unsafe [
50 id as "int"
51 ] -> i32 as "int" {
52 return cudaSetDevice(id);
53 });
54 result!(ret)
55 }
56
57 #[inline(always)]
58 pub fn set_or_panic(id: DeviceId) {
59 Device::set(id).unwrap_or_else(|err| panic!("failed to set device {id}: {err}"));
60 }
61
62 pub fn synchronize() -> Result<()> {
63 let ret = cpp!(unsafe [] -> i32 as "std::int32_t" {
64 return cudaDeviceSynchronize();
65 });
66 result!(ret)
67 }
68
69 pub fn memory_info() -> Result<MemoryInfo> {
70 let mut free: usize = 0;
71 let free_ptr = std::ptr::addr_of_mut!(free);
72 let mut total: usize = 0;
73 let total_ptr = std::ptr::addr_of_mut!(total);
74
75 let ret = cpp!(unsafe [
76 free_ptr as "std::size_t*",
77 total_ptr as "std::size_t*"
78 ] -> i32 as "std::int32_t" {
79 return cudaMemGetInfo(free_ptr, total_ptr);
80 });
81 result!(ret, MemoryInfo { free, total })
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn test_num_devices() {
91 assert!(matches!(num_devices(), Ok(num) if num > 0));
92 }
93
94 #[test]
95 fn test_get_device() {
96 assert!(matches!(Device::get(), Ok(0)));
97 }
98
99 #[test]
100 fn test_set_device() {
101 assert!(Device::set(0).is_ok());
102 assert!(matches!(Device::get(), Ok(0)));
103 }
104
105 #[test]
106 fn test_synchronize() {
107 assert!(Device::synchronize().is_ok());
108 }
109
110 #[test]
111 fn test_memory_info() {
112 let memory_info = Device::memory_info().unwrap();
113 assert!(memory_info.free > 0);
114 assert!(memory_info.total > 0);
115 }
116}