Skip to main content

ferrotorch_core/
device.rs

1/// Device on which a tensor's data resides.
2///
3/// Defined in Phase 1 with only `Cpu` functional. `Cuda` is present
4/// from day one so the type is baked into every API before GPU work begins.
5#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
6pub enum Device {
7    /// CPU main memory.
8    Cpu,
9    /// CUDA GPU with the given device index.
10    Cuda(usize),
11}
12
13impl Device {
14    /// Returns `true` if this is a CPU device.
15    #[inline]
16    pub fn is_cpu(&self) -> bool {
17        matches!(self, Device::Cpu)
18    }
19
20    /// Returns `true` if this is a CUDA device.
21    #[inline]
22    pub fn is_cuda(&self) -> bool {
23        matches!(self, Device::Cuda(_))
24    }
25}
26
27impl core::fmt::Display for Device {
28    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
29        match self {
30            Device::Cpu => write!(f, "cpu"),
31            Device::Cuda(id) => write!(f, "cuda:{id}"),
32        }
33    }
34}
35
36impl Default for Device {
37    fn default() -> Self {
38        Device::Cpu
39    }
40}