jax_rs/
dtype.rs

1//! Data type definitions and utilities.
2
3use std::fmt;
4
5/// Numerical data type for array contents.
6///
7/// Corresponds to jax-js DType enum. Supports basic types
8/// that can be efficiently worked with on the web.
9#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
10pub enum DType {
11    /// 32-bit floating point
12    Float32,
13    /// 16-bit floating point (stored as u16 bits)
14    Float16,
15    /// 64-bit floating point
16    Float64,
17    /// 8-bit signed integer
18    Int8,
19    /// 16-bit signed integer
20    Int16,
21    /// 32-bit signed integer
22    Int32,
23    /// 64-bit signed integer
24    Int64,
25    /// 8-bit unsigned integer
26    Uint8,
27    /// 16-bit unsigned integer
28    Uint16,
29    /// 32-bit unsigned integer
30    Uint32,
31    /// 64-bit unsigned integer
32    Uint64,
33    /// Boolean (stored as 1-byte value)
34    Bool,
35}
36
37impl DType {
38    /// Returns the byte width of this dtype.
39    #[inline]
40    pub const fn byte_width(self) -> usize {
41        match self {
42            DType::Bool | DType::Int8 | DType::Uint8 => 1,
43            DType::Float16 | DType::Int16 | DType::Uint16 => 2,
44            DType::Float32 | DType::Int32 | DType::Uint32 => 4,
45            DType::Float64 | DType::Int64 | DType::Uint64 => 8,
46        }
47    }
48
49    /// Returns true if this is a floating-point dtype.
50    #[inline]
51    pub const fn is_float(self) -> bool {
52        matches!(self, DType::Float32 | DType::Float16 | DType::Float64)
53    }
54
55    /// Returns true if this is an integer dtype.
56    #[inline]
57    pub const fn is_int(self) -> bool {
58        matches!(
59            self,
60            DType::Int8
61                | DType::Int16
62                | DType::Int32
63                | DType::Int64
64                | DType::Uint8
65                | DType::Uint16
66                | DType::Uint32
67                | DType::Uint64
68        )
69    }
70
71    /// Returns true if this is a signed integer dtype.
72    #[inline]
73    pub const fn is_signed(self) -> bool {
74        matches!(
75            self,
76            DType::Int8 | DType::Int16 | DType::Int32 | DType::Int64
77        )
78    }
79
80    /// Returns true if this is an unsigned integer dtype.
81    #[inline]
82    pub const fn is_unsigned(self) -> bool {
83        matches!(
84            self,
85            DType::Uint8 | DType::Uint16 | DType::Uint32 | DType::Uint64
86        )
87    }
88
89    /// Promotes two dtypes according to JAX's type promotion rules.
90    ///
91    /// Type lattice: `bool -> uint8 -> uint16 -> uint32 -> int8 -> int16 -> int32 -> float16 -> float32 -> float64`
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// # use jax_rs::DType;
97    /// assert_eq!(DType::promote(DType::Bool, DType::Int32), DType::Int32);
98    /// assert_eq!(DType::promote(DType::Uint32, DType::Int32), DType::Int32);
99    /// assert_eq!(DType::promote(DType::Int32, DType::Float16), DType::Float16);
100    /// assert_eq!(DType::promote(DType::Float16, DType::Float32), DType::Float32);
101    /// ```
102    pub fn promote(dtype1: DType, dtype2: DType) -> DType {
103        if dtype1 == dtype2 {
104            return dtype1;
105        }
106
107        // Promotion order (higher rank = later in chain)
108        let rank = |d: DType| match d {
109            DType::Bool => 0,
110            DType::Uint8 => 1,
111            DType::Uint16 => 2,
112            DType::Uint32 => 3,
113            DType::Uint64 => 4,
114            DType::Int8 => 5,
115            DType::Int16 => 6,
116            DType::Int32 => 7,
117            DType::Int64 => 8,
118            DType::Float16 => 9,
119            DType::Float32 => 10,
120            DType::Float64 => 11,
121        };
122
123        if rank(dtype1) > rank(dtype2) {
124            dtype1
125        } else {
126            dtype2
127        }
128    }
129
130    /// Cast a float32 value to this dtype (returns as f32 for storage).
131    #[inline]
132    pub fn cast_from_f32(self, value: f32) -> f32 {
133        match self {
134            DType::Float32 | DType::Float64 | DType::Float16 => value,
135            DType::Int8 => (value as i8) as f32,
136            DType::Int16 => (value as i16) as f32,
137            DType::Int32 | DType::Int64 => (value as i32) as f32,
138            DType::Uint8 => (value as u8) as f32,
139            DType::Uint16 => (value as u16) as f32,
140            DType::Uint32 | DType::Uint64 => (value as u32) as f32,
141            DType::Bool => if value != 0.0 { 1.0 } else { 0.0 },
142        }
143    }
144}
145
146impl fmt::Display for DType {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        match self {
149            DType::Float32 => write!(f, "float32"),
150            DType::Float16 => write!(f, "float16"),
151            DType::Float64 => write!(f, "float64"),
152            DType::Int8 => write!(f, "int8"),
153            DType::Int16 => write!(f, "int16"),
154            DType::Int32 => write!(f, "int32"),
155            DType::Int64 => write!(f, "int64"),
156            DType::Uint8 => write!(f, "uint8"),
157            DType::Uint16 => write!(f, "uint16"),
158            DType::Uint32 => write!(f, "uint32"),
159            DType::Uint64 => write!(f, "uint64"),
160            DType::Bool => write!(f, "bool"),
161        }
162    }
163}
164
165impl DType {
166    /// Parse a string into a DType.
167    pub fn from_str(s: &str) -> Option<DType> {
168        match s.to_lowercase().as_str() {
169            "float32" | "f32" => Some(DType::Float32),
170            "float16" | "f16" => Some(DType::Float16),
171            "float64" | "f64" => Some(DType::Float64),
172            "int8" | "i8" => Some(DType::Int8),
173            "int16" | "i16" => Some(DType::Int16),
174            "int32" | "i32" => Some(DType::Int32),
175            "int64" | "i64" => Some(DType::Int64),
176            "uint8" | "u8" => Some(DType::Uint8),
177            "uint16" | "u16" => Some(DType::Uint16),
178            "uint32" | "u32" => Some(DType::Uint32),
179            "uint64" | "u64" => Some(DType::Uint64),
180            "bool" => Some(DType::Bool),
181            _ => None,
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_byte_width() {
192        assert_eq!(DType::Float32.byte_width(), 4);
193        assert_eq!(DType::Float16.byte_width(), 2);
194        assert_eq!(DType::Float64.byte_width(), 8);
195        assert_eq!(DType::Int8.byte_width(), 1);
196        assert_eq!(DType::Int16.byte_width(), 2);
197        assert_eq!(DType::Int32.byte_width(), 4);
198        assert_eq!(DType::Int64.byte_width(), 8);
199        assert_eq!(DType::Uint8.byte_width(), 1);
200        assert_eq!(DType::Uint16.byte_width(), 2);
201        assert_eq!(DType::Uint32.byte_width(), 4);
202        assert_eq!(DType::Uint64.byte_width(), 8);
203        assert_eq!(DType::Bool.byte_width(), 1);
204    }
205
206    #[test]
207    fn test_is_float() {
208        assert!(DType::Float32.is_float());
209        assert!(DType::Float16.is_float());
210        assert!(DType::Float64.is_float());
211        assert!(!DType::Int32.is_float());
212        assert!(!DType::Uint32.is_float());
213        assert!(!DType::Bool.is_float());
214    }
215
216    #[test]
217    fn test_is_int() {
218        assert!(DType::Int8.is_int());
219        assert!(DType::Int16.is_int());
220        assert!(DType::Int32.is_int());
221        assert!(DType::Int64.is_int());
222        assert!(DType::Uint8.is_int());
223        assert!(DType::Uint16.is_int());
224        assert!(DType::Uint32.is_int());
225        assert!(DType::Uint64.is_int());
226        assert!(!DType::Float32.is_int());
227        assert!(!DType::Bool.is_int());
228    }
229
230    #[test]
231    fn test_type_promotion() {
232        assert_eq!(DType::promote(DType::Bool, DType::Int32), DType::Int32);
233        assert_eq!(DType::promote(DType::Uint32, DType::Int32), DType::Int32);
234        assert_eq!(
235            DType::promote(DType::Int32, DType::Float16),
236            DType::Float16
237        );
238        assert_eq!(
239            DType::promote(DType::Float16, DType::Float32),
240            DType::Float32
241        );
242        assert_eq!(
243            DType::promote(DType::Uint32, DType::Float32),
244            DType::Float32
245        );
246        assert_eq!(
247            DType::promote(DType::Float32, DType::Float32),
248            DType::Float32
249        );
250        assert_eq!(DType::promote(DType::Uint8, DType::Uint16), DType::Uint16);
251        assert_eq!(DType::promote(DType::Int8, DType::Int16), DType::Int16);
252    }
253
254    #[test]
255    fn test_display() {
256        assert_eq!(DType::Float32.to_string(), "float32");
257        assert_eq!(DType::Float16.to_string(), "float16");
258        assert_eq!(DType::Float64.to_string(), "float64");
259        assert_eq!(DType::Int8.to_string(), "int8");
260        assert_eq!(DType::Int16.to_string(), "int16");
261        assert_eq!(DType::Int32.to_string(), "int32");
262        assert_eq!(DType::Int64.to_string(), "int64");
263        assert_eq!(DType::Uint8.to_string(), "uint8");
264        assert_eq!(DType::Uint16.to_string(), "uint16");
265        assert_eq!(DType::Uint32.to_string(), "uint32");
266        assert_eq!(DType::Uint64.to_string(), "uint64");
267        assert_eq!(DType::Bool.to_string(), "bool");
268    }
269
270    #[test]
271    fn test_from_str() {
272        assert_eq!(DType::from_str("float32"), Some(DType::Float32));
273        assert_eq!(DType::from_str("f32"), Some(DType::Float32));
274        assert_eq!(DType::from_str("int8"), Some(DType::Int8));
275        assert_eq!(DType::from_str("i8"), Some(DType::Int8));
276        assert_eq!(DType::from_str("uint16"), Some(DType::Uint16));
277        assert_eq!(DType::from_str("bool"), Some(DType::Bool));
278        assert_eq!(DType::from_str("unknown"), None);
279    }
280
281    #[test]
282    fn test_cast_from_f32() {
283        assert_eq!(DType::Int8.cast_from_f32(127.5), 127.0);
284        assert_eq!(DType::Int8.cast_from_f32(-128.0), -128.0);
285        assert_eq!(DType::Uint8.cast_from_f32(255.5), 255.0);
286        assert_eq!(DType::Bool.cast_from_f32(0.0), 0.0);
287        assert_eq!(DType::Bool.cast_from_f32(42.0), 1.0);
288    }
289}