1use core::fmt;
30use sysinfo::System;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
38pub enum Device {
39 #[default]
41 Cpu,
42
43 #[cfg(feature = "cuda")]
45 Cuda(usize),
46
47 #[cfg(feature = "vulkan")]
49 Vulkan(usize),
50
51 #[cfg(feature = "metal")]
53 Metal(usize),
54
55 #[cfg(feature = "wgpu")]
57 Wgpu(usize),
58}
59
60impl Device {
61 #[must_use]
63 pub fn is_available(self) -> bool {
64 match self {
65 Self::Cpu => true,
66 #[cfg(feature = "cuda")]
67 Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
68 #[cfg(feature = "vulkan")]
69 Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
70 #[cfg(feature = "metal")]
71 Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
72 #[cfg(feature = "wgpu")]
73 Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
74 }
75 }
76
77 #[must_use]
79 pub const fn is_cpu(self) -> bool {
80 matches!(self, Self::Cpu)
81 }
82
83 #[must_use]
85 pub const fn is_gpu(self) -> bool {
86 !self.is_cpu()
87 }
88
89 #[must_use]
91 pub const fn index(self) -> usize {
92 match self {
93 Self::Cpu => 0,
94 #[cfg(feature = "cuda")]
95 Self::Cuda(idx) => idx,
96 #[cfg(feature = "vulkan")]
97 Self::Vulkan(idx) => idx,
98 #[cfg(feature = "metal")]
99 Self::Metal(idx) => idx,
100 #[cfg(feature = "wgpu")]
101 Self::Wgpu(idx) => idx,
102 }
103 }
104
105 #[must_use]
107 pub const fn device_type(self) -> &'static str {
108 match self {
109 Self::Cpu => "cpu",
110 #[cfg(feature = "cuda")]
111 Self::Cuda(_) => "cuda",
112 #[cfg(feature = "vulkan")]
113 Self::Vulkan(_) => "vulkan",
114 #[cfg(feature = "metal")]
115 Self::Metal(_) => "metal",
116 #[cfg(feature = "wgpu")]
117 Self::Wgpu(_) => "wgpu",
118 }
119 }
120
121 #[must_use]
123 pub const fn cpu() -> Self {
124 Self::Cpu
125 }
126
127 #[cfg(feature = "cuda")]
129 #[must_use]
130 pub const fn cuda(index: usize) -> Self {
131 Self::Cuda(index)
132 }
133}
134
135impl fmt::Display for Device {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 match self {
138 Self::Cpu => write!(f, "cpu"),
139 #[cfg(feature = "cuda")]
140 Self::Cuda(idx) => write!(f, "cuda:{idx}"),
141 #[cfg(feature = "vulkan")]
142 Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
143 #[cfg(feature = "metal")]
144 Self::Metal(idx) => write!(f, "metal:{idx}"),
145 #[cfg(feature = "wgpu")]
146 Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
157pub struct DeviceCapabilities {
158 pub name: String,
160 pub total_memory: usize,
162 pub available_memory: usize,
164 pub supports_f16: bool,
166 pub supports_f64: bool,
168 pub max_threads_per_block: usize,
170 pub compute_capability: Option<(usize, usize)>,
172}
173
174impl Device {
175 #[must_use]
177 pub fn capabilities(self) -> DeviceCapabilities {
178 match self {
179 Self::Cpu => DeviceCapabilities {
180 name: "CPU".to_string(),
181 total_memory: get_system_memory(),
182 available_memory: get_available_memory(),
183 supports_f16: true,
184 supports_f64: true,
185 max_threads_per_block: num_cpus(),
186 compute_capability: None,
187 },
188 #[cfg(feature = "cuda")]
189 Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
190 #[cfg(feature = "vulkan")]
191 Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
192 #[cfg(feature = "metal")]
193 Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
194 #[cfg(feature = "wgpu")]
195 Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
196 }
197 }
198}
199
200fn get_system_memory() -> usize {
206 let sys = System::new_all();
207 sys.total_memory() as usize
208}
209
210fn get_available_memory() -> usize {
212 let sys = System::new_all();
213 sys.available_memory() as usize
214}
215
216fn num_cpus() -> usize {
218 std::thread::available_parallelism()
219 .map(std::num::NonZeroUsize::get)
220 .unwrap_or(1)
221}
222
223impl DeviceCapabilities {
224 #[must_use]
226 pub const fn supports_f32(&self) -> bool {
227 true }
229}
230
231#[cfg(feature = "cuda")]
237#[must_use]
238pub fn cuda_device_count() -> usize {
239 crate::backends::cuda::device_count()
240}
241
242#[cfg(feature = "vulkan")]
244#[must_use]
245pub fn vulkan_device_count() -> usize {
246 crate::backends::vulkan::device_count()
247}
248
249#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_cpu_device() {
259 let device = Device::Cpu;
260 assert!(device.is_cpu());
261 assert!(!device.is_gpu());
262 assert!(device.is_available());
263 assert_eq!(device.device_type(), "cpu");
264 }
265
266 #[test]
267 fn test_device_display() {
268 let cpu = Device::Cpu;
269 assert_eq!(format!("{cpu}"), "cpu");
270 }
271
272 #[test]
273 fn test_device_default() {
274 let device = Device::default();
275 assert_eq!(device, Device::Cpu);
276 }
277
278 #[test]
279 fn test_device_capabilities() {
280 let caps = Device::Cpu.capabilities();
281 assert_eq!(caps.name, "CPU");
282 assert!(caps.supports_f32());
283 }
284}