Skip to main content

ferrum_types/
devices.rs

1//! Device and computation types
2
3use serde::{Deserialize, Serialize};
4
5/// Device type for computation
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum Device {
8    /// CPU device
9    CPU,
10    /// NVIDIA CUDA device with device index
11    CUDA(usize),
12    /// AMD ROCm device with device index
13    ROCm(usize),
14    /// Apple GPU using Metal Performance Shaders
15    #[cfg(any(target_os = "macos", target_os = "ios"))]
16    Metal,
17}
18
19impl Device {
20    /// Check if device is GPU-based
21    pub fn is_gpu(&self) -> bool {
22        matches!(self, Device::CUDA(_) | Device::ROCm(_)) || {
23            #[cfg(any(target_os = "macos", target_os = "ios"))]
24            {
25                matches!(self, Device::Metal)
26            }
27            #[cfg(not(any(target_os = "macos", target_os = "ios")))]
28            {
29                false
30            }
31        }
32    }
33
34    /// Get device index for GPU devices
35    pub fn index(&self) -> Option<usize> {
36        match self {
37            Device::CUDA(idx) | Device::ROCm(idx) => Some(*idx),
38            _ => None,
39        }
40    }
41}
42
43impl std::fmt::Display for Device {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Device::CPU => write!(f, "cpu"),
47            Device::CUDA(idx) => write!(f, "cuda:{}", idx),
48            Device::ROCm(idx) => write!(f, "rocm:{}", idx),
49            #[cfg(any(target_os = "macos", target_os = "ios"))]
50            Device::Metal => write!(f, "metal"),
51        }
52    }
53}
54
55/// Data type for tensors
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub enum DataType {
58    /// 32-bit floating point
59    FP32,
60    /// 16-bit floating point (IEEE 754)
61    FP16,
62    /// 16-bit brain floating point
63    BF16,
64    /// 8-bit floating point (E5M2 or E4M3)
65    FP8,
66    /// 32-bit signed integer
67    INT32,
68    /// 16-bit signed integer
69    INT16,
70    /// 8-bit signed integer
71    INT8,
72    /// 4-bit signed integer
73    INT4,
74    /// 32-bit unsigned integer
75    UINT32,
76    /// 16-bit unsigned integer  
77    UINT16,
78    /// 8-bit unsigned integer
79    UINT8,
80    /// 4-bit unsigned integer
81    UINT4,
82    /// Boolean
83    BOOL,
84}
85
86impl DataType {
87    /// Get size in bytes for this data type
88    pub fn size_bytes(&self) -> usize {
89        match self {
90            DataType::FP32 | DataType::INT32 | DataType::UINT32 => 4,
91            DataType::FP16 | DataType::BF16 | DataType::INT16 | DataType::UINT16 => 2,
92            DataType::FP8 | DataType::INT8 | DataType::UINT8 | DataType::BOOL => 1,
93            DataType::INT4 | DataType::UINT4 => 1, // Packed, but minimum 1 byte
94        }
95    }
96
97    /// Check if this is a floating point type
98    pub fn is_float(&self) -> bool {
99        matches!(
100            self,
101            DataType::FP32 | DataType::FP16 | DataType::BF16 | DataType::FP8
102        )
103    }
104
105    /// Check if this is an integer type
106    pub fn is_integer(&self) -> bool {
107        matches!(
108            self,
109            DataType::INT32
110                | DataType::INT16
111                | DataType::INT8
112                | DataType::INT4
113                | DataType::UINT32
114                | DataType::UINT16
115                | DataType::UINT8
116                | DataType::UINT4
117        )
118    }
119
120    /// Check if this is a quantized type (reduced precision)
121    pub fn is_quantized(&self) -> bool {
122        matches!(
123            self,
124            DataType::FP8 | DataType::INT8 | DataType::INT4 | DataType::UINT4
125        )
126    }
127}
128
129impl std::fmt::Display for DataType {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        let name = match self {
132            DataType::FP32 => "fp32",
133            DataType::FP16 => "fp16",
134            DataType::BF16 => "bf16",
135            DataType::FP8 => "fp8",
136            DataType::INT32 => "int32",
137            DataType::INT16 => "int16",
138            DataType::INT8 => "int8",
139            DataType::INT4 => "int4",
140            DataType::UINT32 => "uint32",
141            DataType::UINT16 => "uint16",
142            DataType::UINT8 => "uint8",
143            DataType::UINT4 => "uint4",
144            DataType::BOOL => "bool",
145        };
146        write!(f, "{}", name)
147    }
148}