burn_tensor/tensor/element/
cast.rs

1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5/// A generic trait for converting a value to a number.
6/// Adapted from [num_traits::ToPrimitive] to support [bool].
7///
8/// A value can be represented by the target type when it lies within
9/// the range of scalars supported by the target type.
10/// For example, a negative integer cannot be represented by an unsigned
11/// integer type, and an `i64` with a very high magnitude might not be
12/// convertible to an `i32`.
13/// On the other hand, conversions with possible precision loss or truncation
14/// are admitted, like an `f32` with a decimal part to an integer type, or
15/// even a large `f64` saturating to `f32` infinity.
16///
17/// The methods *panic* when the value cannot be represented by the target type.
18pub trait ToElement {
19    /// Converts the value of `self` to an `isize`.
20    #[inline]
21    fn to_isize(&self) -> isize {
22        ToElement::to_isize(&self.to_i64())
23    }
24
25    /// Converts the value of `self` to an `i8`.
26    #[inline]
27    fn to_i8(&self) -> i8 {
28        ToElement::to_i8(&self.to_i64())
29    }
30
31    /// Converts the value of `self` to an `i16`.
32    #[inline]
33    fn to_i16(&self) -> i16 {
34        ToElement::to_i16(&self.to_i64())
35    }
36
37    /// Converts the value of `self` to an `i32`.
38    #[inline]
39    fn to_i32(&self) -> i32 {
40        ToElement::to_i32(&self.to_i64())
41    }
42
43    /// Converts the value of `self` to an `i64`.
44    fn to_i64(&self) -> i64;
45
46    /// Converts the value of `self` to an `i128`.
47    ///
48    /// The default implementation converts through `to_i64()`. Types implementing
49    /// this trait should override this method if they can represent a greater range.
50    #[inline]
51    fn to_i128(&self) -> i128 {
52        i128::from(self.to_i64())
53    }
54
55    /// Converts the value of `self` to a `usize`.
56    #[inline]
57    fn to_usize(&self) -> usize {
58        ToElement::to_usize(&self.to_u64())
59    }
60
61    /// Converts the value of `self` to a `u8`.
62    #[inline]
63    fn to_u8(&self) -> u8 {
64        ToElement::to_u8(&self.to_u64())
65    }
66
67    /// Converts the value of `self` to a `u16`.
68    #[inline]
69    fn to_u16(&self) -> u16 {
70        ToElement::to_u16(&self.to_u64())
71    }
72
73    /// Converts the value of `self` to a `u32`.
74    #[inline]
75    fn to_u32(&self) -> u32 {
76        ToElement::to_u32(&self.to_u64())
77    }
78
79    /// Converts the value of `self` to a `u64`.
80    fn to_u64(&self) -> u64;
81
82    /// Converts the value of `self` to a `u128`.
83    ///
84    /// The default implementation converts through `to_u64()`. Types implementing
85    /// this trait should override this method if they can represent a greater range.
86    #[inline]
87    fn to_u128(&self) -> u128 {
88        u128::from(self.to_u64())
89    }
90
91    /// Converts the value of `self` to an `f32`. Overflows may map to positive
92    /// or negative infinity.
93    #[inline]
94    fn to_f32(&self) -> f32 {
95        ToElement::to_f32(&self.to_f64())
96    }
97
98    /// Converts the value of `self` to an `f64`. Overflows may map to positive
99    /// or negative infinity.
100    ///
101    /// The default implementation tries to convert through `to_i64()`, and
102    /// failing that through `to_u64()`. Types implementing this trait should
103    /// override this method if they can represent a greater range.
104    #[inline]
105    fn to_f64(&self) -> f64 {
106        ToElement::to_f64(&self.to_u64())
107    }
108}
109
110macro_rules! impl_to_element_int_to_int {
111    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
112        #[inline]
113        $(#[$cfg])*
114        fn $method(&self) -> $DstT {
115            let min = $DstT::MIN as $SrcT;
116            let max = $DstT::MAX as $SrcT;
117            if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
118                *self as $DstT
119            } else {
120                panic!("Element cannot be represented in the target type")
121            }
122        }
123    )*}
124}
125
126macro_rules! impl_to_element_int_to_uint {
127    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
128        #[inline]
129        $(#[$cfg])*
130        fn $method(&self) -> $DstT {
131            let max = $DstT::MAX as $SrcT;
132            if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
133                *self as $DstT
134            } else {
135                panic!("Element cannot be represented in the target type")
136            }
137        }
138    )*}
139}
140
141macro_rules! impl_to_element_int {
142    ($T:ident) => {
143        impl ToElement for $T {
144            impl_to_element_int_to_int! { $T:
145                fn to_isize -> isize;
146                fn to_i8 -> i8;
147                fn to_i16 -> i16;
148                fn to_i32 -> i32;
149                fn to_i64 -> i64;
150                fn to_i128 -> i128;
151            }
152
153            impl_to_element_int_to_uint! { $T:
154                fn to_usize -> usize;
155                fn to_u8 -> u8;
156                fn to_u16 -> u16;
157                fn to_u32 -> u32;
158                fn to_u64 -> u64;
159                fn to_u128 -> u128;
160            }
161
162            #[inline]
163            fn to_f32(&self) -> f32 {
164                *self as f32
165            }
166            #[inline]
167            fn to_f64(&self) -> f64 {
168                *self as f64
169            }
170        }
171    };
172}
173
174impl_to_element_int!(isize);
175impl_to_element_int!(i8);
176impl_to_element_int!(i16);
177impl_to_element_int!(i32);
178impl_to_element_int!(i64);
179impl_to_element_int!(i128);
180
181macro_rules! impl_to_element_uint_to_int {
182    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
183        #[inline]
184        $(#[$cfg])*
185        fn $method(&self) -> $DstT {
186            let max = $DstT::MAX as $SrcT;
187            if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
188                *self as $DstT
189            } else {
190                panic!("Element cannot be represented in the target type")
191            }
192        }
193    )*}
194}
195
196macro_rules! impl_to_element_uint_to_uint {
197    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
198        #[inline]
199        $(#[$cfg])*
200        fn $method(&self) -> $DstT {
201            let max = $DstT::MAX as $SrcT;
202            if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
203                *self as $DstT
204            } else {
205                panic!("Element cannot be represented in the target type")
206            }
207        }
208    )*}
209}
210
211macro_rules! impl_to_element_uint {
212    ($T:ident) => {
213        impl ToElement for $T {
214            impl_to_element_uint_to_int! { $T:
215                fn to_isize -> isize;
216                fn to_i8 -> i8;
217                fn to_i16 -> i16;
218                fn to_i32 -> i32;
219                fn to_i64 -> i64;
220                fn to_i128 -> i128;
221            }
222
223            impl_to_element_uint_to_uint! { $T:
224                fn to_usize -> usize;
225                fn to_u8 -> u8;
226                fn to_u16 -> u16;
227                fn to_u32 -> u32;
228                fn to_u64 -> u64;
229                fn to_u128 -> u128;
230            }
231
232            #[inline]
233            fn to_f32(&self) -> f32 {
234                *self as f32
235            }
236            #[inline]
237            fn to_f64(&self) -> f64 {
238                *self as f64
239            }
240        }
241    };
242}
243
244impl_to_element_uint!(usize);
245impl_to_element_uint!(u8);
246impl_to_element_uint!(u16);
247impl_to_element_uint!(u32);
248impl_to_element_uint!(u64);
249impl_to_element_uint!(u128);
250
251macro_rules! impl_to_element_float_to_float {
252    ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
253        #[inline]
254        fn $method(&self) -> $DstT {
255            // We can safely cast all values, whether NaN, +-inf, or finite.
256            // Finite values that are reducing size may saturate to +-inf.
257            *self as $DstT
258        }
259    )*}
260}
261
262macro_rules! float_to_int_unchecked {
263    // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
264    // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
265    ($float:expr => $int:ty) => {
266        unsafe { $float.to_int_unchecked::<$int>() }
267    };
268}
269
270macro_rules! impl_to_element_float_to_signed_int {
271    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
272        #[inline]
273        $(#[$cfg])*
274        fn $method(&self) -> $i {
275            // Float as int truncates toward zero, so we want to allow values
276            // in the exclusive range `(MIN-1, MAX+1)`.
277            if size_of::<$f>() > size_of::<$i>() {
278                // With a larger size, we can represent the range exactly.
279                const MIN_M1: $f = $i::MIN as $f - 1.0;
280                const MAX_P1: $f = $i::MAX as $f + 1.0;
281                if *self > MIN_M1 && *self < MAX_P1 {
282                    return float_to_int_unchecked!(*self => $i);
283                }
284            } else {
285                // We can't represent `MIN-1` exactly, but there's no fractional part
286                // at this magnitude, so we can just use a `MIN` inclusive boundary.
287                const MIN: $f = $i::MIN as $f;
288                // We can't represent `MAX` exactly, but it will round up to exactly
289                // `MAX+1` (a power of two) when we cast it.
290                const MAX_P1: $f = $i::MAX as $f;
291                if *self >= MIN && *self < MAX_P1 {
292                    return float_to_int_unchecked!(*self => $i);
293                }
294            }
295            panic!("Float cannot be represented in the target signed int type")
296        }
297    )*}
298}
299
300macro_rules! impl_to_element_float_to_unsigned_int {
301    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
302        #[inline]
303        $(#[$cfg])*
304        fn $method(&self) -> $u {
305            // Float as int truncates toward zero, so we want to allow values
306            // in the exclusive range `(-1, MAX+1)`.
307            if size_of::<$f>() > size_of::<$u>() {
308                // With a larger size, we can represent the range exactly.
309                const MAX_P1: $f = $u::MAX as $f + 1.0;
310                if *self > -1.0 && *self < MAX_P1 {
311                    return float_to_int_unchecked!(*self => $u);
312                }
313            } else {
314                // We can't represent `MAX` exactly, but it will round up to exactly
315                // `MAX+1` (a power of two) when we cast it.
316                // (`u128::MAX as f32` is infinity, but this is still ok.)
317                const MAX_P1: $f = $u::MAX as $f;
318                if *self > -1.0 && *self < MAX_P1 {
319                    return float_to_int_unchecked!(*self => $u);
320                }
321            }
322            panic!("Float cannot be represented in the target unsigned int type")
323        }
324    )*}
325}
326
327macro_rules! impl_to_element_float {
328    ($T:ident) => {
329        impl ToElement for $T {
330            impl_to_element_float_to_signed_int! { $T:
331                fn to_isize -> isize;
332                fn to_i8 -> i8;
333                fn to_i16 -> i16;
334                fn to_i32 -> i32;
335                fn to_i64 -> i64;
336                fn to_i128 -> i128;
337            }
338
339            impl_to_element_float_to_unsigned_int! { $T:
340                fn to_usize -> usize;
341                fn to_u8 -> u8;
342                fn to_u16 -> u16;
343                fn to_u32 -> u32;
344                fn to_u64 -> u64;
345                fn to_u128 -> u128;
346            }
347
348            impl_to_element_float_to_float! { $T:
349                fn to_f32 -> f32;
350                fn to_f64 -> f64;
351            }
352        }
353    };
354}
355
356impl_to_element_float!(f32);
357impl_to_element_float!(f64);
358
359impl ToElement for f16 {
360    #[inline]
361    fn to_i64(&self) -> i64 {
362        Self::to_f32(*self).to_i64()
363    }
364    #[inline]
365    fn to_u64(&self) -> u64 {
366        Self::to_f32(*self).to_u64()
367    }
368    #[inline]
369    fn to_i8(&self) -> i8 {
370        Self::to_f32(*self).to_i8()
371    }
372    #[inline]
373    fn to_u8(&self) -> u8 {
374        Self::to_f32(*self).to_u8()
375    }
376    #[inline]
377    fn to_i16(&self) -> i16 {
378        Self::to_f32(*self).to_i16()
379    }
380    #[inline]
381    fn to_u16(&self) -> u16 {
382        Self::to_f32(*self).to_u16()
383    }
384    #[inline]
385    fn to_i32(&self) -> i32 {
386        Self::to_f32(*self).to_i32()
387    }
388    #[inline]
389    fn to_u32(&self) -> u32 {
390        Self::to_f32(*self).to_u32()
391    }
392    #[inline]
393    fn to_f32(&self) -> f32 {
394        Self::to_f32(*self)
395    }
396    #[inline]
397    fn to_f64(&self) -> f64 {
398        Self::to_f64(*self)
399    }
400}
401
402impl ToElement for bf16 {
403    #[inline]
404    fn to_i64(&self) -> i64 {
405        Self::to_f32(*self).to_i64()
406    }
407    #[inline]
408    fn to_u64(&self) -> u64 {
409        Self::to_f32(*self).to_u64()
410    }
411    #[inline]
412    fn to_i8(&self) -> i8 {
413        Self::to_f32(*self).to_i8()
414    }
415    #[inline]
416    fn to_u8(&self) -> u8 {
417        Self::to_f32(*self).to_u8()
418    }
419    #[inline]
420    fn to_i16(&self) -> i16 {
421        Self::to_f32(*self).to_i16()
422    }
423    #[inline]
424    fn to_u16(&self) -> u16 {
425        Self::to_f32(*self).to_u16()
426    }
427    #[inline]
428    fn to_i32(&self) -> i32 {
429        Self::to_f32(*self).to_i32()
430    }
431    #[inline]
432    fn to_u32(&self) -> u32 {
433        Self::to_f32(*self).to_u32()
434    }
435    #[inline]
436    fn to_f32(&self) -> f32 {
437        Self::to_f32(*self)
438    }
439    #[inline]
440    fn to_f64(&self) -> f64 {
441        Self::to_f64(*self)
442    }
443}
444
445#[cfg(feature = "cubecl")]
446impl ToElement for cubecl::flex32 {
447    #[inline]
448    fn to_i64(&self) -> i64 {
449        Self::to_f32(*self).to_i64()
450    }
451    #[inline]
452    fn to_u64(&self) -> u64 {
453        Self::to_f32(*self).to_u64()
454    }
455    #[inline]
456    fn to_i8(&self) -> i8 {
457        Self::to_f32(*self).to_i8()
458    }
459    #[inline]
460    fn to_u8(&self) -> u8 {
461        Self::to_f32(*self).to_u8()
462    }
463    #[inline]
464    fn to_i16(&self) -> i16 {
465        Self::to_f32(*self).to_i16()
466    }
467    #[inline]
468    fn to_u16(&self) -> u16 {
469        Self::to_f32(*self).to_u16()
470    }
471    #[inline]
472    fn to_i32(&self) -> i32 {
473        Self::to_f32(*self).to_i32()
474    }
475    #[inline]
476    fn to_u32(&self) -> u32 {
477        Self::to_f32(*self).to_u32()
478    }
479    #[inline]
480    fn to_f32(&self) -> f32 {
481        Self::to_f32(*self)
482    }
483    #[inline]
484    fn to_f64(&self) -> f64 {
485        Self::to_f64(*self)
486    }
487}
488
489impl ToElement for bool {
490    #[inline]
491    fn to_i64(&self) -> i64 {
492        *self as i64
493    }
494    #[inline]
495    fn to_u64(&self) -> u64 {
496        *self as u64
497    }
498    #[inline]
499    fn to_i8(&self) -> i8 {
500        *self as i8
501    }
502    #[inline]
503    fn to_u8(&self) -> u8 {
504        *self as u8
505    }
506    #[inline]
507    fn to_i16(&self) -> i16 {
508        *self as i16
509    }
510    #[inline]
511    fn to_u16(&self) -> u16 {
512        *self as u16
513    }
514    #[inline]
515    fn to_i32(&self) -> i32 {
516        *self as i32
517    }
518    #[inline]
519    fn to_u32(&self) -> u32 {
520        *self as u32
521    }
522    #[inline]
523    fn to_f32(&self) -> f32 {
524        self.to_u8() as f32
525    }
526    #[inline]
527    fn to_f64(&self) -> f64 {
528        self.to_u8() as f64
529    }
530}
531
532mod tests {
533    #[allow(unused_imports)]
534    use super::*;
535
536    #[test]
537    fn to_element_float() {
538        let f32_toolarge = 1e39f64;
539        assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
540        assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
541        assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
542        assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
543        assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
544        assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
545        assert!((f64::NAN).to_f32().is_nan());
546    }
547
548    #[test]
549    #[should_panic]
550    fn to_element_signed_to_u8_underflow() {
551        let _x = (-1i8).to_u8();
552    }
553
554    #[test]
555    #[should_panic]
556    fn to_element_signed_to_u16_underflow() {
557        let _x = (-1i8).to_u16();
558    }
559
560    #[test]
561    #[should_panic]
562    fn to_element_signed_to_u32_underflow() {
563        let _x = (-1i8).to_u32();
564    }
565
566    #[test]
567    #[should_panic]
568    fn to_element_signed_to_u64_underflow() {
569        let _x = (-1i8).to_u64();
570    }
571
572    #[test]
573    #[should_panic]
574    fn to_element_signed_to_u128_underflow() {
575        let _x = (-1i8).to_u128();
576    }
577
578    #[test]
579    #[should_panic]
580    fn to_element_signed_to_usize_underflow() {
581        let _x = (-1i8).to_usize();
582    }
583
584    #[test]
585    #[should_panic]
586    fn to_element_unsigned_to_u8_overflow() {
587        let _x = 256.to_u8();
588    }
589
590    #[test]
591    #[should_panic]
592    fn to_element_unsigned_to_u16_overflow() {
593        let _x = 65_536.to_u16();
594    }
595
596    #[test]
597    #[should_panic]
598    fn to_element_unsigned_to_u32_overflow() {
599        let _x = 4_294_967_296u64.to_u32();
600    }
601
602    #[test]
603    #[should_panic]
604    fn to_element_unsigned_to_u64_overflow() {
605        let _x = 18_446_744_073_709_551_616u128.to_u64();
606    }
607
608    #[test]
609    fn to_element_int_to_float() {
610        assert_eq!((-1).to_f32(), -1.0);
611        assert_eq!((-1).to_f64(), -1.0);
612        assert_eq!(255.to_f32(), 255.0);
613        assert_eq!(65_535.to_f64(), 65_535.0);
614    }
615
616    #[test]
617    fn to_element_float_to_int() {
618        assert_eq!((-1.0).to_i8(), -1);
619        assert_eq!(1.0.to_u8(), 1);
620        assert_eq!(1.8.to_u16(), 1);
621        assert_eq!(123.456.to_u32(), 123);
622    }
623}