Skip to main content

numr/dtype/
element.rs

1//! Element trait for mapping Rust types to DType
2
3use super::DType;
4use bytemuck::{Pod, Zeroable};
5use std::ops::{Add, Div, Mul, Sub};
6
7/// Trait for types that can be elements of a tensor
8///
9/// This trait connects Rust's type system to numr's runtime dtype system.
10/// It's implemented for all primitive numeric types.
11///
12/// # Bounds
13/// - `Copy + Clone + Send + Sync + 'static` - Basic trait requirements
14/// - `Pod + Zeroable` - Safe memory transmutation (bytemuck)
15/// - `Add + Sub + Mul + Div` - Arithmetic operations (Output = Self)
16/// - `PartialOrd` - Comparison for min/max operations
17///
18/// Note: `Neg` is NOT required since unsigned types don't support it.
19/// Negation is handled via to_f64/from_f64 conversion in kernels.
20pub trait Element:
21    Copy
22    + Clone
23    + Send
24    + Sync
25    + Pod
26    + Zeroable
27    + 'static
28    + Add<Output = Self>
29    + Sub<Output = Self>
30    + Mul<Output = Self>
31    + Div<Output = Self>
32    + PartialOrd
33{
34    /// The corresponding DType for this Rust type
35    const DTYPE: DType;
36
37    /// Convert to f64 for generic numeric operations
38    ///
39    /// # Complex Number Behavior
40    ///
41    /// For complex types (Complex64, Complex128), this returns the **magnitude** (|z|),
42    /// not the real part. This is consistent with:
43    /// - PartialOrd using magnitude for comparison
44    /// - The need for a single scalar representation
45    ///
46    /// If you need the real part, access `.re` directly on the complex type.
47    fn to_f64(self) -> f64;
48
49    /// Convert from f64 to this type
50    ///
51    /// # Complex Number Behavior
52    ///
53    /// For complex types, this creates a **real number** (imaginary part = 0).
54    fn from_f64(v: f64) -> Self;
55
56    /// Convert to f32
57    ///
58    /// Default implementation goes through f64. Types that have direct f32
59    /// conversion (f16, bf16, fp8) override this for efficiency.
60    #[inline]
61    fn to_f32(self) -> f32 {
62        self.to_f64() as f32
63    }
64
65    /// Convert from f32 to this type
66    ///
67    /// Default implementation goes through f64. Types that have direct f32
68    /// conversion (f16, bf16, fp8) override this for efficiency.
69    #[inline]
70    fn from_f32(v: f32) -> Self {
71        Self::from_f64(v as f64)
72    }
73
74    /// Zero value
75    fn zero() -> Self;
76
77    /// One value
78    fn one() -> Self;
79}
80
81impl Element for f64 {
82    const DTYPE: DType = DType::F64;
83
84    #[inline]
85    fn to_f64(self) -> f64 {
86        self
87    }
88
89    #[inline]
90    fn from_f64(v: f64) -> Self {
91        v
92    }
93
94    #[inline]
95    fn to_f32(self) -> f32 {
96        self as f32
97    }
98
99    #[inline]
100    fn from_f32(v: f32) -> Self {
101        v as f64
102    }
103
104    #[inline]
105    fn zero() -> Self {
106        0.0
107    }
108
109    #[inline]
110    fn one() -> Self {
111        1.0
112    }
113}
114
115impl Element for f32 {
116    const DTYPE: DType = DType::F32;
117
118    #[inline]
119    fn to_f64(self) -> f64 {
120        self as f64
121    }
122
123    #[inline]
124    fn from_f64(v: f64) -> Self {
125        v as f32
126    }
127
128    #[inline]
129    fn to_f32(self) -> f32 {
130        self
131    }
132
133    #[inline]
134    fn from_f32(v: f32) -> Self {
135        v
136    }
137
138    #[inline]
139    fn zero() -> Self {
140        0.0
141    }
142
143    #[inline]
144    fn one() -> Self {
145        1.0
146    }
147}
148
149impl Element for i64 {
150    const DTYPE: DType = DType::I64;
151
152    #[inline]
153    fn to_f64(self) -> f64 {
154        self as f64
155    }
156
157    #[inline]
158    fn from_f64(v: f64) -> Self {
159        v as i64
160    }
161
162    #[inline]
163    fn zero() -> Self {
164        0
165    }
166
167    #[inline]
168    fn one() -> Self {
169        1
170    }
171}
172
173impl Element for i32 {
174    const DTYPE: DType = DType::I32;
175
176    #[inline]
177    fn to_f64(self) -> f64 {
178        self as f64
179    }
180
181    #[inline]
182    fn from_f64(v: f64) -> Self {
183        v as i32
184    }
185
186    #[inline]
187    fn zero() -> Self {
188        0
189    }
190
191    #[inline]
192    fn one() -> Self {
193        1
194    }
195}
196
197impl Element for i16 {
198    const DTYPE: DType = DType::I16;
199
200    #[inline]
201    fn to_f64(self) -> f64 {
202        self as f64
203    }
204
205    #[inline]
206    fn from_f64(v: f64) -> Self {
207        v as i16
208    }
209
210    #[inline]
211    fn zero() -> Self {
212        0
213    }
214
215    #[inline]
216    fn one() -> Self {
217        1
218    }
219}
220
221impl Element for i8 {
222    const DTYPE: DType = DType::I8;
223
224    #[inline]
225    fn to_f64(self) -> f64 {
226        self as f64
227    }
228
229    #[inline]
230    fn from_f64(v: f64) -> Self {
231        v as i8
232    }
233
234    #[inline]
235    fn zero() -> Self {
236        0
237    }
238
239    #[inline]
240    fn one() -> Self {
241        1
242    }
243}
244
245impl Element for u64 {
246    const DTYPE: DType = DType::U64;
247
248    #[inline]
249    fn to_f64(self) -> f64 {
250        self as f64
251    }
252
253    #[inline]
254    fn from_f64(v: f64) -> Self {
255        v as u64
256    }
257
258    #[inline]
259    fn zero() -> Self {
260        0
261    }
262
263    #[inline]
264    fn one() -> Self {
265        1
266    }
267}
268
269impl Element for u32 {
270    const DTYPE: DType = DType::U32;
271
272    #[inline]
273    fn to_f64(self) -> f64 {
274        self as f64
275    }
276
277    #[inline]
278    fn from_f64(v: f64) -> Self {
279        v as u32
280    }
281
282    #[inline]
283    fn zero() -> Self {
284        0
285    }
286
287    #[inline]
288    fn one() -> Self {
289        1
290    }
291}
292
293impl Element for u16 {
294    const DTYPE: DType = DType::U16;
295
296    #[inline]
297    fn to_f64(self) -> f64 {
298        self as f64
299    }
300
301    #[inline]
302    fn from_f64(v: f64) -> Self {
303        v as u16
304    }
305
306    #[inline]
307    fn zero() -> Self {
308        0
309    }
310
311    #[inline]
312    fn one() -> Self {
313        1
314    }
315}
316
317impl Element for u8 {
318    const DTYPE: DType = DType::U8;
319
320    #[inline]
321    fn to_f64(self) -> f64 {
322        self as f64
323    }
324
325    #[inline]
326    fn from_f64(v: f64) -> Self {
327        v as u8
328    }
329
330    #[inline]
331    fn zero() -> Self {
332        0
333    }
334
335    #[inline]
336    fn one() -> Self {
337        1
338    }
339}
340
341// Note: bool doesn't implement Pod, so we can't implement Element for it directly.
342// Boolean tensors use u8 internally.
343
344// ============================================================================
345// Half-precision floating point types (requires "f16" feature)
346// ============================================================================
347
348#[cfg(feature = "f16")]
349impl Element for half::f16 {
350    const DTYPE: DType = DType::F16;
351
352    #[inline]
353    fn to_f64(self) -> f64 {
354        self.to_f64()
355    }
356
357    #[inline]
358    fn from_f64(v: f64) -> Self {
359        half::f16::from_f64(v)
360    }
361
362    #[inline]
363    fn to_f32(self) -> f32 {
364        self.to_f32()
365    }
366
367    #[inline]
368    fn from_f32(v: f32) -> Self {
369        half::f16::from_f32(v)
370    }
371
372    #[inline]
373    fn zero() -> Self {
374        half::f16::ZERO
375    }
376
377    #[inline]
378    fn one() -> Self {
379        half::f16::ONE
380    }
381}
382
383#[cfg(feature = "f16")]
384impl Element for half::bf16 {
385    const DTYPE: DType = DType::BF16;
386
387    #[inline]
388    fn to_f64(self) -> f64 {
389        self.to_f64()
390    }
391
392    #[inline]
393    fn from_f64(v: f64) -> Self {
394        half::bf16::from_f64(v)
395    }
396
397    #[inline]
398    fn to_f32(self) -> f32 {
399        self.to_f32()
400    }
401
402    #[inline]
403    fn from_f32(v: f32) -> Self {
404        half::bf16::from_f32(v)
405    }
406
407    #[inline]
408    fn zero() -> Self {
409        half::bf16::ZERO
410    }
411
412    #[inline]
413    fn one() -> Self {
414        half::bf16::ONE
415    }
416}
417
418// ============================================================================
419// 8-bit floating point types (requires "fp8" feature)
420// ============================================================================
421
422impl Element for super::fp8::FP8E4M3 {
423    const DTYPE: DType = DType::FP8E4M3;
424
425    #[inline]
426    fn to_f64(self) -> f64 {
427        self.to_f32() as f64
428    }
429
430    #[inline]
431    fn from_f64(v: f64) -> Self {
432        Self::from_f32(v as f32)
433    }
434
435    #[inline]
436    fn to_f32(self) -> f32 {
437        self.to_f32()
438    }
439
440    #[inline]
441    fn from_f32(v: f32) -> Self {
442        Self::from_f32(v)
443    }
444
445    #[inline]
446    fn zero() -> Self {
447        Self::ZERO
448    }
449
450    #[inline]
451    fn one() -> Self {
452        Self::ONE
453    }
454}
455
456impl Element for super::fp8::FP8E5M2 {
457    const DTYPE: DType = DType::FP8E5M2;
458
459    #[inline]
460    fn to_f64(self) -> f64 {
461        self.to_f32() as f64
462    }
463
464    #[inline]
465    fn from_f64(v: f64) -> Self {
466        Self::from_f32(v as f32)
467    }
468
469    #[inline]
470    fn to_f32(self) -> f32 {
471        self.to_f32()
472    }
473
474    #[inline]
475    fn from_f32(v: f32) -> Self {
476        Self::from_f32(v)
477    }
478
479    #[inline]
480    fn zero() -> Self {
481        Self::ZERO
482    }
483
484    #[inline]
485    fn one() -> Self {
486        Self::ONE
487    }
488}
489
490// ============================================================================
491// Complex types
492//
493// Complex number conversion semantics:
494// - to_f64(): Returns magnitude (|z| = sqrt(re² + im²))
495//   This is intentional - a lossy conversion that provides a single scalar.
496//   For the real part, use z.re directly.
497// - from_f64(): Creates a real number (im = 0)
498//
499// These semantics are consistent with PartialOrd (compare by magnitude).
500// ============================================================================
501
502impl Element for super::complex::Complex64 {
503    const DTYPE: DType = DType::Complex64;
504
505    /// Returns magnitude (|z|) - this is a lossy conversion.
506    /// For the real part, use `.re` directly.
507    #[inline]
508    fn to_f64(self) -> f64 {
509        self.magnitude() as f64
510    }
511
512    /// Creates a real complex number (im = 0)
513    #[inline]
514    fn from_f64(v: f64) -> Self {
515        Self::new(v as f32, 0.0)
516    }
517
518    #[inline]
519    fn zero() -> Self {
520        Self::ZERO
521    }
522
523    #[inline]
524    fn one() -> Self {
525        Self::ONE
526    }
527}
528
529impl Element for super::complex::Complex128 {
530    const DTYPE: DType = DType::Complex128;
531
532    /// Returns magnitude (|z|) - this is a lossy conversion.
533    /// For the real part, use `.re` directly.
534    #[inline]
535    fn to_f64(self) -> f64 {
536        self.magnitude()
537    }
538
539    /// Creates a real complex number (im = 0)
540    #[inline]
541    fn from_f64(v: f64) -> Self {
542        Self::new(v, 0.0)
543    }
544
545    #[inline]
546    fn zero() -> Self {
547        Self::ZERO
548    }
549
550    #[inline]
551    fn one() -> Self {
552        Self::ONE
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_element_dtype() {
562        assert_eq!(f64::DTYPE, DType::F64);
563        assert_eq!(f32::DTYPE, DType::F32);
564        assert_eq!(i32::DTYPE, DType::I32);
565        assert_eq!(u8::DTYPE, DType::U8);
566    }
567
568    #[test]
569    fn test_element_conversions() {
570        assert_eq!(f32::from_f64(2.5).to_f64(), 2.5f32 as f64);
571        assert_eq!(i32::from_f64(42.0), 42);
572    }
573
574    #[test]
575    fn test_fp8_element_dtype() {
576        use super::super::fp8::{FP8E4M3, FP8E5M2};
577        assert_eq!(FP8E4M3::DTYPE, DType::FP8E4M3);
578        assert_eq!(FP8E5M2::DTYPE, DType::FP8E5M2);
579    }
580
581    #[test]
582    fn test_fp8_element_conversions() {
583        use super::super::fp8::{FP8E4M3, FP8E5M2};
584
585        // FP8E4M3 roundtrip
586        let e4m3 = FP8E4M3::from_f64(2.0);
587        assert!((e4m3.to_f64() - 2.0).abs() < 0.1);
588
589        // FP8E5M2 roundtrip
590        let e5m2 = FP8E5M2::from_f64(100.0);
591        assert!((e5m2.to_f64() - 100.0).abs() < 15.0);
592
593        // Zero and one
594        assert_eq!(FP8E4M3::zero().to_f32(), 0.0);
595        assert!((FP8E4M3::one().to_f32() - 1.0).abs() < 0.01);
596        assert_eq!(FP8E5M2::zero().to_f32(), 0.0);
597        assert!((FP8E5M2::one().to_f32() - 1.0).abs() < 0.01);
598    }
599}