1use core::fmt;
25use sysinfo::System;
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
33pub enum Device {
34 #[default]
36 Cpu,
37
38 #[cfg(feature = "cuda")]
40 Cuda(usize),
41
42 #[cfg(feature = "vulkan")]
44 Vulkan(usize),
45
46 #[cfg(feature = "metal")]
48 Metal(usize),
49
50 #[cfg(feature = "wgpu")]
52 Wgpu(usize),
53}
54
55impl Device {
56 #[must_use]
58 pub fn is_available(self) -> bool {
59 match self {
60 Self::Cpu => true,
61 #[cfg(feature = "cuda")]
62 Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
63 #[cfg(feature = "vulkan")]
64 Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
65 #[cfg(feature = "metal")]
66 Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
67 #[cfg(feature = "wgpu")]
68 Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
69 }
70 }
71
72 #[must_use]
74 pub const fn is_cpu(self) -> bool {
75 matches!(self, Self::Cpu)
76 }
77
78 #[must_use]
80 pub const fn is_gpu(self) -> bool {
81 !self.is_cpu()
82 }
83
84 #[must_use]
86 pub const fn index(self) -> usize {
87 match self {
88 Self::Cpu => 0,
89 #[cfg(feature = "cuda")]
90 Self::Cuda(idx) => idx,
91 #[cfg(feature = "vulkan")]
92 Self::Vulkan(idx) => idx,
93 #[cfg(feature = "metal")]
94 Self::Metal(idx) => idx,
95 #[cfg(feature = "wgpu")]
96 Self::Wgpu(idx) => idx,
97 }
98 }
99
100 #[must_use]
102 pub const fn device_type(self) -> &'static str {
103 match self {
104 Self::Cpu => "cpu",
105 #[cfg(feature = "cuda")]
106 Self::Cuda(_) => "cuda",
107 #[cfg(feature = "vulkan")]
108 Self::Vulkan(_) => "vulkan",
109 #[cfg(feature = "metal")]
110 Self::Metal(_) => "metal",
111 #[cfg(feature = "wgpu")]
112 Self::Wgpu(_) => "wgpu",
113 }
114 }
115
116 #[must_use]
118 pub const fn cpu() -> Self {
119 Self::Cpu
120 }
121
122 #[cfg(feature = "cuda")]
124 #[must_use]
125 pub const fn cuda(index: usize) -> Self {
126 Self::Cuda(index)
127 }
128}
129
130impl fmt::Display for Device {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 match self {
133 Self::Cpu => write!(f, "cpu"),
134 #[cfg(feature = "cuda")]
135 Self::Cuda(idx) => write!(f, "cuda:{idx}"),
136 #[cfg(feature = "vulkan")]
137 Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
138 #[cfg(feature = "metal")]
139 Self::Metal(idx) => write!(f, "metal:{idx}"),
140 #[cfg(feature = "wgpu")]
141 Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
152pub struct DeviceCapabilities {
153 pub name: String,
155 pub total_memory: usize,
157 pub available_memory: usize,
159 pub supports_f16: bool,
161 pub supports_f64: bool,
163 pub max_threads_per_block: usize,
165 pub compute_capability: Option<(usize, usize)>,
167}
168
169impl Device {
170 #[must_use]
172 pub fn capabilities(self) -> DeviceCapabilities {
173 match self {
174 Self::Cpu => DeviceCapabilities {
175 name: "CPU".to_string(),
176 total_memory: get_system_memory(),
177 available_memory: get_available_memory(),
178 supports_f16: true,
179 supports_f64: true,
180 max_threads_per_block: num_cpus(),
181 compute_capability: None,
182 },
183 #[cfg(feature = "cuda")]
184 Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
185 #[cfg(feature = "vulkan")]
186 Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
187 #[cfg(feature = "metal")]
188 Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
189 #[cfg(feature = "wgpu")]
190 Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
191 }
192 }
193}
194
195fn get_system_memory() -> usize {
201 let sys = System::new_all();
202 sys.total_memory() as usize
203}
204
205fn get_available_memory() -> usize {
207 let sys = System::new_all();
208 sys.available_memory() as usize
209}
210
211fn num_cpus() -> usize {
213 std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get)
214}
215
216impl DeviceCapabilities {
217 #[must_use]
219 pub const fn supports_f32(&self) -> bool {
220 true }
222}
223
224#[cfg(feature = "cuda")]
230#[must_use]
231pub fn cuda_device_count() -> usize {
232 crate::backends::cuda::device_count()
233}
234
235#[cfg(feature = "vulkan")]
237#[must_use]
238pub fn vulkan_device_count() -> usize {
239 crate::backends::vulkan::device_count()
240}
241
242#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_cpu_device() {
252 let device = Device::Cpu;
253 assert!(device.is_cpu());
254 assert!(!device.is_gpu());
255 assert!(device.is_available());
256 assert_eq!(device.device_type(), "cpu");
257 }
258
259 #[test]
260 fn test_device_display() {
261 let cpu = Device::Cpu;
262 assert_eq!(format!("{cpu}"), "cpu");
263 }
264
265 #[test]
266 fn test_device_default() {
267 let device = Device::default();
268 assert_eq!(device, Device::Cpu);
269 }
270
271 #[test]
272 fn test_device_capabilities() {
273 let caps = Device::Cpu.capabilities();
274 assert_eq!(caps.name, "CPU");
275 assert!(caps.supports_f32());
276 }
277}