1use core::fmt;
30use sysinfo::System;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub enum Device {
39 Cpu,
41
42 #[cfg(feature = "cuda")]
44 Cuda(usize),
45
46 #[cfg(feature = "vulkan")]
48 Vulkan(usize),
49
50 #[cfg(feature = "metal")]
52 Metal(usize),
53
54 #[cfg(feature = "wgpu")]
56 Wgpu(usize),
57}
58
59impl Device {
60 #[must_use]
62 pub fn is_available(self) -> bool {
63 match self {
64 Self::Cpu => true,
65 #[cfg(feature = "cuda")]
66 Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
67 #[cfg(feature = "vulkan")]
68 Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
69 #[cfg(feature = "metal")]
70 Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
71 #[cfg(feature = "wgpu")]
72 Self::Wgpu(idx) => crate::backends::wgpu::is_device_available(idx),
73 }
74 }
75
76 #[must_use]
78 pub const fn is_cpu(self) -> bool {
79 matches!(self, Self::Cpu)
80 }
81
82 #[must_use]
84 pub const fn is_gpu(self) -> bool {
85 !self.is_cpu()
86 }
87
88 #[must_use]
90 pub const fn index(self) -> usize {
91 match self {
92 Self::Cpu => 0,
93 #[cfg(feature = "cuda")]
94 Self::Cuda(idx) => idx,
95 #[cfg(feature = "vulkan")]
96 Self::Vulkan(idx) => idx,
97 #[cfg(feature = "metal")]
98 Self::Metal(idx) => idx,
99 #[cfg(feature = "wgpu")]
100 Self::Wgpu(idx) => idx,
101 }
102 }
103
104 #[must_use]
106 pub const fn device_type(self) -> &'static str {
107 match self {
108 Self::Cpu => "cpu",
109 #[cfg(feature = "cuda")]
110 Self::Cuda(_) => "cuda",
111 #[cfg(feature = "vulkan")]
112 Self::Vulkan(_) => "vulkan",
113 #[cfg(feature = "metal")]
114 Self::Metal(_) => "metal",
115 #[cfg(feature = "wgpu")]
116 Self::Wgpu(_) => "wgpu",
117 }
118 }
119
120 #[must_use]
122 pub const fn cpu() -> Self {
123 Self::Cpu
124 }
125
126 #[cfg(feature = "cuda")]
128 #[must_use]
129 pub const fn cuda(index: usize) -> Self {
130 Self::Cuda(index)
131 }
132}
133
134impl Default for Device {
135 fn default() -> Self {
136 Self::Cpu
137 }
138}
139
140impl fmt::Display for Device {
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 match self {
143 Self::Cpu => write!(f, "cpu"),
144 #[cfg(feature = "cuda")]
145 Self::Cuda(idx) => write!(f, "cuda:{idx}"),
146 #[cfg(feature = "vulkan")]
147 Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
148 #[cfg(feature = "metal")]
149 Self::Metal(idx) => write!(f, "metal:{idx}"),
150 #[cfg(feature = "wgpu")]
151 Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
162pub struct DeviceCapabilities {
163 pub name: String,
165 pub total_memory: usize,
167 pub available_memory: usize,
169 pub supports_f16: bool,
171 pub supports_f64: bool,
173 pub max_threads_per_block: usize,
175 pub compute_capability: Option<(usize, usize)>,
177}
178
179impl Device {
180 #[must_use]
182 pub fn capabilities(self) -> DeviceCapabilities {
183 match self {
184 Self::Cpu => DeviceCapabilities {
185 name: "CPU".to_string(),
186 total_memory: get_system_memory(),
187 available_memory: get_available_memory(),
188 supports_f16: true,
189 supports_f64: true,
190 max_threads_per_block: num_cpus(),
191 compute_capability: None,
192 },
193 #[cfg(feature = "cuda")]
194 Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
195 #[cfg(feature = "vulkan")]
196 Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
197 #[cfg(feature = "metal")]
198 Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
199 #[cfg(feature = "wgpu")]
200 Self::Wgpu(idx) => crate::backends::wgpu::get_capabilities(idx),
201 }
202 }
203}
204
205fn get_system_memory() -> usize {
211 let sys = System::new_all();
212 sys.total_memory() as usize
213}
214
215fn get_available_memory() -> usize {
217 let sys = System::new_all();
218 sys.available_memory() as usize
219}
220
221fn num_cpus() -> usize {
223 std::thread::available_parallelism()
224 .map(std::num::NonZeroUsize::get)
225 .unwrap_or(1)
226}
227
228impl DeviceCapabilities {
229 #[must_use]
231 pub const fn supports_f32(&self) -> bool {
232 true }
234}
235
236#[cfg(feature = "cuda")]
242#[must_use]
243pub fn cuda_device_count() -> usize {
244 crate::backends::cuda::device_count()
245}
246
247#[cfg(feature = "vulkan")]
249#[must_use]
250pub fn vulkan_device_count() -> usize {
251 crate::backends::vulkan::device_count()
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_cpu_device() {
264 let device = Device::Cpu;
265 assert!(device.is_cpu());
266 assert!(!device.is_gpu());
267 assert!(device.is_available());
268 assert_eq!(device.device_type(), "cpu");
269 }
270
271 #[test]
272 fn test_device_display() {
273 let cpu = Device::Cpu;
274 assert_eq!(format!("{cpu}"), "cpu");
275 }
276
277 #[test]
278 fn test_device_default() {
279 let device = Device::default();
280 assert_eq!(device, Device::Cpu);
281 }
282
283 #[test]
284 fn test_device_capabilities() {
285 let caps = Device::Cpu.capabilities();
286 assert_eq!(caps.name, "CPU");
287 assert!(caps.supports_f32());
288 }
289}