1use core::fmt;
18use sysinfo::System;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
26pub enum Device {
27 #[default]
29 Cpu,
30
31 #[cfg(feature = "cuda")]
33 Cuda(usize),
34
35 #[cfg(feature = "vulkan")]
37 Vulkan(usize),
38
39 #[cfg(feature = "metal")]
41 Metal(usize),
42
43 #[cfg(feature = "wgpu")]
45 Wgpu(usize),
46}
47
48impl Device {
49 #[must_use]
51 pub fn is_available(self) -> bool {
52 match self {
53 Self::Cpu => true,
54 #[cfg(feature = "cuda")]
55 Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
56 #[cfg(feature = "vulkan")]
57 Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
58 #[cfg(feature = "metal")]
59 Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
60 #[cfg(feature = "wgpu")]
61 Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
62 }
63 }
64
65 #[must_use]
67 pub const fn is_cpu(self) -> bool {
68 matches!(self, Self::Cpu)
69 }
70
71 #[must_use]
73 pub const fn is_gpu(self) -> bool {
74 !self.is_cpu()
75 }
76
77 #[must_use]
79 pub const fn index(self) -> usize {
80 match self {
81 Self::Cpu => 0,
82 #[cfg(feature = "cuda")]
83 Self::Cuda(idx) => idx,
84 #[cfg(feature = "vulkan")]
85 Self::Vulkan(idx) => idx,
86 #[cfg(feature = "metal")]
87 Self::Metal(idx) => idx,
88 #[cfg(feature = "wgpu")]
89 Self::Wgpu(idx) => idx,
90 }
91 }
92
93 #[must_use]
95 pub const fn device_type(self) -> &'static str {
96 match self {
97 Self::Cpu => "cpu",
98 #[cfg(feature = "cuda")]
99 Self::Cuda(_) => "cuda",
100 #[cfg(feature = "vulkan")]
101 Self::Vulkan(_) => "vulkan",
102 #[cfg(feature = "metal")]
103 Self::Metal(_) => "metal",
104 #[cfg(feature = "wgpu")]
105 Self::Wgpu(_) => "wgpu",
106 }
107 }
108
109 #[must_use]
111 pub const fn cpu() -> Self {
112 Self::Cpu
113 }
114
115 #[cfg(feature = "cuda")]
117 #[must_use]
118 pub const fn cuda(index: usize) -> Self {
119 Self::Cuda(index)
120 }
121}
122
123impl fmt::Display for Device {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 match self {
126 Self::Cpu => write!(f, "cpu"),
127 #[cfg(feature = "cuda")]
128 Self::Cuda(idx) => write!(f, "cuda:{idx}"),
129 #[cfg(feature = "vulkan")]
130 Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
131 #[cfg(feature = "metal")]
132 Self::Metal(idx) => write!(f, "metal:{idx}"),
133 #[cfg(feature = "wgpu")]
134 Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
145pub struct DeviceCapabilities {
146 pub name: String,
148 pub total_memory: usize,
150 pub available_memory: usize,
152 pub supports_f16: bool,
154 pub supports_f64: bool,
156 pub max_threads_per_block: usize,
158 pub compute_capability: Option<(usize, usize)>,
160}
161
162impl Device {
163 #[must_use]
165 pub fn capabilities(self) -> DeviceCapabilities {
166 match self {
167 Self::Cpu => DeviceCapabilities {
168 name: "CPU".to_string(),
169 total_memory: get_system_memory(),
170 available_memory: get_available_memory(),
171 supports_f16: true,
172 supports_f64: true,
173 max_threads_per_block: num_cpus(),
174 compute_capability: None,
175 },
176 #[cfg(feature = "cuda")]
177 Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
178 #[cfg(feature = "vulkan")]
179 Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
180 #[cfg(feature = "metal")]
181 Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
182 #[cfg(feature = "wgpu")]
183 Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
184 }
185 }
186}
187
188fn get_system_memory() -> usize {
194 let sys = System::new_all();
195 sys.total_memory() as usize
196}
197
198fn get_available_memory() -> usize {
200 let sys = System::new_all();
201 sys.available_memory() as usize
202}
203
204fn num_cpus() -> usize {
206 std::thread::available_parallelism()
207 .map(std::num::NonZeroUsize::get)
208 .unwrap_or(1)
209}
210
211impl DeviceCapabilities {
212 #[must_use]
214 pub const fn supports_f32(&self) -> bool {
215 true }
217}
218
219#[cfg(feature = "cuda")]
225#[must_use]
226pub fn cuda_device_count() -> usize {
227 crate::backends::cuda::device_count()
228}
229
230#[cfg(feature = "vulkan")]
232#[must_use]
233pub fn vulkan_device_count() -> usize {
234 crate::backends::vulkan::device_count()
235}
236
237#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_cpu_device() {
247 let device = Device::Cpu;
248 assert!(device.is_cpu());
249 assert!(!device.is_gpu());
250 assert!(device.is_available());
251 assert_eq!(device.device_type(), "cpu");
252 }
253
254 #[test]
255 fn test_device_display() {
256 let cpu = Device::Cpu;
257 assert_eq!(format!("{cpu}"), "cpu");
258 }
259
260 #[test]
261 fn test_device_default() {
262 let device = Device::default();
263 assert_eq!(device, Device::Cpu);
264 }
265
266 #[test]
267 fn test_device_capabilities() {
268 let caps = Device::Cpu.capabilities();
269 assert_eq!(caps.name, "CPU");
270 assert!(caps.supports_f32());
271 }
272}