burn_tensor/tensor/element/
cast.rs1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5pub trait ToElement {
19 #[inline]
21 fn to_isize(&self) -> isize {
22 ToElement::to_isize(&self.to_i64())
23 }
24
25 #[inline]
27 fn to_i8(&self) -> i8 {
28 ToElement::to_i8(&self.to_i64())
29 }
30
31 #[inline]
33 fn to_i16(&self) -> i16 {
34 ToElement::to_i16(&self.to_i64())
35 }
36
37 #[inline]
39 fn to_i32(&self) -> i32 {
40 ToElement::to_i32(&self.to_i64())
41 }
42
43 fn to_i64(&self) -> i64;
45
46 #[inline]
51 fn to_i128(&self) -> i128 {
52 i128::from(self.to_i64())
53 }
54
55 #[inline]
57 fn to_usize(&self) -> usize {
58 ToElement::to_usize(&self.to_u64())
59 }
60
61 #[inline]
63 fn to_u8(&self) -> u8 {
64 ToElement::to_u8(&self.to_u64())
65 }
66
67 #[inline]
69 fn to_u16(&self) -> u16 {
70 ToElement::to_u16(&self.to_u64())
71 }
72
73 #[inline]
75 fn to_u32(&self) -> u32 {
76 ToElement::to_u32(&self.to_u64())
77 }
78
79 fn to_u64(&self) -> u64;
81
82 #[inline]
87 fn to_u128(&self) -> u128 {
88 u128::from(self.to_u64())
89 }
90
91 #[inline]
94 fn to_f16(&self) -> f16 {
95 f16::from_f32(self.to_f32())
96 }
97
98 #[inline]
101 fn to_bf16(&self) -> bf16 {
102 bf16::from_f32(self.to_f32())
103 }
104
105 #[inline]
108 fn to_f32(&self) -> f32 {
109 ToElement::to_f32(&self.to_f64())
110 }
111
112 #[inline]
119 fn to_f64(&self) -> f64 {
120 ToElement::to_f64(&self.to_u64())
121 }
122
123 #[inline]
131 fn to_bool(&self) -> bool {
132 ToElement::to_bool(&self.to_u64())
133 }
134}
135
136macro_rules! impl_to_element_int_to_int {
137 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
138 #[inline]
139 $(#[$cfg])*
140 fn $method(&self) -> $DstT {
141 let min = $DstT::MIN as $SrcT;
142 let max = $DstT::MAX as $SrcT;
143 if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
144 *self as $DstT
145 } else {
146 panic!(
147 "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
148 core::any::type_name::<$SrcT>(),
149 self,
150 core::any::type_name::<$DstT>(),
151 )
152 }
153 }
154 )*}
155}
156
157macro_rules! impl_to_element_int_to_uint {
158 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
159 #[inline]
160 $(#[$cfg])*
161 fn $method(&self) -> $DstT {
162 let max = $DstT::MAX as $SrcT;
163 if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
164 *self as $DstT
165 } else {
166 panic!(
167 "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
168 core::any::type_name::<$SrcT>(),
169 self,
170 core::any::type_name::<$DstT>(),
171 )
172 }
173 }
174 )*}
175}
176
177macro_rules! impl_to_element_int {
178 ($T:ident) => {
179 impl ToElement for $T {
180 impl_to_element_int_to_int! { $T:
181 fn to_isize -> isize;
182 fn to_i8 -> i8;
183 fn to_i16 -> i16;
184 fn to_i32 -> i32;
185 fn to_i64 -> i64;
186 fn to_i128 -> i128;
187 }
188
189 impl_to_element_int_to_uint! { $T:
190 fn to_usize -> usize;
191 fn to_u8 -> u8;
192 fn to_u16 -> u16;
193 fn to_u32 -> u32;
194 fn to_u64 -> u64;
195 fn to_u128 -> u128;
196 }
197
198 #[inline]
199 fn to_f32(&self) -> f32 {
200 *self as f32
201 }
202 #[inline]
203 fn to_f64(&self) -> f64 {
204 *self as f64
205 }
206 #[inline]
207 fn to_bool(&self) -> bool {
208 *self != 0
209 }
210 }
211 };
212}
213
214impl_to_element_int!(isize);
215impl_to_element_int!(i8);
216impl_to_element_int!(i16);
217impl_to_element_int!(i32);
218impl_to_element_int!(i64);
219impl_to_element_int!(i128);
220
221macro_rules! impl_to_element_uint_to_int {
222 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
223 #[inline]
224 $(#[$cfg])*
225 fn $method(&self) -> $DstT {
226 let max = $DstT::MAX as $SrcT;
227 if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
228 *self as $DstT
229 } else {
230 panic!(
231 "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
232 core::any::type_name::<$SrcT>(),
233 self,
234 core::any::type_name::<$DstT>(),
235 )
236 }
237 }
238 )*}
239}
240
241macro_rules! impl_to_element_uint_to_uint {
242 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
243 #[inline]
244 $(#[$cfg])*
245 fn $method(&self) -> $DstT {
246 let max = $DstT::MAX as $SrcT;
247 if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
248 *self as $DstT
249 } else {
250 panic!(
251 "Element cannot be represented in the target type: {:?}({:?}) => {:?}",
252 core::any::type_name::<$SrcT>(),
253 self,
254 core::any::type_name::<$DstT>(),
255 )
256 }
257 }
258 )*}
259}
260
261macro_rules! impl_to_element_uint {
262 ($T:ident) => {
263 impl ToElement for $T {
264 impl_to_element_uint_to_int! { $T:
265 fn to_isize -> isize;
266 fn to_i8 -> i8;
267 fn to_i16 -> i16;
268 fn to_i32 -> i32;
269 fn to_i64 -> i64;
270 fn to_i128 -> i128;
271 }
272
273 impl_to_element_uint_to_uint! { $T:
274 fn to_usize -> usize;
275 fn to_u8 -> u8;
276 fn to_u16 -> u16;
277 fn to_u32 -> u32;
278 fn to_u64 -> u64;
279 fn to_u128 -> u128;
280 }
281
282 #[inline]
283 fn to_f32(&self) -> f32 {
284 *self as f32
285 }
286 #[inline]
287 fn to_f64(&self) -> f64 {
288 *self as f64
289 }
290 #[inline]
291 fn to_bool(&self) -> bool {
292 *self != 0
293 }
294 }
295 };
296}
297
298impl_to_element_uint!(usize);
299impl_to_element_uint!(u8);
300impl_to_element_uint!(u16);
301impl_to_element_uint!(u32);
302impl_to_element_uint!(u64);
303impl_to_element_uint!(u128);
304
305macro_rules! impl_to_element_float_to_float {
306 ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
307 #[inline]
308 fn $method(&self) -> $DstT {
309 *self as $DstT
312 }
313 )*}
314}
315
316macro_rules! float_to_int_unchecked {
317 ($float:expr => $int:ty) => {
320 unsafe { $float.to_int_unchecked::<$int>() }
321 };
322}
323
324macro_rules! impl_to_element_float_to_signed_int {
325 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
326 #[inline]
327 $(#[$cfg])*
328 fn $method(&self) -> $i {
329 if size_of::<$f>() > size_of::<$i>() {
332 const MIN_M1: $f = $i::MIN as $f - 1.0;
334 const MAX_P1: $f = $i::MAX as $f + 1.0;
335 if *self > MIN_M1 && *self < MAX_P1 {
336 return float_to_int_unchecked!(*self => $i);
337 }
338 } else {
339 const MIN: $f = $i::MIN as $f;
342 const MAX_P1: $f = $i::MAX as $f;
345 if *self >= MIN && *self < MAX_P1 {
346 return float_to_int_unchecked!(*self => $i);
347 }
348 }
349 panic!("Float cannot be represented in the target signed int type")
350 }
351 )*}
352}
353
354macro_rules! impl_to_element_float_to_unsigned_int {
355 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
356 #[inline]
357 $(#[$cfg])*
358 fn $method(&self) -> $u {
359 if size_of::<$f>() > size_of::<$u>() {
362 const MAX_P1: $f = $u::MAX as $f + 1.0;
364 if *self > -1.0 && *self < MAX_P1 {
365 return float_to_int_unchecked!(*self => $u);
366 }
367 } else {
368 const MAX_P1: $f = $u::MAX as $f;
372 if *self > -1.0 && *self < MAX_P1 {
373 return float_to_int_unchecked!(*self => $u);
374 }
375 }
376 panic!("Float cannot be represented in the target unsigned int type")
377 }
378 )*}
379}
380
381macro_rules! impl_to_element_float {
382 ($T:ident) => {
383 impl ToElement for $T {
384 impl_to_element_float_to_signed_int! { $T:
385 fn to_isize -> isize;
386 fn to_i8 -> i8;
387 fn to_i16 -> i16;
388 fn to_i32 -> i32;
389 fn to_i64 -> i64;
390 fn to_i128 -> i128;
391 }
392
393 impl_to_element_float_to_unsigned_int! { $T:
394 fn to_usize -> usize;
395 fn to_u8 -> u8;
396 fn to_u16 -> u16;
397 fn to_u32 -> u32;
398 fn to_u64 -> u64;
399 fn to_u128 -> u128;
400 }
401
402 impl_to_element_float_to_float! { $T:
403 fn to_f32 -> f32;
404 fn to_f64 -> f64;
405 }
406
407 #[inline]
408 fn to_bool(&self) -> bool {
409 *self != 0.0
410 }
411 }
412 };
413}
414
415impl_to_element_float!(f32);
416impl_to_element_float!(f64);
417
418impl ToElement for f16 {
419 #[inline]
420 fn to_i64(&self) -> i64 {
421 Self::to_f32(*self).to_i64()
422 }
423 #[inline]
424 fn to_u64(&self) -> u64 {
425 Self::to_f32(*self).to_u64()
426 }
427 #[inline]
428 fn to_i8(&self) -> i8 {
429 Self::to_f32(*self).to_i8()
430 }
431 #[inline]
432 fn to_u8(&self) -> u8 {
433 Self::to_f32(*self).to_u8()
434 }
435 #[inline]
436 fn to_i16(&self) -> i16 {
437 Self::to_f32(*self).to_i16()
438 }
439 #[inline]
440 fn to_u16(&self) -> u16 {
441 Self::to_f32(*self).to_u16()
442 }
443 #[inline]
444 fn to_i32(&self) -> i32 {
445 Self::to_f32(*self).to_i32()
446 }
447 #[inline]
448 fn to_u32(&self) -> u32 {
449 Self::to_f32(*self).to_u32()
450 }
451 #[inline]
452 fn to_f16(&self) -> f16 {
453 *self
454 }
455 #[inline]
456 fn to_f32(&self) -> f32 {
457 Self::to_f32(*self)
458 }
459 #[inline]
460 fn to_f64(&self) -> f64 {
461 Self::to_f64(*self)
462 }
463 #[inline]
464 fn to_bool(&self) -> bool {
465 *self != f16::from_f32_const(0.0)
466 }
467}
468
469impl ToElement for bf16 {
470 #[inline]
471 fn to_i64(&self) -> i64 {
472 Self::to_f32(*self).to_i64()
473 }
474 #[inline]
475 fn to_u64(&self) -> u64 {
476 Self::to_f32(*self).to_u64()
477 }
478 #[inline]
479 fn to_i8(&self) -> i8 {
480 Self::to_f32(*self).to_i8()
481 }
482 #[inline]
483 fn to_u8(&self) -> u8 {
484 Self::to_f32(*self).to_u8()
485 }
486 #[inline]
487 fn to_i16(&self) -> i16 {
488 Self::to_f32(*self).to_i16()
489 }
490 #[inline]
491 fn to_u16(&self) -> u16 {
492 Self::to_f32(*self).to_u16()
493 }
494 #[inline]
495 fn to_i32(&self) -> i32 {
496 Self::to_f32(*self).to_i32()
497 }
498 #[inline]
499 fn to_u32(&self) -> u32 {
500 Self::to_f32(*self).to_u32()
501 }
502 #[inline]
503 fn to_bf16(&self) -> bf16 {
504 *self
505 }
506 #[inline]
507 fn to_f32(&self) -> f32 {
508 Self::to_f32(*self)
509 }
510 #[inline]
511 fn to_f64(&self) -> f64 {
512 Self::to_f64(*self)
513 }
514 #[inline]
515 fn to_bool(&self) -> bool {
516 *self != bf16::from_f32_const(0.0)
517 }
518}
519
520#[cfg(feature = "cubecl")]
521impl ToElement for cubecl::flex32 {
522 #[inline]
523 fn to_i64(&self) -> i64 {
524 Self::to_f32(*self).to_i64()
525 }
526 #[inline]
527 fn to_u64(&self) -> u64 {
528 Self::to_f32(*self).to_u64()
529 }
530 #[inline]
531 fn to_i8(&self) -> i8 {
532 Self::to_f32(*self).to_i8()
533 }
534 #[inline]
535 fn to_u8(&self) -> u8 {
536 Self::to_f32(*self).to_u8()
537 }
538 #[inline]
539 fn to_i16(&self) -> i16 {
540 Self::to_f32(*self).to_i16()
541 }
542 #[inline]
543 fn to_u16(&self) -> u16 {
544 Self::to_f32(*self).to_u16()
545 }
546 #[inline]
547 fn to_i32(&self) -> i32 {
548 Self::to_f32(*self).to_i32()
549 }
550 #[inline]
551 fn to_u32(&self) -> u32 {
552 Self::to_f32(*self).to_u32()
553 }
554 #[inline]
555 fn to_f32(&self) -> f32 {
556 Self::to_f32(*self)
557 }
558 #[inline]
559 fn to_f64(&self) -> f64 {
560 Self::to_f64(*self)
561 }
562 #[inline]
563 fn to_bool(&self) -> bool {
564 *self != cubecl::flex32::from_f32(0.0)
565 }
566}
567
568impl ToElement for bool {
569 #[inline]
570 fn to_i64(&self) -> i64 {
571 *self as i64
572 }
573 #[inline]
574 fn to_u64(&self) -> u64 {
575 *self as u64
576 }
577 #[inline]
578 fn to_i8(&self) -> i8 {
579 *self as i8
580 }
581 #[inline]
582 fn to_u8(&self) -> u8 {
583 *self as u8
584 }
585 #[inline]
586 fn to_i16(&self) -> i16 {
587 *self as i16
588 }
589 #[inline]
590 fn to_u16(&self) -> u16 {
591 *self as u16
592 }
593 #[inline]
594 fn to_i32(&self) -> i32 {
595 *self as i32
596 }
597 #[inline]
598 fn to_u32(&self) -> u32 {
599 *self as u32
600 }
601 #[inline]
602 fn to_f32(&self) -> f32 {
603 self.to_u8() as f32
604 }
605 #[inline]
606 fn to_f64(&self) -> f64 {
607 self.to_u8() as f64
608 }
609 #[inline]
610 fn to_bool(&self) -> bool {
611 *self
612 }
613}
614
615mod tests {
616 #[allow(unused_imports)]
617 use super::*;
618
619 #[test]
620 fn to_element_float() {
621 let f32_toolarge = 1e39f64;
622 assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
623 assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
624 assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
625 assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
626 assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
627 assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
628 assert!((f64::NAN).to_f32().is_nan());
629 }
630
631 #[test]
632 #[should_panic]
633 fn to_element_signed_to_u8_underflow() {
634 let _x = (-1i8).to_u8();
635 }
636
637 #[test]
638 #[should_panic]
639 fn to_element_signed_to_u16_underflow() {
640 let _x = (-1i8).to_u16();
641 }
642
643 #[test]
644 #[should_panic]
645 fn to_element_signed_to_u32_underflow() {
646 let _x = (-1i8).to_u32();
647 }
648
649 #[test]
650 #[should_panic]
651 fn to_element_signed_to_u64_underflow() {
652 let _x = (-1i8).to_u64();
653 }
654
655 #[test]
656 #[should_panic]
657 fn to_element_signed_to_u128_underflow() {
658 let _x = (-1i8).to_u128();
659 }
660
661 #[test]
662 #[should_panic]
663 fn to_element_signed_to_usize_underflow() {
664 let _x = (-1i8).to_usize();
665 }
666
667 #[test]
668 #[should_panic]
669 fn to_element_unsigned_to_u8_overflow() {
670 let _x = 256.to_u8();
671 }
672
673 #[test]
674 #[should_panic]
675 fn to_element_unsigned_to_u16_overflow() {
676 let _x = 65_536.to_u16();
677 }
678
679 #[test]
680 #[should_panic]
681 fn to_element_unsigned_to_u32_overflow() {
682 let _x = 4_294_967_296u64.to_u32();
683 }
684
685 #[test]
686 #[should_panic]
687 fn to_element_unsigned_to_u64_overflow() {
688 let _x = 18_446_744_073_709_551_616u128.to_u64();
689 }
690
691 #[test]
692 fn to_element_int_to_float() {
693 assert_eq!((-1).to_f32(), -1.0);
694 assert_eq!((-1).to_f64(), -1.0);
695 assert_eq!(255.to_f32(), 255.0);
696 assert_eq!(65_535.to_f64(), 65_535.0);
697 }
698
699 #[test]
700 fn to_element_float_to_int() {
701 assert_eq!((-1.0).to_i8(), -1);
702 assert_eq!(1.0.to_u8(), 1);
703 assert_eq!(1.8.to_u16(), 1);
704 assert_eq!(123.456.to_u32(), 123);
705 }
706}