burn_tensor/tensor/element/
cast.rs

1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5/// A generic trait for converting a value to a number.
6/// Adapted from [num_traits::ToPrimitive] to support [bool].
7///
8/// A value can be represented by the target type when it lies within
9/// the range of scalars supported by the target type.
10/// For example, a negative integer cannot be represented by an unsigned
11/// integer type, and an `i64` with a very high magnitude might not be
12/// convertible to an `i32`.
13/// On the other hand, conversions with possible precision loss or truncation
14/// are admitted, like an `f32` with a decimal part to an integer type, or
15/// even a large `f64` saturating to `f32` infinity.
16///
17/// The methods *panic* when the value cannot be represented by the target type.
18pub trait ToElement {
19    /// Converts the value of `self` to an `isize`.
20    #[inline]
21    fn to_isize(&self) -> isize {
22        ToElement::to_isize(&self.to_i64())
23    }
24
25    /// Converts the value of `self` to an `i8`.
26    #[inline]
27    fn to_i8(&self) -> i8 {
28        ToElement::to_i8(&self.to_i64())
29    }
30
31    /// Converts the value of `self` to an `i16`.
32    #[inline]
33    fn to_i16(&self) -> i16 {
34        ToElement::to_i16(&self.to_i64())
35    }
36
37    /// Converts the value of `self` to an `i32`.
38    #[inline]
39    fn to_i32(&self) -> i32 {
40        ToElement::to_i32(&self.to_i64())
41    }
42
43    /// Converts the value of `self` to an `i64`.
44    fn to_i64(&self) -> i64;
45
46    /// Converts the value of `self` to an `i128`.
47    ///
48    /// The default implementation converts through `to_i64()`. Types implementing
49    /// this trait should override this method if they can represent a greater range.
50    #[inline]
51    fn to_i128(&self) -> i128 {
52        i128::from(self.to_i64())
53    }
54
55    /// Converts the value of `self` to a `usize`.
56    #[inline]
57    fn to_usize(&self) -> usize {
58        ToElement::to_usize(&self.to_u64())
59    }
60
61    /// Converts the value of `self` to a `u8`.
62    #[inline]
63    fn to_u8(&self) -> u8 {
64        ToElement::to_u8(&self.to_u64())
65    }
66
67    /// Converts the value of `self` to a `u16`.
68    #[inline]
69    fn to_u16(&self) -> u16 {
70        ToElement::to_u16(&self.to_u64())
71    }
72
73    /// Converts the value of `self` to a `u32`.
74    #[inline]
75    fn to_u32(&self) -> u32 {
76        ToElement::to_u32(&self.to_u64())
77    }
78
79    /// Converts the value of `self` to a `u64`.
80    fn to_u64(&self) -> u64;
81
82    /// Converts the value of `self` to a `u128`.
83    ///
84    /// The default implementation converts through `to_u64()`. Types implementing
85    /// this trait should override this method if they can represent a greater range.
86    #[inline]
87    fn to_u128(&self) -> u128 {
88        u128::from(self.to_u64())
89    }
90
91    /// Converts the value of `self` to an `f16`. Overflows may map to positive
92    /// or negative infinity.
93    #[inline]
94    fn to_f16(&self) -> f16 {
95        f16::from_f32(self.to_f32())
96    }
97
98    /// Converts the value of `self` to an `bf16`. Overflows may map to positive
99    /// or negative infinity.
100    #[inline]
101    fn to_bf16(&self) -> bf16 {
102        bf16::from_f32(self.to_f32())
103    }
104
105    /// Converts the value of `self` to an `f32`. Overflows may map to positive
106    /// or negative infinity.
107    #[inline]
108    fn to_f32(&self) -> f32 {
109        ToElement::to_f32(&self.to_f64())
110    }
111
112    /// Converts the value of `self` to an `f64`. Overflows may map to positive
113    /// or negative infinity.
114    ///
115    /// The default implementation tries to convert through `to_i64()`, and
116    /// failing that through `to_u64()`. Types implementing this trait should
117    /// override this method if they can represent a greater range.
118    #[inline]
119    fn to_f64(&self) -> f64 {
120        ToElement::to_f64(&self.to_u64())
121    }
122
123    /// Converts the value of `self` to a bool.
124    /// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are
125    /// adopted (anything that's not 0 is true).
126    ///
127    /// The default implementation tries to convert through `to_i64()`, and
128    /// failing that through `to_u64()`. Types implementing this trait should
129    /// override this method if they can represent a greater range.
130    #[inline]
131    fn to_bool(&self) -> bool {
132        ToElement::to_bool(&self.to_u64())
133    }
134}
135
136macro_rules! impl_to_element_int_to_int {
137    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
138        #[inline]
139        $(#[$cfg])*
140        fn $method(&self) -> $DstT {
141            let min = $DstT::MIN as $SrcT;
142            let max = $DstT::MAX as $SrcT;
143            if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
144                *self as $DstT
145            } else {
146                panic!("Element cannot be represented in the target type")
147            }
148        }
149    )*}
150}
151
152macro_rules! impl_to_element_int_to_uint {
153    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
154        #[inline]
155        $(#[$cfg])*
156        fn $method(&self) -> $DstT {
157            let max = $DstT::MAX as $SrcT;
158            if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
159                *self as $DstT
160            } else {
161                panic!("Element cannot be represented in the target type")
162            }
163        }
164    )*}
165}
166
167macro_rules! impl_to_element_int {
168    ($T:ident) => {
169        impl ToElement for $T {
170            impl_to_element_int_to_int! { $T:
171                fn to_isize -> isize;
172                fn to_i8 -> i8;
173                fn to_i16 -> i16;
174                fn to_i32 -> i32;
175                fn to_i64 -> i64;
176                fn to_i128 -> i128;
177            }
178
179            impl_to_element_int_to_uint! { $T:
180                fn to_usize -> usize;
181                fn to_u8 -> u8;
182                fn to_u16 -> u16;
183                fn to_u32 -> u32;
184                fn to_u64 -> u64;
185                fn to_u128 -> u128;
186            }
187
188            #[inline]
189            fn to_f32(&self) -> f32 {
190                *self as f32
191            }
192            #[inline]
193            fn to_f64(&self) -> f64 {
194                *self as f64
195            }
196            #[inline]
197            fn to_bool(&self) -> bool {
198                *self != 0
199            }
200        }
201    };
202}
203
204impl_to_element_int!(isize);
205impl_to_element_int!(i8);
206impl_to_element_int!(i16);
207impl_to_element_int!(i32);
208impl_to_element_int!(i64);
209impl_to_element_int!(i128);
210
211macro_rules! impl_to_element_uint_to_int {
212    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
213        #[inline]
214        $(#[$cfg])*
215        fn $method(&self) -> $DstT {
216            let max = $DstT::MAX as $SrcT;
217            if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
218                *self as $DstT
219            } else {
220                panic!("Element cannot be represented in the target type")
221            }
222        }
223    )*}
224}
225
226macro_rules! impl_to_element_uint_to_uint {
227    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
228        #[inline]
229        $(#[$cfg])*
230        fn $method(&self) -> $DstT {
231            let max = $DstT::MAX as $SrcT;
232            if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
233                *self as $DstT
234            } else {
235                panic!("Element cannot be represented in the target type")
236            }
237        }
238    )*}
239}
240
241macro_rules! impl_to_element_uint {
242    ($T:ident) => {
243        impl ToElement for $T {
244            impl_to_element_uint_to_int! { $T:
245                fn to_isize -> isize;
246                fn to_i8 -> i8;
247                fn to_i16 -> i16;
248                fn to_i32 -> i32;
249                fn to_i64 -> i64;
250                fn to_i128 -> i128;
251            }
252
253            impl_to_element_uint_to_uint! { $T:
254                fn to_usize -> usize;
255                fn to_u8 -> u8;
256                fn to_u16 -> u16;
257                fn to_u32 -> u32;
258                fn to_u64 -> u64;
259                fn to_u128 -> u128;
260            }
261
262            #[inline]
263            fn to_f32(&self) -> f32 {
264                *self as f32
265            }
266            #[inline]
267            fn to_f64(&self) -> f64 {
268                *self as f64
269            }
270            #[inline]
271            fn to_bool(&self) -> bool {
272                *self != 0
273            }
274        }
275    };
276}
277
278impl_to_element_uint!(usize);
279impl_to_element_uint!(u8);
280impl_to_element_uint!(u16);
281impl_to_element_uint!(u32);
282impl_to_element_uint!(u64);
283impl_to_element_uint!(u128);
284
285macro_rules! impl_to_element_float_to_float {
286    ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
287        #[inline]
288        fn $method(&self) -> $DstT {
289            // We can safely cast all values, whether NaN, +-inf, or finite.
290            // Finite values that are reducing size may saturate to +-inf.
291            *self as $DstT
292        }
293    )*}
294}
295
296macro_rules! float_to_int_unchecked {
297    // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
298    // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
299    ($float:expr => $int:ty) => {
300        unsafe { $float.to_int_unchecked::<$int>() }
301    };
302}
303
304macro_rules! impl_to_element_float_to_signed_int {
305    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
306        #[inline]
307        $(#[$cfg])*
308        fn $method(&self) -> $i {
309            // Float as int truncates toward zero, so we want to allow values
310            // in the exclusive range `(MIN-1, MAX+1)`.
311            if size_of::<$f>() > size_of::<$i>() {
312                // With a larger size, we can represent the range exactly.
313                const MIN_M1: $f = $i::MIN as $f - 1.0;
314                const MAX_P1: $f = $i::MAX as $f + 1.0;
315                if *self > MIN_M1 && *self < MAX_P1 {
316                    return float_to_int_unchecked!(*self => $i);
317                }
318            } else {
319                // We can't represent `MIN-1` exactly, but there's no fractional part
320                // at this magnitude, so we can just use a `MIN` inclusive boundary.
321                const MIN: $f = $i::MIN as $f;
322                // We can't represent `MAX` exactly, but it will round up to exactly
323                // `MAX+1` (a power of two) when we cast it.
324                const MAX_P1: $f = $i::MAX as $f;
325                if *self >= MIN && *self < MAX_P1 {
326                    return float_to_int_unchecked!(*self => $i);
327                }
328            }
329            panic!("Float cannot be represented in the target signed int type")
330        }
331    )*}
332}
333
334macro_rules! impl_to_element_float_to_unsigned_int {
335    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
336        #[inline]
337        $(#[$cfg])*
338        fn $method(&self) -> $u {
339            // Float as int truncates toward zero, so we want to allow values
340            // in the exclusive range `(-1, MAX+1)`.
341            if size_of::<$f>() > size_of::<$u>() {
342                // With a larger size, we can represent the range exactly.
343                const MAX_P1: $f = $u::MAX as $f + 1.0;
344                if *self > -1.0 && *self < MAX_P1 {
345                    return float_to_int_unchecked!(*self => $u);
346                }
347            } else {
348                // We can't represent `MAX` exactly, but it will round up to exactly
349                // `MAX+1` (a power of two) when we cast it.
350                // (`u128::MAX as f32` is infinity, but this is still ok.)
351                const MAX_P1: $f = $u::MAX as $f;
352                if *self > -1.0 && *self < MAX_P1 {
353                    return float_to_int_unchecked!(*self => $u);
354                }
355            }
356            panic!("Float cannot be represented in the target unsigned int type")
357        }
358    )*}
359}
360
361macro_rules! impl_to_element_float {
362    ($T:ident) => {
363        impl ToElement for $T {
364            impl_to_element_float_to_signed_int! { $T:
365                fn to_isize -> isize;
366                fn to_i8 -> i8;
367                fn to_i16 -> i16;
368                fn to_i32 -> i32;
369                fn to_i64 -> i64;
370                fn to_i128 -> i128;
371            }
372
373            impl_to_element_float_to_unsigned_int! { $T:
374                fn to_usize -> usize;
375                fn to_u8 -> u8;
376                fn to_u16 -> u16;
377                fn to_u32 -> u32;
378                fn to_u64 -> u64;
379                fn to_u128 -> u128;
380            }
381
382            impl_to_element_float_to_float! { $T:
383                fn to_f32 -> f32;
384                fn to_f64 -> f64;
385            }
386
387            #[inline]
388            fn to_bool(&self) -> bool {
389                *self != 0.0
390            }
391        }
392    };
393}
394
395impl_to_element_float!(f32);
396impl_to_element_float!(f64);
397
398impl ToElement for f16 {
399    #[inline]
400    fn to_i64(&self) -> i64 {
401        Self::to_f32(*self).to_i64()
402    }
403    #[inline]
404    fn to_u64(&self) -> u64 {
405        Self::to_f32(*self).to_u64()
406    }
407    #[inline]
408    fn to_i8(&self) -> i8 {
409        Self::to_f32(*self).to_i8()
410    }
411    #[inline]
412    fn to_u8(&self) -> u8 {
413        Self::to_f32(*self).to_u8()
414    }
415    #[inline]
416    fn to_i16(&self) -> i16 {
417        Self::to_f32(*self).to_i16()
418    }
419    #[inline]
420    fn to_u16(&self) -> u16 {
421        Self::to_f32(*self).to_u16()
422    }
423    #[inline]
424    fn to_i32(&self) -> i32 {
425        Self::to_f32(*self).to_i32()
426    }
427    #[inline]
428    fn to_u32(&self) -> u32 {
429        Self::to_f32(*self).to_u32()
430    }
431    #[inline]
432    fn to_f16(&self) -> f16 {
433        *self
434    }
435    #[inline]
436    fn to_f32(&self) -> f32 {
437        Self::to_f32(*self)
438    }
439    #[inline]
440    fn to_f64(&self) -> f64 {
441        Self::to_f64(*self)
442    }
443    #[inline]
444    fn to_bool(&self) -> bool {
445        *self != f16::from_f32_const(0.0)
446    }
447}
448
449impl ToElement for bf16 {
450    #[inline]
451    fn to_i64(&self) -> i64 {
452        Self::to_f32(*self).to_i64()
453    }
454    #[inline]
455    fn to_u64(&self) -> u64 {
456        Self::to_f32(*self).to_u64()
457    }
458    #[inline]
459    fn to_i8(&self) -> i8 {
460        Self::to_f32(*self).to_i8()
461    }
462    #[inline]
463    fn to_u8(&self) -> u8 {
464        Self::to_f32(*self).to_u8()
465    }
466    #[inline]
467    fn to_i16(&self) -> i16 {
468        Self::to_f32(*self).to_i16()
469    }
470    #[inline]
471    fn to_u16(&self) -> u16 {
472        Self::to_f32(*self).to_u16()
473    }
474    #[inline]
475    fn to_i32(&self) -> i32 {
476        Self::to_f32(*self).to_i32()
477    }
478    #[inline]
479    fn to_u32(&self) -> u32 {
480        Self::to_f32(*self).to_u32()
481    }
482    #[inline]
483    fn to_bf16(&self) -> bf16 {
484        *self
485    }
486    #[inline]
487    fn to_f32(&self) -> f32 {
488        Self::to_f32(*self)
489    }
490    #[inline]
491    fn to_f64(&self) -> f64 {
492        Self::to_f64(*self)
493    }
494    #[inline]
495    fn to_bool(&self) -> bool {
496        *self != bf16::from_f32_const(0.0)
497    }
498}
499
500#[cfg(feature = "cubecl")]
501impl ToElement for cubecl::flex32 {
502    #[inline]
503    fn to_i64(&self) -> i64 {
504        Self::to_f32(*self).to_i64()
505    }
506    #[inline]
507    fn to_u64(&self) -> u64 {
508        Self::to_f32(*self).to_u64()
509    }
510    #[inline]
511    fn to_i8(&self) -> i8 {
512        Self::to_f32(*self).to_i8()
513    }
514    #[inline]
515    fn to_u8(&self) -> u8 {
516        Self::to_f32(*self).to_u8()
517    }
518    #[inline]
519    fn to_i16(&self) -> i16 {
520        Self::to_f32(*self).to_i16()
521    }
522    #[inline]
523    fn to_u16(&self) -> u16 {
524        Self::to_f32(*self).to_u16()
525    }
526    #[inline]
527    fn to_i32(&self) -> i32 {
528        Self::to_f32(*self).to_i32()
529    }
530    #[inline]
531    fn to_u32(&self) -> u32 {
532        Self::to_f32(*self).to_u32()
533    }
534    #[inline]
535    fn to_f32(&self) -> f32 {
536        Self::to_f32(*self)
537    }
538    #[inline]
539    fn to_f64(&self) -> f64 {
540        Self::to_f64(*self)
541    }
542    #[inline]
543    fn to_bool(&self) -> bool {
544        *self != cubecl::flex32::from_f32(0.0)
545    }
546}
547
548impl ToElement for bool {
549    #[inline]
550    fn to_i64(&self) -> i64 {
551        *self as i64
552    }
553    #[inline]
554    fn to_u64(&self) -> u64 {
555        *self as u64
556    }
557    #[inline]
558    fn to_i8(&self) -> i8 {
559        *self as i8
560    }
561    #[inline]
562    fn to_u8(&self) -> u8 {
563        *self as u8
564    }
565    #[inline]
566    fn to_i16(&self) -> i16 {
567        *self as i16
568    }
569    #[inline]
570    fn to_u16(&self) -> u16 {
571        *self as u16
572    }
573    #[inline]
574    fn to_i32(&self) -> i32 {
575        *self as i32
576    }
577    #[inline]
578    fn to_u32(&self) -> u32 {
579        *self as u32
580    }
581    #[inline]
582    fn to_f32(&self) -> f32 {
583        self.to_u8() as f32
584    }
585    #[inline]
586    fn to_f64(&self) -> f64 {
587        self.to_u8() as f64
588    }
589    #[inline]
590    fn to_bool(&self) -> bool {
591        *self
592    }
593}
594
595mod tests {
596    #[allow(unused_imports)]
597    use super::*;
598
599    #[test]
600    fn to_element_float() {
601        let f32_toolarge = 1e39f64;
602        assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
603        assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
604        assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
605        assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
606        assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
607        assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
608        assert!((f64::NAN).to_f32().is_nan());
609    }
610
611    #[test]
612    #[should_panic]
613    fn to_element_signed_to_u8_underflow() {
614        let _x = (-1i8).to_u8();
615    }
616
617    #[test]
618    #[should_panic]
619    fn to_element_signed_to_u16_underflow() {
620        let _x = (-1i8).to_u16();
621    }
622
623    #[test]
624    #[should_panic]
625    fn to_element_signed_to_u32_underflow() {
626        let _x = (-1i8).to_u32();
627    }
628
629    #[test]
630    #[should_panic]
631    fn to_element_signed_to_u64_underflow() {
632        let _x = (-1i8).to_u64();
633    }
634
635    #[test]
636    #[should_panic]
637    fn to_element_signed_to_u128_underflow() {
638        let _x = (-1i8).to_u128();
639    }
640
641    #[test]
642    #[should_panic]
643    fn to_element_signed_to_usize_underflow() {
644        let _x = (-1i8).to_usize();
645    }
646
647    #[test]
648    #[should_panic]
649    fn to_element_unsigned_to_u8_overflow() {
650        let _x = 256.to_u8();
651    }
652
653    #[test]
654    #[should_panic]
655    fn to_element_unsigned_to_u16_overflow() {
656        let _x = 65_536.to_u16();
657    }
658
659    #[test]
660    #[should_panic]
661    fn to_element_unsigned_to_u32_overflow() {
662        let _x = 4_294_967_296u64.to_u32();
663    }
664
665    #[test]
666    #[should_panic]
667    fn to_element_unsigned_to_u64_overflow() {
668        let _x = 18_446_744_073_709_551_616u128.to_u64();
669    }
670
671    #[test]
672    fn to_element_int_to_float() {
673        assert_eq!((-1).to_f32(), -1.0);
674        assert_eq!((-1).to_f64(), -1.0);
675        assert_eq!(255.to_f32(), 255.0);
676        assert_eq!(65_535.to_f64(), 65_535.0);
677    }
678
679    #[test]
680    fn to_element_float_to_int() {
681        assert_eq!((-1.0).to_i8(), -1);
682        assert_eq!(1.0.to_u8(), 1);
683        assert_eq!(1.8.to_u16(), 1);
684        assert_eq!(123.456.to_u32(), 123);
685    }
686}