1use super::DType;
4use bytemuck::{Pod, Zeroable};
5use std::ops::{Add, Div, Mul, Sub};
6
7pub trait Element:
21 Copy
22 + Clone
23 + Send
24 + Sync
25 + Pod
26 + Zeroable
27 + 'static
28 + Add<Output = Self>
29 + Sub<Output = Self>
30 + Mul<Output = Self>
31 + Div<Output = Self>
32 + PartialOrd
33{
34 const DTYPE: DType;
36
37 fn to_f64(self) -> f64;
48
49 fn from_f64(v: f64) -> Self;
55
56 #[inline]
61 fn to_f32(self) -> f32 {
62 self.to_f64() as f32
63 }
64
65 #[inline]
70 fn from_f32(v: f32) -> Self {
71 Self::from_f64(v as f64)
72 }
73
74 fn zero() -> Self;
76
77 fn one() -> Self;
79}
80
81impl Element for f64 {
82 const DTYPE: DType = DType::F64;
83
84 #[inline]
85 fn to_f64(self) -> f64 {
86 self
87 }
88
89 #[inline]
90 fn from_f64(v: f64) -> Self {
91 v
92 }
93
94 #[inline]
95 fn to_f32(self) -> f32 {
96 self as f32
97 }
98
99 #[inline]
100 fn from_f32(v: f32) -> Self {
101 v as f64
102 }
103
104 #[inline]
105 fn zero() -> Self {
106 0.0
107 }
108
109 #[inline]
110 fn one() -> Self {
111 1.0
112 }
113}
114
115impl Element for f32 {
116 const DTYPE: DType = DType::F32;
117
118 #[inline]
119 fn to_f64(self) -> f64 {
120 self as f64
121 }
122
123 #[inline]
124 fn from_f64(v: f64) -> Self {
125 v as f32
126 }
127
128 #[inline]
129 fn to_f32(self) -> f32 {
130 self
131 }
132
133 #[inline]
134 fn from_f32(v: f32) -> Self {
135 v
136 }
137
138 #[inline]
139 fn zero() -> Self {
140 0.0
141 }
142
143 #[inline]
144 fn one() -> Self {
145 1.0
146 }
147}
148
149impl Element for i64 {
150 const DTYPE: DType = DType::I64;
151
152 #[inline]
153 fn to_f64(self) -> f64 {
154 self as f64
155 }
156
157 #[inline]
158 fn from_f64(v: f64) -> Self {
159 v as i64
160 }
161
162 #[inline]
163 fn zero() -> Self {
164 0
165 }
166
167 #[inline]
168 fn one() -> Self {
169 1
170 }
171}
172
173impl Element for i32 {
174 const DTYPE: DType = DType::I32;
175
176 #[inline]
177 fn to_f64(self) -> f64 {
178 self as f64
179 }
180
181 #[inline]
182 fn from_f64(v: f64) -> Self {
183 v as i32
184 }
185
186 #[inline]
187 fn zero() -> Self {
188 0
189 }
190
191 #[inline]
192 fn one() -> Self {
193 1
194 }
195}
196
197impl Element for i16 {
198 const DTYPE: DType = DType::I16;
199
200 #[inline]
201 fn to_f64(self) -> f64 {
202 self as f64
203 }
204
205 #[inline]
206 fn from_f64(v: f64) -> Self {
207 v as i16
208 }
209
210 #[inline]
211 fn zero() -> Self {
212 0
213 }
214
215 #[inline]
216 fn one() -> Self {
217 1
218 }
219}
220
221impl Element for i8 {
222 const DTYPE: DType = DType::I8;
223
224 #[inline]
225 fn to_f64(self) -> f64 {
226 self as f64
227 }
228
229 #[inline]
230 fn from_f64(v: f64) -> Self {
231 v as i8
232 }
233
234 #[inline]
235 fn zero() -> Self {
236 0
237 }
238
239 #[inline]
240 fn one() -> Self {
241 1
242 }
243}
244
245impl Element for u64 {
246 const DTYPE: DType = DType::U64;
247
248 #[inline]
249 fn to_f64(self) -> f64 {
250 self as f64
251 }
252
253 #[inline]
254 fn from_f64(v: f64) -> Self {
255 v as u64
256 }
257
258 #[inline]
259 fn zero() -> Self {
260 0
261 }
262
263 #[inline]
264 fn one() -> Self {
265 1
266 }
267}
268
269impl Element for u32 {
270 const DTYPE: DType = DType::U32;
271
272 #[inline]
273 fn to_f64(self) -> f64 {
274 self as f64
275 }
276
277 #[inline]
278 fn from_f64(v: f64) -> Self {
279 v as u32
280 }
281
282 #[inline]
283 fn zero() -> Self {
284 0
285 }
286
287 #[inline]
288 fn one() -> Self {
289 1
290 }
291}
292
293impl Element for u16 {
294 const DTYPE: DType = DType::U16;
295
296 #[inline]
297 fn to_f64(self) -> f64 {
298 self as f64
299 }
300
301 #[inline]
302 fn from_f64(v: f64) -> Self {
303 v as u16
304 }
305
306 #[inline]
307 fn zero() -> Self {
308 0
309 }
310
311 #[inline]
312 fn one() -> Self {
313 1
314 }
315}
316
317impl Element for u8 {
318 const DTYPE: DType = DType::U8;
319
320 #[inline]
321 fn to_f64(self) -> f64 {
322 self as f64
323 }
324
325 #[inline]
326 fn from_f64(v: f64) -> Self {
327 v as u8
328 }
329
330 #[inline]
331 fn zero() -> Self {
332 0
333 }
334
335 #[inline]
336 fn one() -> Self {
337 1
338 }
339}
340
341#[cfg(feature = "f16")]
349impl Element for half::f16 {
350 const DTYPE: DType = DType::F16;
351
352 #[inline]
353 fn to_f64(self) -> f64 {
354 self.to_f64()
355 }
356
357 #[inline]
358 fn from_f64(v: f64) -> Self {
359 half::f16::from_f64(v)
360 }
361
362 #[inline]
363 fn to_f32(self) -> f32 {
364 self.to_f32()
365 }
366
367 #[inline]
368 fn from_f32(v: f32) -> Self {
369 half::f16::from_f32(v)
370 }
371
372 #[inline]
373 fn zero() -> Self {
374 half::f16::ZERO
375 }
376
377 #[inline]
378 fn one() -> Self {
379 half::f16::ONE
380 }
381}
382
383#[cfg(feature = "f16")]
384impl Element for half::bf16 {
385 const DTYPE: DType = DType::BF16;
386
387 #[inline]
388 fn to_f64(self) -> f64 {
389 self.to_f64()
390 }
391
392 #[inline]
393 fn from_f64(v: f64) -> Self {
394 half::bf16::from_f64(v)
395 }
396
397 #[inline]
398 fn to_f32(self) -> f32 {
399 self.to_f32()
400 }
401
402 #[inline]
403 fn from_f32(v: f32) -> Self {
404 half::bf16::from_f32(v)
405 }
406
407 #[inline]
408 fn zero() -> Self {
409 half::bf16::ZERO
410 }
411
412 #[inline]
413 fn one() -> Self {
414 half::bf16::ONE
415 }
416}
417
418impl Element for super::fp8::FP8E4M3 {
423 const DTYPE: DType = DType::FP8E4M3;
424
425 #[inline]
426 fn to_f64(self) -> f64 {
427 self.to_f32() as f64
428 }
429
430 #[inline]
431 fn from_f64(v: f64) -> Self {
432 Self::from_f32(v as f32)
433 }
434
435 #[inline]
436 fn to_f32(self) -> f32 {
437 self.to_f32()
438 }
439
440 #[inline]
441 fn from_f32(v: f32) -> Self {
442 Self::from_f32(v)
443 }
444
445 #[inline]
446 fn zero() -> Self {
447 Self::ZERO
448 }
449
450 #[inline]
451 fn one() -> Self {
452 Self::ONE
453 }
454}
455
456impl Element for super::fp8::FP8E5M2 {
457 const DTYPE: DType = DType::FP8E5M2;
458
459 #[inline]
460 fn to_f64(self) -> f64 {
461 self.to_f32() as f64
462 }
463
464 #[inline]
465 fn from_f64(v: f64) -> Self {
466 Self::from_f32(v as f32)
467 }
468
469 #[inline]
470 fn to_f32(self) -> f32 {
471 self.to_f32()
472 }
473
474 #[inline]
475 fn from_f32(v: f32) -> Self {
476 Self::from_f32(v)
477 }
478
479 #[inline]
480 fn zero() -> Self {
481 Self::ZERO
482 }
483
484 #[inline]
485 fn one() -> Self {
486 Self::ONE
487 }
488}
489
490impl Element for super::complex::Complex64 {
503 const DTYPE: DType = DType::Complex64;
504
505 #[inline]
508 fn to_f64(self) -> f64 {
509 self.magnitude() as f64
510 }
511
512 #[inline]
514 fn from_f64(v: f64) -> Self {
515 Self::new(v as f32, 0.0)
516 }
517
518 #[inline]
519 fn zero() -> Self {
520 Self::ZERO
521 }
522
523 #[inline]
524 fn one() -> Self {
525 Self::ONE
526 }
527}
528
529impl Element for super::complex::Complex128 {
530 const DTYPE: DType = DType::Complex128;
531
532 #[inline]
535 fn to_f64(self) -> f64 {
536 self.magnitude()
537 }
538
539 #[inline]
541 fn from_f64(v: f64) -> Self {
542 Self::new(v, 0.0)
543 }
544
545 #[inline]
546 fn zero() -> Self {
547 Self::ZERO
548 }
549
550 #[inline]
551 fn one() -> Self {
552 Self::ONE
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_element_dtype() {
562 assert_eq!(f64::DTYPE, DType::F64);
563 assert_eq!(f32::DTYPE, DType::F32);
564 assert_eq!(i32::DTYPE, DType::I32);
565 assert_eq!(u8::DTYPE, DType::U8);
566 }
567
568 #[test]
569 fn test_element_conversions() {
570 assert_eq!(f32::from_f64(2.5).to_f64(), 2.5f32 as f64);
571 assert_eq!(i32::from_f64(42.0), 42);
572 }
573
574 #[test]
575 fn test_fp8_element_dtype() {
576 use super::super::fp8::{FP8E4M3, FP8E5M2};
577 assert_eq!(FP8E4M3::DTYPE, DType::FP8E4M3);
578 assert_eq!(FP8E5M2::DTYPE, DType::FP8E5M2);
579 }
580
581 #[test]
582 fn test_fp8_element_conversions() {
583 use super::super::fp8::{FP8E4M3, FP8E5M2};
584
585 let e4m3 = FP8E4M3::from_f64(2.0);
587 assert!((e4m3.to_f64() - 2.0).abs() < 0.1);
588
589 let e5m2 = FP8E5M2::from_f64(100.0);
591 assert!((e5m2.to_f64() - 100.0).abs() < 15.0);
592
593 assert_eq!(FP8E4M3::zero().to_f32(), 0.0);
595 assert!((FP8E4M3::one().to_f32() - 1.0).abs() < 0.01);
596 assert_eq!(FP8E5M2::zero().to_f32(), 0.0);
597 assert!((FP8E5M2::one().to_f32() - 1.0).abs() < 0.01);
598 }
599}