burn_tensor/tensor/element/
cast.rs1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5pub trait ToElement {
19 #[inline]
21 fn to_isize(&self) -> isize {
22 ToElement::to_isize(&self.to_i64())
23 }
24
25 #[inline]
27 fn to_i8(&self) -> i8 {
28 ToElement::to_i8(&self.to_i64())
29 }
30
31 #[inline]
33 fn to_i16(&self) -> i16 {
34 ToElement::to_i16(&self.to_i64())
35 }
36
37 #[inline]
39 fn to_i32(&self) -> i32 {
40 ToElement::to_i32(&self.to_i64())
41 }
42
43 fn to_i64(&self) -> i64;
45
46 #[inline]
51 fn to_i128(&self) -> i128 {
52 i128::from(self.to_i64())
53 }
54
55 #[inline]
57 fn to_usize(&self) -> usize {
58 ToElement::to_usize(&self.to_u64())
59 }
60
61 #[inline]
63 fn to_u8(&self) -> u8 {
64 ToElement::to_u8(&self.to_u64())
65 }
66
67 #[inline]
69 fn to_u16(&self) -> u16 {
70 ToElement::to_u16(&self.to_u64())
71 }
72
73 #[inline]
75 fn to_u32(&self) -> u32 {
76 ToElement::to_u32(&self.to_u64())
77 }
78
79 fn to_u64(&self) -> u64;
81
82 #[inline]
87 fn to_u128(&self) -> u128 {
88 u128::from(self.to_u64())
89 }
90
91 #[inline]
94 fn to_f16(&self) -> f16 {
95 f16::from_f32(self.to_f32())
96 }
97
98 #[inline]
101 fn to_bf16(&self) -> bf16 {
102 bf16::from_f32(self.to_f32())
103 }
104
105 #[inline]
108 fn to_f32(&self) -> f32 {
109 ToElement::to_f32(&self.to_f64())
110 }
111
112 #[inline]
119 fn to_f64(&self) -> f64 {
120 ToElement::to_f64(&self.to_u64())
121 }
122
123 #[inline]
131 fn to_bool(&self) -> bool {
132 ToElement::to_bool(&self.to_u64())
133 }
134}
135
136macro_rules! impl_to_element_int_to_int {
137 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
138 #[inline]
139 $(#[$cfg])*
140 fn $method(&self) -> $DstT {
141 let min = $DstT::MIN as $SrcT;
142 let max = $DstT::MAX as $SrcT;
143 if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
144 *self as $DstT
145 } else {
146 panic!("Element cannot be represented in the target type")
147 }
148 }
149 )*}
150}
151
152macro_rules! impl_to_element_int_to_uint {
153 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
154 #[inline]
155 $(#[$cfg])*
156 fn $method(&self) -> $DstT {
157 let max = $DstT::MAX as $SrcT;
158 if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
159 *self as $DstT
160 } else {
161 panic!("Element cannot be represented in the target type")
162 }
163 }
164 )*}
165}
166
167macro_rules! impl_to_element_int {
168 ($T:ident) => {
169 impl ToElement for $T {
170 impl_to_element_int_to_int! { $T:
171 fn to_isize -> isize;
172 fn to_i8 -> i8;
173 fn to_i16 -> i16;
174 fn to_i32 -> i32;
175 fn to_i64 -> i64;
176 fn to_i128 -> i128;
177 }
178
179 impl_to_element_int_to_uint! { $T:
180 fn to_usize -> usize;
181 fn to_u8 -> u8;
182 fn to_u16 -> u16;
183 fn to_u32 -> u32;
184 fn to_u64 -> u64;
185 fn to_u128 -> u128;
186 }
187
188 #[inline]
189 fn to_f32(&self) -> f32 {
190 *self as f32
191 }
192 #[inline]
193 fn to_f64(&self) -> f64 {
194 *self as f64
195 }
196 #[inline]
197 fn to_bool(&self) -> bool {
198 *self != 0
199 }
200 }
201 };
202}
203
204impl_to_element_int!(isize);
205impl_to_element_int!(i8);
206impl_to_element_int!(i16);
207impl_to_element_int!(i32);
208impl_to_element_int!(i64);
209impl_to_element_int!(i128);
210
211macro_rules! impl_to_element_uint_to_int {
212 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
213 #[inline]
214 $(#[$cfg])*
215 fn $method(&self) -> $DstT {
216 let max = $DstT::MAX as $SrcT;
217 if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
218 *self as $DstT
219 } else {
220 panic!("Element cannot be represented in the target type")
221 }
222 }
223 )*}
224}
225
226macro_rules! impl_to_element_uint_to_uint {
227 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
228 #[inline]
229 $(#[$cfg])*
230 fn $method(&self) -> $DstT {
231 let max = $DstT::MAX as $SrcT;
232 if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
233 *self as $DstT
234 } else {
235 panic!("Element cannot be represented in the target type")
236 }
237 }
238 )*}
239}
240
241macro_rules! impl_to_element_uint {
242 ($T:ident) => {
243 impl ToElement for $T {
244 impl_to_element_uint_to_int! { $T:
245 fn to_isize -> isize;
246 fn to_i8 -> i8;
247 fn to_i16 -> i16;
248 fn to_i32 -> i32;
249 fn to_i64 -> i64;
250 fn to_i128 -> i128;
251 }
252
253 impl_to_element_uint_to_uint! { $T:
254 fn to_usize -> usize;
255 fn to_u8 -> u8;
256 fn to_u16 -> u16;
257 fn to_u32 -> u32;
258 fn to_u64 -> u64;
259 fn to_u128 -> u128;
260 }
261
262 #[inline]
263 fn to_f32(&self) -> f32 {
264 *self as f32
265 }
266 #[inline]
267 fn to_f64(&self) -> f64 {
268 *self as f64
269 }
270 #[inline]
271 fn to_bool(&self) -> bool {
272 *self != 0
273 }
274 }
275 };
276}
277
278impl_to_element_uint!(usize);
279impl_to_element_uint!(u8);
280impl_to_element_uint!(u16);
281impl_to_element_uint!(u32);
282impl_to_element_uint!(u64);
283impl_to_element_uint!(u128);
284
285macro_rules! impl_to_element_float_to_float {
286 ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
287 #[inline]
288 fn $method(&self) -> $DstT {
289 *self as $DstT
292 }
293 )*}
294}
295
296macro_rules! float_to_int_unchecked {
297 ($float:expr => $int:ty) => {
300 unsafe { $float.to_int_unchecked::<$int>() }
301 };
302}
303
304macro_rules! impl_to_element_float_to_signed_int {
305 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
306 #[inline]
307 $(#[$cfg])*
308 fn $method(&self) -> $i {
309 if size_of::<$f>() > size_of::<$i>() {
312 const MIN_M1: $f = $i::MIN as $f - 1.0;
314 const MAX_P1: $f = $i::MAX as $f + 1.0;
315 if *self > MIN_M1 && *self < MAX_P1 {
316 return float_to_int_unchecked!(*self => $i);
317 }
318 } else {
319 const MIN: $f = $i::MIN as $f;
322 const MAX_P1: $f = $i::MAX as $f;
325 if *self >= MIN && *self < MAX_P1 {
326 return float_to_int_unchecked!(*self => $i);
327 }
328 }
329 panic!("Float cannot be represented in the target signed int type")
330 }
331 )*}
332}
333
334macro_rules! impl_to_element_float_to_unsigned_int {
335 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
336 #[inline]
337 $(#[$cfg])*
338 fn $method(&self) -> $u {
339 if size_of::<$f>() > size_of::<$u>() {
342 const MAX_P1: $f = $u::MAX as $f + 1.0;
344 if *self > -1.0 && *self < MAX_P1 {
345 return float_to_int_unchecked!(*self => $u);
346 }
347 } else {
348 const MAX_P1: $f = $u::MAX as $f;
352 if *self > -1.0 && *self < MAX_P1 {
353 return float_to_int_unchecked!(*self => $u);
354 }
355 }
356 panic!("Float cannot be represented in the target unsigned int type")
357 }
358 )*}
359}
360
361macro_rules! impl_to_element_float {
362 ($T:ident) => {
363 impl ToElement for $T {
364 impl_to_element_float_to_signed_int! { $T:
365 fn to_isize -> isize;
366 fn to_i8 -> i8;
367 fn to_i16 -> i16;
368 fn to_i32 -> i32;
369 fn to_i64 -> i64;
370 fn to_i128 -> i128;
371 }
372
373 impl_to_element_float_to_unsigned_int! { $T:
374 fn to_usize -> usize;
375 fn to_u8 -> u8;
376 fn to_u16 -> u16;
377 fn to_u32 -> u32;
378 fn to_u64 -> u64;
379 fn to_u128 -> u128;
380 }
381
382 impl_to_element_float_to_float! { $T:
383 fn to_f32 -> f32;
384 fn to_f64 -> f64;
385 }
386
387 #[inline]
388 fn to_bool(&self) -> bool {
389 *self != 0.0
390 }
391 }
392 };
393}
394
395impl_to_element_float!(f32);
396impl_to_element_float!(f64);
397
398impl ToElement for f16 {
399 #[inline]
400 fn to_i64(&self) -> i64 {
401 Self::to_f32(*self).to_i64()
402 }
403 #[inline]
404 fn to_u64(&self) -> u64 {
405 Self::to_f32(*self).to_u64()
406 }
407 #[inline]
408 fn to_i8(&self) -> i8 {
409 Self::to_f32(*self).to_i8()
410 }
411 #[inline]
412 fn to_u8(&self) -> u8 {
413 Self::to_f32(*self).to_u8()
414 }
415 #[inline]
416 fn to_i16(&self) -> i16 {
417 Self::to_f32(*self).to_i16()
418 }
419 #[inline]
420 fn to_u16(&self) -> u16 {
421 Self::to_f32(*self).to_u16()
422 }
423 #[inline]
424 fn to_i32(&self) -> i32 {
425 Self::to_f32(*self).to_i32()
426 }
427 #[inline]
428 fn to_u32(&self) -> u32 {
429 Self::to_f32(*self).to_u32()
430 }
431 #[inline]
432 fn to_f16(&self) -> f16 {
433 *self
434 }
435 #[inline]
436 fn to_f32(&self) -> f32 {
437 Self::to_f32(*self)
438 }
439 #[inline]
440 fn to_f64(&self) -> f64 {
441 Self::to_f64(*self)
442 }
443 #[inline]
444 fn to_bool(&self) -> bool {
445 *self != f16::from_f32_const(0.0)
446 }
447}
448
449impl ToElement for bf16 {
450 #[inline]
451 fn to_i64(&self) -> i64 {
452 Self::to_f32(*self).to_i64()
453 }
454 #[inline]
455 fn to_u64(&self) -> u64 {
456 Self::to_f32(*self).to_u64()
457 }
458 #[inline]
459 fn to_i8(&self) -> i8 {
460 Self::to_f32(*self).to_i8()
461 }
462 #[inline]
463 fn to_u8(&self) -> u8 {
464 Self::to_f32(*self).to_u8()
465 }
466 #[inline]
467 fn to_i16(&self) -> i16 {
468 Self::to_f32(*self).to_i16()
469 }
470 #[inline]
471 fn to_u16(&self) -> u16 {
472 Self::to_f32(*self).to_u16()
473 }
474 #[inline]
475 fn to_i32(&self) -> i32 {
476 Self::to_f32(*self).to_i32()
477 }
478 #[inline]
479 fn to_u32(&self) -> u32 {
480 Self::to_f32(*self).to_u32()
481 }
482 #[inline]
483 fn to_bf16(&self) -> bf16 {
484 *self
485 }
486 #[inline]
487 fn to_f32(&self) -> f32 {
488 Self::to_f32(*self)
489 }
490 #[inline]
491 fn to_f64(&self) -> f64 {
492 Self::to_f64(*self)
493 }
494 #[inline]
495 fn to_bool(&self) -> bool {
496 *self != bf16::from_f32_const(0.0)
497 }
498}
499
500#[cfg(feature = "cubecl")]
501impl ToElement for cubecl::flex32 {
502 #[inline]
503 fn to_i64(&self) -> i64 {
504 Self::to_f32(*self).to_i64()
505 }
506 #[inline]
507 fn to_u64(&self) -> u64 {
508 Self::to_f32(*self).to_u64()
509 }
510 #[inline]
511 fn to_i8(&self) -> i8 {
512 Self::to_f32(*self).to_i8()
513 }
514 #[inline]
515 fn to_u8(&self) -> u8 {
516 Self::to_f32(*self).to_u8()
517 }
518 #[inline]
519 fn to_i16(&self) -> i16 {
520 Self::to_f32(*self).to_i16()
521 }
522 #[inline]
523 fn to_u16(&self) -> u16 {
524 Self::to_f32(*self).to_u16()
525 }
526 #[inline]
527 fn to_i32(&self) -> i32 {
528 Self::to_f32(*self).to_i32()
529 }
530 #[inline]
531 fn to_u32(&self) -> u32 {
532 Self::to_f32(*self).to_u32()
533 }
534 #[inline]
535 fn to_f32(&self) -> f32 {
536 Self::to_f32(*self)
537 }
538 #[inline]
539 fn to_f64(&self) -> f64 {
540 Self::to_f64(*self)
541 }
542 #[inline]
543 fn to_bool(&self) -> bool {
544 *self != cubecl::flex32::from_f32(0.0)
545 }
546}
547
548impl ToElement for bool {
549 #[inline]
550 fn to_i64(&self) -> i64 {
551 *self as i64
552 }
553 #[inline]
554 fn to_u64(&self) -> u64 {
555 *self as u64
556 }
557 #[inline]
558 fn to_i8(&self) -> i8 {
559 *self as i8
560 }
561 #[inline]
562 fn to_u8(&self) -> u8 {
563 *self as u8
564 }
565 #[inline]
566 fn to_i16(&self) -> i16 {
567 *self as i16
568 }
569 #[inline]
570 fn to_u16(&self) -> u16 {
571 *self as u16
572 }
573 #[inline]
574 fn to_i32(&self) -> i32 {
575 *self as i32
576 }
577 #[inline]
578 fn to_u32(&self) -> u32 {
579 *self as u32
580 }
581 #[inline]
582 fn to_f32(&self) -> f32 {
583 self.to_u8() as f32
584 }
585 #[inline]
586 fn to_f64(&self) -> f64 {
587 self.to_u8() as f64
588 }
589 #[inline]
590 fn to_bool(&self) -> bool {
591 *self
592 }
593}
594
595mod tests {
596 #[allow(unused_imports)]
597 use super::*;
598
599 #[test]
600 fn to_element_float() {
601 let f32_toolarge = 1e39f64;
602 assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
603 assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
604 assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
605 assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
606 assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
607 assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
608 assert!((f64::NAN).to_f32().is_nan());
609 }
610
611 #[test]
612 #[should_panic]
613 fn to_element_signed_to_u8_underflow() {
614 let _x = (-1i8).to_u8();
615 }
616
617 #[test]
618 #[should_panic]
619 fn to_element_signed_to_u16_underflow() {
620 let _x = (-1i8).to_u16();
621 }
622
623 #[test]
624 #[should_panic]
625 fn to_element_signed_to_u32_underflow() {
626 let _x = (-1i8).to_u32();
627 }
628
629 #[test]
630 #[should_panic]
631 fn to_element_signed_to_u64_underflow() {
632 let _x = (-1i8).to_u64();
633 }
634
635 #[test]
636 #[should_panic]
637 fn to_element_signed_to_u128_underflow() {
638 let _x = (-1i8).to_u128();
639 }
640
641 #[test]
642 #[should_panic]
643 fn to_element_signed_to_usize_underflow() {
644 let _x = (-1i8).to_usize();
645 }
646
647 #[test]
648 #[should_panic]
649 fn to_element_unsigned_to_u8_overflow() {
650 let _x = 256.to_u8();
651 }
652
653 #[test]
654 #[should_panic]
655 fn to_element_unsigned_to_u16_overflow() {
656 let _x = 65_536.to_u16();
657 }
658
659 #[test]
660 #[should_panic]
661 fn to_element_unsigned_to_u32_overflow() {
662 let _x = 4_294_967_296u64.to_u32();
663 }
664
665 #[test]
666 #[should_panic]
667 fn to_element_unsigned_to_u64_overflow() {
668 let _x = 18_446_744_073_709_551_616u128.to_u64();
669 }
670
671 #[test]
672 fn to_element_int_to_float() {
673 assert_eq!((-1).to_f32(), -1.0);
674 assert_eq!((-1).to_f64(), -1.0);
675 assert_eq!(255.to_f32(), 255.0);
676 assert_eq!(65_535.to_f64(), 65_535.0);
677 }
678
679 #[test]
680 fn to_element_float_to_int() {
681 assert_eq!((-1.0).to_i8(), -1);
682 assert_eq!(1.0.to_u8(), 1);
683 assert_eq!(1.8.to_u16(), 1);
684 assert_eq!(123.456.to_u32(), 123);
685 }
686}