burn_tensor/tensor/element/
cast.rs1use core::mem::size_of;
2
3use half::{bf16, f16};
4
5pub trait ToElement {
19 #[inline]
21 fn to_isize(&self) -> isize {
22 ToElement::to_isize(&self.to_i64())
23 }
24
25 #[inline]
27 fn to_i8(&self) -> i8 {
28 ToElement::to_i8(&self.to_i64())
29 }
30
31 #[inline]
33 fn to_i16(&self) -> i16 {
34 ToElement::to_i16(&self.to_i64())
35 }
36
37 #[inline]
39 fn to_i32(&self) -> i32 {
40 ToElement::to_i32(&self.to_i64())
41 }
42
43 fn to_i64(&self) -> i64;
45
46 #[inline]
51 fn to_i128(&self) -> i128 {
52 i128::from(self.to_i64())
53 }
54
55 #[inline]
57 fn to_usize(&self) -> usize {
58 ToElement::to_usize(&self.to_u64())
59 }
60
61 #[inline]
63 fn to_u8(&self) -> u8 {
64 ToElement::to_u8(&self.to_u64())
65 }
66
67 #[inline]
69 fn to_u16(&self) -> u16 {
70 ToElement::to_u16(&self.to_u64())
71 }
72
73 #[inline]
75 fn to_u32(&self) -> u32 {
76 ToElement::to_u32(&self.to_u64())
77 }
78
79 fn to_u64(&self) -> u64;
81
82 #[inline]
87 fn to_u128(&self) -> u128 {
88 u128::from(self.to_u64())
89 }
90
91 #[inline]
94 fn to_f32(&self) -> f32 {
95 ToElement::to_f32(&self.to_f64())
96 }
97
98 #[inline]
105 fn to_f64(&self) -> f64 {
106 ToElement::to_f64(&self.to_u64())
107 }
108}
109
110macro_rules! impl_to_element_int_to_int {
111 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
112 #[inline]
113 $(#[$cfg])*
114 fn $method(&self) -> $DstT {
115 let min = $DstT::MIN as $SrcT;
116 let max = $DstT::MAX as $SrcT;
117 if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
118 *self as $DstT
119 } else {
120 panic!("Element cannot be represented in the target type")
121 }
122 }
123 )*}
124}
125
126macro_rules! impl_to_element_int_to_uint {
127 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
128 #[inline]
129 $(#[$cfg])*
130 fn $method(&self) -> $DstT {
131 let max = $DstT::MAX as $SrcT;
132 if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
133 *self as $DstT
134 } else {
135 panic!("Element cannot be represented in the target type")
136 }
137 }
138 )*}
139}
140
141macro_rules! impl_to_element_int {
142 ($T:ident) => {
143 impl ToElement for $T {
144 impl_to_element_int_to_int! { $T:
145 fn to_isize -> isize;
146 fn to_i8 -> i8;
147 fn to_i16 -> i16;
148 fn to_i32 -> i32;
149 fn to_i64 -> i64;
150 fn to_i128 -> i128;
151 }
152
153 impl_to_element_int_to_uint! { $T:
154 fn to_usize -> usize;
155 fn to_u8 -> u8;
156 fn to_u16 -> u16;
157 fn to_u32 -> u32;
158 fn to_u64 -> u64;
159 fn to_u128 -> u128;
160 }
161
162 #[inline]
163 fn to_f32(&self) -> f32 {
164 *self as f32
165 }
166 #[inline]
167 fn to_f64(&self) -> f64 {
168 *self as f64
169 }
170 }
171 };
172}
173
174impl_to_element_int!(isize);
175impl_to_element_int!(i8);
176impl_to_element_int!(i16);
177impl_to_element_int!(i32);
178impl_to_element_int!(i64);
179impl_to_element_int!(i128);
180
181macro_rules! impl_to_element_uint_to_int {
182 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
183 #[inline]
184 $(#[$cfg])*
185 fn $method(&self) -> $DstT {
186 let max = $DstT::MAX as $SrcT;
187 if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
188 *self as $DstT
189 } else {
190 panic!("Element cannot be represented in the target type")
191 }
192 }
193 )*}
194}
195
196macro_rules! impl_to_element_uint_to_uint {
197 ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
198 #[inline]
199 $(#[$cfg])*
200 fn $method(&self) -> $DstT {
201 let max = $DstT::MAX as $SrcT;
202 if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
203 *self as $DstT
204 } else {
205 panic!("Element cannot be represented in the target type")
206 }
207 }
208 )*}
209}
210
211macro_rules! impl_to_element_uint {
212 ($T:ident) => {
213 impl ToElement for $T {
214 impl_to_element_uint_to_int! { $T:
215 fn to_isize -> isize;
216 fn to_i8 -> i8;
217 fn to_i16 -> i16;
218 fn to_i32 -> i32;
219 fn to_i64 -> i64;
220 fn to_i128 -> i128;
221 }
222
223 impl_to_element_uint_to_uint! { $T:
224 fn to_usize -> usize;
225 fn to_u8 -> u8;
226 fn to_u16 -> u16;
227 fn to_u32 -> u32;
228 fn to_u64 -> u64;
229 fn to_u128 -> u128;
230 }
231
232 #[inline]
233 fn to_f32(&self) -> f32 {
234 *self as f32
235 }
236 #[inline]
237 fn to_f64(&self) -> f64 {
238 *self as f64
239 }
240 }
241 };
242}
243
244impl_to_element_uint!(usize);
245impl_to_element_uint!(u8);
246impl_to_element_uint!(u16);
247impl_to_element_uint!(u32);
248impl_to_element_uint!(u64);
249impl_to_element_uint!(u128);
250
251macro_rules! impl_to_element_float_to_float {
252 ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
253 #[inline]
254 fn $method(&self) -> $DstT {
255 *self as $DstT
258 }
259 )*}
260}
261
262macro_rules! float_to_int_unchecked {
263 ($float:expr => $int:ty) => {
266 unsafe { $float.to_int_unchecked::<$int>() }
267 };
268}
269
270macro_rules! impl_to_element_float_to_signed_int {
271 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
272 #[inline]
273 $(#[$cfg])*
274 fn $method(&self) -> $i {
275 if size_of::<$f>() > size_of::<$i>() {
278 const MIN_M1: $f = $i::MIN as $f - 1.0;
280 const MAX_P1: $f = $i::MAX as $f + 1.0;
281 if *self > MIN_M1 && *self < MAX_P1 {
282 return float_to_int_unchecked!(*self => $i);
283 }
284 } else {
285 const MIN: $f = $i::MIN as $f;
288 const MAX_P1: $f = $i::MAX as $f;
291 if *self >= MIN && *self < MAX_P1 {
292 return float_to_int_unchecked!(*self => $i);
293 }
294 }
295 panic!("Float cannot be represented in the target signed int type")
296 }
297 )*}
298}
299
300macro_rules! impl_to_element_float_to_unsigned_int {
301 ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
302 #[inline]
303 $(#[$cfg])*
304 fn $method(&self) -> $u {
305 if size_of::<$f>() > size_of::<$u>() {
308 const MAX_P1: $f = $u::MAX as $f + 1.0;
310 if *self > -1.0 && *self < MAX_P1 {
311 return float_to_int_unchecked!(*self => $u);
312 }
313 } else {
314 const MAX_P1: $f = $u::MAX as $f;
318 if *self > -1.0 && *self < MAX_P1 {
319 return float_to_int_unchecked!(*self => $u);
320 }
321 }
322 panic!("Float cannot be represented in the target unsigned int type")
323 }
324 )*}
325}
326
327macro_rules! impl_to_element_float {
328 ($T:ident) => {
329 impl ToElement for $T {
330 impl_to_element_float_to_signed_int! { $T:
331 fn to_isize -> isize;
332 fn to_i8 -> i8;
333 fn to_i16 -> i16;
334 fn to_i32 -> i32;
335 fn to_i64 -> i64;
336 fn to_i128 -> i128;
337 }
338
339 impl_to_element_float_to_unsigned_int! { $T:
340 fn to_usize -> usize;
341 fn to_u8 -> u8;
342 fn to_u16 -> u16;
343 fn to_u32 -> u32;
344 fn to_u64 -> u64;
345 fn to_u128 -> u128;
346 }
347
348 impl_to_element_float_to_float! { $T:
349 fn to_f32 -> f32;
350 fn to_f64 -> f64;
351 }
352 }
353 };
354}
355
356impl_to_element_float!(f32);
357impl_to_element_float!(f64);
358
359impl ToElement for f16 {
360 #[inline]
361 fn to_i64(&self) -> i64 {
362 Self::to_f32(*self).to_i64()
363 }
364 #[inline]
365 fn to_u64(&self) -> u64 {
366 Self::to_f32(*self).to_u64()
367 }
368 #[inline]
369 fn to_i8(&self) -> i8 {
370 Self::to_f32(*self).to_i8()
371 }
372 #[inline]
373 fn to_u8(&self) -> u8 {
374 Self::to_f32(*self).to_u8()
375 }
376 #[inline]
377 fn to_i16(&self) -> i16 {
378 Self::to_f32(*self).to_i16()
379 }
380 #[inline]
381 fn to_u16(&self) -> u16 {
382 Self::to_f32(*self).to_u16()
383 }
384 #[inline]
385 fn to_i32(&self) -> i32 {
386 Self::to_f32(*self).to_i32()
387 }
388 #[inline]
389 fn to_u32(&self) -> u32 {
390 Self::to_f32(*self).to_u32()
391 }
392 #[inline]
393 fn to_f32(&self) -> f32 {
394 Self::to_f32(*self)
395 }
396 #[inline]
397 fn to_f64(&self) -> f64 {
398 Self::to_f64(*self)
399 }
400}
401
402impl ToElement for bf16 {
403 #[inline]
404 fn to_i64(&self) -> i64 {
405 Self::to_f32(*self).to_i64()
406 }
407 #[inline]
408 fn to_u64(&self) -> u64 {
409 Self::to_f32(*self).to_u64()
410 }
411 #[inline]
412 fn to_i8(&self) -> i8 {
413 Self::to_f32(*self).to_i8()
414 }
415 #[inline]
416 fn to_u8(&self) -> u8 {
417 Self::to_f32(*self).to_u8()
418 }
419 #[inline]
420 fn to_i16(&self) -> i16 {
421 Self::to_f32(*self).to_i16()
422 }
423 #[inline]
424 fn to_u16(&self) -> u16 {
425 Self::to_f32(*self).to_u16()
426 }
427 #[inline]
428 fn to_i32(&self) -> i32 {
429 Self::to_f32(*self).to_i32()
430 }
431 #[inline]
432 fn to_u32(&self) -> u32 {
433 Self::to_f32(*self).to_u32()
434 }
435 #[inline]
436 fn to_f32(&self) -> f32 {
437 Self::to_f32(*self)
438 }
439 #[inline]
440 fn to_f64(&self) -> f64 {
441 Self::to_f64(*self)
442 }
443}
444
445#[cfg(feature = "cubecl")]
446impl ToElement for cubecl::flex32 {
447 #[inline]
448 fn to_i64(&self) -> i64 {
449 Self::to_f32(*self).to_i64()
450 }
451 #[inline]
452 fn to_u64(&self) -> u64 {
453 Self::to_f32(*self).to_u64()
454 }
455 #[inline]
456 fn to_i8(&self) -> i8 {
457 Self::to_f32(*self).to_i8()
458 }
459 #[inline]
460 fn to_u8(&self) -> u8 {
461 Self::to_f32(*self).to_u8()
462 }
463 #[inline]
464 fn to_i16(&self) -> i16 {
465 Self::to_f32(*self).to_i16()
466 }
467 #[inline]
468 fn to_u16(&self) -> u16 {
469 Self::to_f32(*self).to_u16()
470 }
471 #[inline]
472 fn to_i32(&self) -> i32 {
473 Self::to_f32(*self).to_i32()
474 }
475 #[inline]
476 fn to_u32(&self) -> u32 {
477 Self::to_f32(*self).to_u32()
478 }
479 #[inline]
480 fn to_f32(&self) -> f32 {
481 Self::to_f32(*self)
482 }
483 #[inline]
484 fn to_f64(&self) -> f64 {
485 Self::to_f64(*self)
486 }
487}
488
489impl ToElement for bool {
490 #[inline]
491 fn to_i64(&self) -> i64 {
492 *self as i64
493 }
494 #[inline]
495 fn to_u64(&self) -> u64 {
496 *self as u64
497 }
498 #[inline]
499 fn to_i8(&self) -> i8 {
500 *self as i8
501 }
502 #[inline]
503 fn to_u8(&self) -> u8 {
504 *self as u8
505 }
506 #[inline]
507 fn to_i16(&self) -> i16 {
508 *self as i16
509 }
510 #[inline]
511 fn to_u16(&self) -> u16 {
512 *self as u16
513 }
514 #[inline]
515 fn to_i32(&self) -> i32 {
516 *self as i32
517 }
518 #[inline]
519 fn to_u32(&self) -> u32 {
520 *self as u32
521 }
522 #[inline]
523 fn to_f32(&self) -> f32 {
524 self.to_u8() as f32
525 }
526 #[inline]
527 fn to_f64(&self) -> f64 {
528 self.to_u8() as f64
529 }
530}
531
532mod tests {
533 #[allow(unused_imports)]
534 use super::*;
535
536 #[test]
537 fn to_element_float() {
538 let f32_toolarge = 1e39f64;
539 assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
540 assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
541 assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
542 assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
543 assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
544 assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
545 assert!((f64::NAN).to_f32().is_nan());
546 }
547
548 #[test]
549 #[should_panic]
550 fn to_element_signed_to_u8_underflow() {
551 let _x = (-1i8).to_u8();
552 }
553
554 #[test]
555 #[should_panic]
556 fn to_element_signed_to_u16_underflow() {
557 let _x = (-1i8).to_u16();
558 }
559
560 #[test]
561 #[should_panic]
562 fn to_element_signed_to_u32_underflow() {
563 let _x = (-1i8).to_u32();
564 }
565
566 #[test]
567 #[should_panic]
568 fn to_element_signed_to_u64_underflow() {
569 let _x = (-1i8).to_u64();
570 }
571
572 #[test]
573 #[should_panic]
574 fn to_element_signed_to_u128_underflow() {
575 let _x = (-1i8).to_u128();
576 }
577
578 #[test]
579 #[should_panic]
580 fn to_element_signed_to_usize_underflow() {
581 let _x = (-1i8).to_usize();
582 }
583
584 #[test]
585 #[should_panic]
586 fn to_element_unsigned_to_u8_overflow() {
587 let _x = 256.to_u8();
588 }
589
590 #[test]
591 #[should_panic]
592 fn to_element_unsigned_to_u16_overflow() {
593 let _x = 65_536.to_u16();
594 }
595
596 #[test]
597 #[should_panic]
598 fn to_element_unsigned_to_u32_overflow() {
599 let _x = 4_294_967_296u64.to_u32();
600 }
601
602 #[test]
603 #[should_panic]
604 fn to_element_unsigned_to_u64_overflow() {
605 let _x = 18_446_744_073_709_551_616u128.to_u64();
606 }
607
608 #[test]
609 fn to_element_int_to_float() {
610 assert_eq!((-1).to_f32(), -1.0);
611 assert_eq!((-1).to_f64(), -1.0);
612 assert_eq!(255.to_f32(), 255.0);
613 assert_eq!(65_535.to_f64(), 65_535.0);
614 }
615
616 #[test]
617 fn to_element_float_to_int() {
618 assert_eq!((-1.0).to_i8(), -1);
619 assert_eq!(1.0.to_u8(), 1);
620 assert_eq!(1.8.to_u16(), 1);
621 assert_eq!(123.456.to_u32(), 123);
622 }
623}