burn_tensor/tensor/element/
cast.rs

1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5/// A generic trait for converting a value to a number.
6/// Adapted from [num_traits::ToPrimitive] to support [bool].
7///
8/// A value can be represented by the target type when it lies within
9/// the range of scalars supported by the target type.
10/// For example, a negative integer cannot be represented by an unsigned
11/// integer type, and an `i64` with a very high magnitude might not be
12/// convertible to an `i32`.
13/// On the other hand, conversions with possible precision loss or truncation
14/// are admitted, like an `f32` with a decimal part to an integer type, or
15/// even a large `f64` saturating to `f32` infinity.
16///
17/// The methods *panic* when the value cannot be represented by the target type.
18pub trait ToElement {
19    /// Converts the value of `self` to an `isize`.
20    #[inline]
21    fn to_isize(&self) -> isize {
22        ToElement::to_isize(&self.to_i64())
23    }
24
25    /// Converts the value of `self` to an `i8`.
26    #[inline]
27    fn to_i8(&self) -> i8 {
28        ToElement::to_i8(&self.to_i64())
29    }
30
31    /// Converts the value of `self` to an `i16`.
32    #[inline]
33    fn to_i16(&self) -> i16 {
34        ToElement::to_i16(&self.to_i64())
35    }
36
37    /// Converts the value of `self` to an `i32`.
38    #[inline]
39    fn to_i32(&self) -> i32 {
40        ToElement::to_i32(&self.to_i64())
41    }
42
43    /// Converts the value of `self` to an `i64`.
44    fn to_i64(&self) -> i64;
45
46    /// Converts the value of `self` to an `i128`.
47    ///
48    /// The default implementation converts through `to_i64()`. Types implementing
49    /// this trait should override this method if they can represent a greater range.
50    #[inline]
51    fn to_i128(&self) -> i128 {
52        i128::from(self.to_i64())
53    }
54
55    /// Converts the value of `self` to a `usize`.
56    #[inline]
57    fn to_usize(&self) -> usize {
58        ToElement::to_usize(&self.to_u64())
59    }
60
61    /// Converts the value of `self` to a `u8`.
62    #[inline]
63    fn to_u8(&self) -> u8 {
64        ToElement::to_u8(&self.to_u64())
65    }
66
67    /// Converts the value of `self` to a `u16`.
68    #[inline]
69    fn to_u16(&self) -> u16 {
70        ToElement::to_u16(&self.to_u64())
71    }
72
73    /// Converts the value of `self` to a `u32`.
74    #[inline]
75    fn to_u32(&self) -> u32 {
76        ToElement::to_u32(&self.to_u64())
77    }
78
79    /// Converts the value of `self` to a `u64`.
80    fn to_u64(&self) -> u64;
81
82    /// Converts the value of `self` to a `u128`.
83    ///
84    /// The default implementation converts through `to_u64()`. Types implementing
85    /// this trait should override this method if they can represent a greater range.
86    #[inline]
87    fn to_u128(&self) -> u128 {
88        u128::from(self.to_u64())
89    }
90
91    /// Converts the value of `self` to an `f16`. Overflows may map to positive
92    /// or negative infinity.
93    #[inline]
94    fn to_f16(&self) -> f16 {
95        f16::from_f32(self.to_f32())
96    }
97
98    /// Converts the value of `self` to an `bf16`. Overflows may map to positive
99    /// or negative infinity.
100    #[inline]
101    fn to_bf16(&self) -> bf16 {
102        bf16::from_f32(self.to_f32())
103    }
104
105    /// Converts the value of `self` to an `f32`. Overflows may map to positive
106    /// or negative infinity.
107    #[inline]
108    fn to_f32(&self) -> f32 {
109        ToElement::to_f32(&self.to_f64())
110    }
111
112    /// Converts the value of `self` to an `f64`. Overflows may map to positive
113    /// or negative infinity.
114    ///
115    /// The default implementation tries to convert through `to_i64()`, and
116    /// failing that through `to_u64()`. Types implementing this trait should
117    /// override this method if they can represent a greater range.
118    #[inline]
119    fn to_f64(&self) -> f64 {
120        ToElement::to_f64(&self.to_u64())
121    }
122
123    /// Converts the value of `self` to a bool.
124    /// Rust only considers 0 and 1 to be valid booleans, but for compatibility, C semantics are
125    /// adopted (anything that's not 0 is true).
126    ///
127    /// The default implementation tries to convert through `to_i64()`, and
128    /// failing that through `to_u64()`. Types implementing this trait should
129    /// override this method if they can represent a greater range.
130    #[inline]
131    fn to_bool(&self) -> bool {
132        ToElement::to_bool(&self.to_u64())
133    }
134}
135
136macro_rules! impl_to_element_int_to_int {
137    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
138        #[inline]
139        $(#[$cfg])*
140        fn $method(&self) -> $DstT {
141            let min = $DstT::MIN as $SrcT;
142            let max = $DstT::MAX as $SrcT;
143            if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
144                *self as $DstT
145            } else {
146                panic!(
147                    "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
148                    core::any::type_name::<$SrcT>(),
149                    self,
150                    core::any::type_name::<$DstT>(),
151                )
152            }
153        }
154    )*}
155}
156
157macro_rules! impl_to_element_int_to_uint {
158    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
159        #[inline]
160        $(#[$cfg])*
161        fn $method(&self) -> $DstT {
162            let max = $DstT::MAX as $SrcT;
163            if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
164                *self as $DstT
165            } else {
166                panic!(
167                    "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
168                    core::any::type_name::<$SrcT>(),
169                    self,
170                    core::any::type_name::<$DstT>(),
171                )
172            }
173        }
174    )*}
175}
176
177macro_rules! impl_to_element_int {
178    ($T:ident) => {
179        impl ToElement for $T {
180            impl_to_element_int_to_int! { $T:
181                fn to_isize -> isize;
182                fn to_i8 -> i8;
183                fn to_i16 -> i16;
184                fn to_i32 -> i32;
185                fn to_i64 -> i64;
186                fn to_i128 -> i128;
187            }
188
189            impl_to_element_int_to_uint! { $T:
190                fn to_usize -> usize;
191                fn to_u8 -> u8;
192                fn to_u16 -> u16;
193                fn to_u32 -> u32;
194                fn to_u64 -> u64;
195                fn to_u128 -> u128;
196            }
197
198            #[inline]
199            fn to_f32(&self) -> f32 {
200                *self as f32
201            }
202            #[inline]
203            fn to_f64(&self) -> f64 {
204                *self as f64
205            }
206            #[inline]
207            fn to_bool(&self) -> bool {
208                *self != 0
209            }
210        }
211    };
212}
213
214impl_to_element_int!(isize);
215impl_to_element_int!(i8);
216impl_to_element_int!(i16);
217impl_to_element_int!(i32);
218impl_to_element_int!(i64);
219impl_to_element_int!(i128);
220
221macro_rules! impl_to_element_uint_to_int {
222    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
223        #[inline]
224        $(#[$cfg])*
225        fn $method(&self) -> $DstT {
226            let max = $DstT::MAX as $SrcT;
227            if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
228                *self as $DstT
229            } else {
230                panic!(
231                    "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
232                    core::any::type_name::<$SrcT>(),
233                    self,
234                    core::any::type_name::<$DstT>(),
235                )
236            }
237        }
238    )*}
239}
240
241macro_rules! impl_to_element_uint_to_uint {
242    ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
243        #[inline]
244        $(#[$cfg])*
245        fn $method(&self) -> $DstT {
246            let max = $DstT::MAX as $SrcT;
247            if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
248                *self as $DstT
249            } else {
250                panic!(
251                    "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
252                    core::any::type_name::<$SrcT>(),
253                    self,
254                    core::any::type_name::<$DstT>(),
255                )
256            }
257        }
258    )*}
259}
260
261macro_rules! impl_to_element_uint {
262    ($T:ident) => {
263        impl ToElement for $T {
264            impl_to_element_uint_to_int! { $T:
265                fn to_isize -> isize;
266                fn to_i8 -> i8;
267                fn to_i16 -> i16;
268                fn to_i32 -> i32;
269                fn to_i64 -> i64;
270                fn to_i128 -> i128;
271            }
272
273            impl_to_element_uint_to_uint! { $T:
274                fn to_usize -> usize;
275                fn to_u8 -> u8;
276                fn to_u16 -> u16;
277                fn to_u32 -> u32;
278                fn to_u64 -> u64;
279                fn to_u128 -> u128;
280            }
281
282            #[inline]
283            fn to_f32(&self) -> f32 {
284                *self as f32
285            }
286            #[inline]
287            fn to_f64(&self) -> f64 {
288                *self as f64
289            }
290            #[inline]
291            fn to_bool(&self) -> bool {
292                *self != 0
293            }
294        }
295    };
296}
297
298impl_to_element_uint!(usize);
299impl_to_element_uint!(u8);
300impl_to_element_uint!(u16);
301impl_to_element_uint!(u32);
302impl_to_element_uint!(u64);
303impl_to_element_uint!(u128);
304
305macro_rules! impl_to_element_float_to_float {
306    ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
307        #[inline]
308        fn $method(&self) -> $DstT {
309            // We can safely cast all values, whether NaN, +-inf, or finite.
310            // Finite values that are reducing size may saturate to +-inf.
311            *self as $DstT
312        }
313    )*}
314}
315
316macro_rules! float_to_int_unchecked {
317    // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
318    // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
319    ($float:expr => $int:ty) => {
320        unsafe { $float.to_int_unchecked::<$int>() }
321    };
322}
323
324macro_rules! impl_to_element_float_to_signed_int {
325    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
326        #[inline]
327        $(#[$cfg])*
328        fn $method(&self) -> $i {
329            // Float as int truncates toward zero, so we want to allow values
330            // in the exclusive range `(MIN-1, MAX+1)`.
331            if size_of::<$f>() > size_of::<$i>() {
332                // With a larger size, we can represent the range exactly.
333                const MIN_M1: $f = $i::MIN as $f - 1.0;
334                const MAX_P1: $f = $i::MAX as $f + 1.0;
335                if *self > MIN_M1 && *self < MAX_P1 {
336                    return float_to_int_unchecked!(*self => $i);
337                }
338            } else {
339                // We can't represent `MIN-1` exactly, but there's no fractional part
340                // at this magnitude, so we can just use a `MIN` inclusive boundary.
341                const MIN: $f = $i::MIN as $f;
342                // We can't represent `MAX` exactly, but it will round up to exactly
343                // `MAX+1` (a power of two) when we cast it.
344                const MAX_P1: $f = $i::MAX as $f;
345                if *self >= MIN && *self < MAX_P1 {
346                    return float_to_int_unchecked!(*self => $i);
347                }
348            }
349            panic!("Float cannot be represented in the target signed int type")
350        }
351    )*}
352}
353
354macro_rules! impl_to_element_float_to_unsigned_int {
355    ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
356        #[inline]
357        $(#[$cfg])*
358        fn $method(&self) -> $u {
359            // Float as int truncates toward zero, so we want to allow values
360            // in the exclusive range `(-1, MAX+1)`.
361            if size_of::<$f>() > size_of::<$u>() {
362                // With a larger size, we can represent the range exactly.
363                const MAX_P1: $f = $u::MAX as $f + 1.0;
364                if *self > -1.0 && *self < MAX_P1 {
365                    return float_to_int_unchecked!(*self => $u);
366                }
367            } else {
368                // We can't represent `MAX` exactly, but it will round up to exactly
369                // `MAX+1` (a power of two) when we cast it.
370                // (`u128::MAX as f32` is infinity, but this is still ok.)
371                const MAX_P1: $f = $u::MAX as $f;
372                if *self > -1.0 && *self < MAX_P1 {
373                    return float_to_int_unchecked!(*self => $u);
374                }
375            }
376            panic!("Float cannot be represented in the target unsigned int type")
377        }
378    )*}
379}
380
381macro_rules! impl_to_element_float {
382    ($T:ident) => {
383        impl ToElement for $T {
384            impl_to_element_float_to_signed_int! { $T:
385                fn to_isize -> isize;
386                fn to_i8 -> i8;
387                fn to_i16 -> i16;
388                fn to_i32 -> i32;
389                fn to_i64 -> i64;
390                fn to_i128 -> i128;
391            }
392
393            impl_to_element_float_to_unsigned_int! { $T:
394                fn to_usize -> usize;
395                fn to_u8 -> u8;
396                fn to_u16 -> u16;
397                fn to_u32 -> u32;
398                fn to_u64 -> u64;
399                fn to_u128 -> u128;
400            }
401
402            impl_to_element_float_to_float! { $T:
403                fn to_f32 -> f32;
404                fn to_f64 -> f64;
405            }
406
407            #[inline]
408            fn to_bool(&self) -> bool {
409                *self != 0.0
410            }
411        }
412    };
413}
414
415impl_to_element_float!(f32);
416impl_to_element_float!(f64);
417
418impl ToElement for f16 {
419    #[inline]
420    fn to_i64(&self) -> i64 {
421        Self::to_f32(*self).to_i64()
422    }
423    #[inline]
424    fn to_u64(&self) -> u64 {
425        Self::to_f32(*self).to_u64()
426    }
427    #[inline]
428    fn to_i8(&self) -> i8 {
429        Self::to_f32(*self).to_i8()
430    }
431    #[inline]
432    fn to_u8(&self) -> u8 {
433        Self::to_f32(*self).to_u8()
434    }
435    #[inline]
436    fn to_i16(&self) -> i16 {
437        Self::to_f32(*self).to_i16()
438    }
439    #[inline]
440    fn to_u16(&self) -> u16 {
441        Self::to_f32(*self).to_u16()
442    }
443    #[inline]
444    fn to_i32(&self) -> i32 {
445        Self::to_f32(*self).to_i32()
446    }
447    #[inline]
448    fn to_u32(&self) -> u32 {
449        Self::to_f32(*self).to_u32()
450    }
451    #[inline]
452    fn to_f16(&self) -> f16 {
453        *self
454    }
455    #[inline]
456    fn to_f32(&self) -> f32 {
457        Self::to_f32(*self)
458    }
459    #[inline]
460    fn to_f64(&self) -> f64 {
461        Self::to_f64(*self)
462    }
463    #[inline]
464    fn to_bool(&self) -> bool {
465        *self != f16::from_f32_const(0.0)
466    }
467}
468
469impl ToElement for bf16 {
470    #[inline]
471    fn to_i64(&self) -> i64 {
472        Self::to_f32(*self).to_i64()
473    }
474    #[inline]
475    fn to_u64(&self) -> u64 {
476        Self::to_f32(*self).to_u64()
477    }
478    #[inline]
479    fn to_i8(&self) -> i8 {
480        Self::to_f32(*self).to_i8()
481    }
482    #[inline]
483    fn to_u8(&self) -> u8 {
484        Self::to_f32(*self).to_u8()
485    }
486    #[inline]
487    fn to_i16(&self) -> i16 {
488        Self::to_f32(*self).to_i16()
489    }
490    #[inline]
491    fn to_u16(&self) -> u16 {
492        Self::to_f32(*self).to_u16()
493    }
494    #[inline]
495    fn to_i32(&self) -> i32 {
496        Self::to_f32(*self).to_i32()
497    }
498    #[inline]
499    fn to_u32(&self) -> u32 {
500        Self::to_f32(*self).to_u32()
501    }
502    #[inline]
503    fn to_bf16(&self) -> bf16 {
504        *self
505    }
506    #[inline]
507    fn to_f32(&self) -> f32 {
508        Self::to_f32(*self)
509    }
510    #[inline]
511    fn to_f64(&self) -> f64 {
512        Self::to_f64(*self)
513    }
514    #[inline]
515    fn to_bool(&self) -> bool {
516        *self != bf16::from_f32_const(0.0)
517    }
518}
519
520#[cfg(feature = "cubecl")]
521impl ToElement for cubecl::flex32 {
522    #[inline]
523    fn to_i64(&self) -> i64 {
524        Self::to_f32(*self).to_i64()
525    }
526    #[inline]
527    fn to_u64(&self) -> u64 {
528        Self::to_f32(*self).to_u64()
529    }
530    #[inline]
531    fn to_i8(&self) -> i8 {
532        Self::to_f32(*self).to_i8()
533    }
534    #[inline]
535    fn to_u8(&self) -> u8 {
536        Self::to_f32(*self).to_u8()
537    }
538    #[inline]
539    fn to_i16(&self) -> i16 {
540        Self::to_f32(*self).to_i16()
541    }
542    #[inline]
543    fn to_u16(&self) -> u16 {
544        Self::to_f32(*self).to_u16()
545    }
546    #[inline]
547    fn to_i32(&self) -> i32 {
548        Self::to_f32(*self).to_i32()
549    }
550    #[inline]
551    fn to_u32(&self) -> u32 {
552        Self::to_f32(*self).to_u32()
553    }
554    #[inline]
555    fn to_f32(&self) -> f32 {
556        Self::to_f32(*self)
557    }
558    #[inline]
559    fn to_f64(&self) -> f64 {
560        Self::to_f64(*self)
561    }
562    #[inline]
563    fn to_bool(&self) -> bool {
564        *self != cubecl::flex32::from_f32(0.0)
565    }
566}
567
568impl ToElement for bool {
569    #[inline]
570    fn to_i64(&self) -> i64 {
571        *self as i64
572    }
573    #[inline]
574    fn to_u64(&self) -> u64 {
575        *self as u64
576    }
577    #[inline]
578    fn to_i8(&self) -> i8 {
579        *self as i8
580    }
581    #[inline]
582    fn to_u8(&self) -> u8 {
583        *self as u8
584    }
585    #[inline]
586    fn to_i16(&self) -> i16 {
587        *self as i16
588    }
589    #[inline]
590    fn to_u16(&self) -> u16 {
591        *self as u16
592    }
593    #[inline]
594    fn to_i32(&self) -> i32 {
595        *self as i32
596    }
597    #[inline]
598    fn to_u32(&self) -> u32 {
599        *self as u32
600    }
601    #[inline]
602    fn to_f32(&self) -> f32 {
603        self.to_u8() as f32
604    }
605    #[inline]
606    fn to_f64(&self) -> f64 {
607        self.to_u8() as f64
608    }
609    #[inline]
610    fn to_bool(&self) -> bool {
611        *self
612    }
613}
614
615mod tests {
616    #[allow(unused_imports)]
617    use super::*;
618
619    #[test]
620    fn to_element_float() {
621        let f32_toolarge = 1e39f64;
622        assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
623        assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
624        assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
625        assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
626        assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
627        assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
628        assert!((f64::NAN).to_f32().is_nan());
629    }
630
631    #[test]
632    #[should_panic]
633    fn to_element_signed_to_u8_underflow() {
634        let _x = (-1i8).to_u8();
635    }
636
637    #[test]
638    #[should_panic]
639    fn to_element_signed_to_u16_underflow() {
640        let _x = (-1i8).to_u16();
641    }
642
643    #[test]
644    #[should_panic]
645    fn to_element_signed_to_u32_underflow() {
646        let _x = (-1i8).to_u32();
647    }
648
649    #[test]
650    #[should_panic]
651    fn to_element_signed_to_u64_underflow() {
652        let _x = (-1i8).to_u64();
653    }
654
655    #[test]
656    #[should_panic]
657    fn to_element_signed_to_u128_underflow() {
658        let _x = (-1i8).to_u128();
659    }
660
661    #[test]
662    #[should_panic]
663    fn to_element_signed_to_usize_underflow() {
664        let _x = (-1i8).to_usize();
665    }
666
667    #[test]
668    #[should_panic]
669    fn to_element_unsigned_to_u8_overflow() {
670        let _x = 256.to_u8();
671    }
672
673    #[test]
674    #[should_panic]
675    fn to_element_unsigned_to_u16_overflow() {
676        let _x = 65_536.to_u16();
677    }
678
679    #[test]
680    #[should_panic]
681    fn to_element_unsigned_to_u32_overflow() {
682        let _x = 4_294_967_296u64.to_u32();
683    }
684
685    #[test]
686    #[should_panic]
687    fn to_element_unsigned_to_u64_overflow() {
688        let _x = 18_446_744_073_709_551_616u128.to_u64();
689    }
690
691    #[test]
692    fn to_element_int_to_float() {
693        assert_eq!((-1).to_f32(), -1.0);
694        assert_eq!((-1).to_f64(), -1.0);
695        assert_eq!(255.to_f32(), 255.0);
696        assert_eq!(65_535.to_f64(), 65_535.0);
697    }
698
699    #[test]
700    fn to_element_float_to_int() {
701        assert_eq!((-1.0).to_i8(), -1);
702        assert_eq!(1.0.to_u8(), 1);
703        assert_eq!(1.8.to_u16(), 1);
704        assert_eq!(123.456.to_u32(), 123);
705    }
706}