kn_graph/
dtype.rs

1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::num::FpCategory;
4use std::ops::Deref;
5
6use bytemuck::NoUninit;
7use decorum::cmp::FloatEq;
8use decorum::hash::FloatHash;
9use itertools::zip_eq;
10use ndarray::{ArcArray, IntoDimension, IxDyn, LinalgScalar};
11
12#[derive(Debug, Copy, Clone)]
13pub struct T32(pub f32);
14
15#[derive(Debug, Copy, Clone)]
16pub struct T64(pub f64);
17
18// TODO maybe remove this at some point and switch to proper bools,
19//   and figure out another solution to the "bool arithmetic" problem
20#[derive(Debug, Copy, Clone, Eq, Ord, PartialOrd, PartialEq, Hash)]
21#[repr(transparent)]
22pub struct DBool(pub bool);
23
24#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
25pub enum DType {
26    F32,
27    F64,
28    I8,
29    I16,
30    I32,
31    I64,
32    U8,
33    U16,
34    U32,
35    U64,
36    Bool,
37}
38
39pub type Tensor<T> = ArcArray<T, IxDyn>;
40
41#[derive(Debug, Clone)]
42pub enum DTensor {
43    F32(Tensor<f32>),
44    F64(Tensor<f64>),
45    I8(Tensor<i8>),
46    I16(Tensor<i16>),
47    I32(Tensor<i32>),
48    I64(Tensor<i64>),
49    U8(Tensor<u8>),
50    U16(Tensor<u16>),
51    U32(Tensor<u32>),
52    U64(Tensor<u64>),
53    Bool(Tensor<DBool>),
54}
55
56#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
57pub enum DScalar {
58    F32(T32),
59    F64(T64),
60    I8(i8),
61    I16(i16),
62    I32(i32),
63    I64(i64),
64    U8(u8),
65    U16(u16),
66    U32(u32),
67    U64(u64),
68    Bool(DBool),
69}
70
71#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
72pub enum DSize {
73    S8,
74    S16,
75    S32,
76    S64,
77}
78
79#[derive(Debug, Copy, Clone, Eq, PartialEq)]
80pub struct Specials {
81    pub zero: DScalar,
82    pub one: DScalar,
83    pub min: DScalar,
84    pub max: DScalar,
85}
86
87#[derive(Debug, Copy, Clone, Eq, PartialEq)]
88pub struct DInfo {
89    pub size: DSize,
90    pub signed: bool,
91    pub float: bool,
92    pub int: bool,
93    pub is_bool: bool,
94}
95
96impl DType {
97    pub fn info(self) -> DInfo {
98        match self {
99            DType::F32 => DInfo::float(DSize::S32),
100            DType::F64 => DInfo::float(DSize::S64),
101            DType::I8 => DInfo::int(DSize::S8, true),
102            DType::I16 => DInfo::int(DSize::S16, true),
103            DType::I32 => DInfo::int(DSize::S32, true),
104            DType::I64 => DInfo::int(DSize::S64, true),
105            DType::U8 => DInfo::int(DSize::S8, false),
106            DType::U16 => DInfo::int(DSize::S16, false),
107            DType::U32 => DInfo::int(DSize::S32, false),
108            DType::U64 => DInfo::int(DSize::S64, false),
109            DType::Bool => DInfo::bool(),
110        }
111    }
112
113    pub fn size(self) -> DSize {
114        self.info().size
115    }
116
117    pub fn is_signed(self) -> bool {
118        self.info().signed
119    }
120
121    pub fn is_float(self) -> bool {
122        self.info().float
123    }
124
125    pub fn is_int(self) -> bool {
126        self.info().int
127    }
128
129    pub fn is_bool(self) -> bool {
130        self.info().is_bool
131    }
132
133    // TODO move specials to type itself, while keeping this one too?
134    pub fn specials(self) -> Specials {
135        match self {
136            DType::F32 => Specials::new(f32::NEG_INFINITY, f32::INFINITY),
137            DType::F64 => Specials::new(f64::NEG_INFINITY, f64::INFINITY),
138            DType::I8 => Specials::new(i8::MIN, i8::MAX),
139            DType::I16 => Specials::new(i16::MIN, i16::MAX),
140            DType::I32 => Specials::new(i32::MIN, i32::MAX),
141            DType::I64 => Specials::new(i64::MIN, i64::MAX),
142            DType::U8 => Specials::new(u8::MIN, u8::MAX),
143            DType::U16 => Specials::new(u16::MIN, u16::MAX),
144            DType::U32 => Specials::new(u32::MIN, u32::MAX),
145            DType::U64 => Specials::new(u64::MIN, u64::MAX),
146            DType::Bool => Specials::new(DBool(false), DBool(true)),
147        }
148    }
149
150    pub fn as_c_str(self) -> &'static str {
151        match self {
152            DType::F32 => "float",
153            DType::F64 => "double",
154            DType::I8 => "int8_t",
155            DType::I16 => "int16_t",
156            DType::I32 => "int32_t",
157            DType::I64 => "int64_t",
158            DType::U8 => "uint8_t",
159            DType::U16 => "uint16_t",
160            DType::U32 => "uint32_t",
161            DType::U64 => "uint64_t",
162            DType::Bool => "bool",
163        }
164    }
165}
166
167impl DInfo {
168    fn int(size: DSize, signed: bool) -> Self {
169        DInfo {
170            size,
171            signed,
172            float: false,
173            int: true,
174            is_bool: false,
175        }
176    }
177
178    fn float(size: DSize) -> Self {
179        DInfo {
180            size,
181            signed: true,
182            float: true,
183            int: false,
184            is_bool: false,
185        }
186    }
187
188    fn bool() -> Self {
189        DInfo {
190            size: DSize::S8,
191            signed: false,
192            float: false,
193            int: false,
194            is_bool: true,
195        }
196    }
197}
198
199impl DSize {
200    pub fn bytes(self) -> usize {
201        match self {
202            DSize::S8 => 1,
203            DSize::S16 => 2,
204            DSize::S32 => 4,
205            DSize::S64 => 8,
206        }
207    }
208}
209
210#[rustfmt::skip]
211#[macro_export]
212macro_rules! dispatch_dtype {
213    ($outer:expr, |$ty:ident, $fs:ident, $ft:ident| $expr:expr) => {{
214        use $crate::dtype::{DType, DBool, DScalar, DTensor};
215        match $outer {
216            DType::F32 => { type $ty=f32; let $fs=DScalar::F32; let $ft=DTensor::F32; { $expr } }
217            DType::F64 => { type $ty=f64; let $fs=DScalar::F64; let $ft=DTensor::F64; { $expr } }
218            DType::I8 => { type $ty=i8; let $fs=DScalar::I8; let $ft=DTensor::I8; { $expr } }
219            DType::I16 => { type $ty=i16; let $fs=DScalar::I16; let $ft=DTensor::I16; { $expr } }
220            DType::I32 => { type $ty=i32; let $fs=DScalar::I32; let $ft=DTensor::I32; { $expr } }
221            DType::I64 => { type $ty=i64; let $fs=DScalar::I64; let $ft=DTensor::I64; { $expr } }
222            DType::U8 => { type $ty=u8; let $fs=DScalar::U8; let $ft=DTensor::U8; { $expr } }
223            DType::U16 => { type $ty=u16; let $fs=DScalar::U16; let $ft=DTensor::U16; { $expr } }
224            DType::U32 => { type $ty=u32; let $fs=DScalar::U32; let $ft=DTensor::U32; { $expr } }
225            DType::U64 => { type $ty=u64; let $fs=DScalar::U64; let $ft=DTensor::U64; { $expr } }
226            DType::Bool => { type $ty=DBool; let $fs=DScalar::Bool; let $ft=DTensor::Bool; { $expr } }
227        }
228    }};
229}
230
231impl DScalar {
232    pub fn f32(x: f32) -> Self {
233        DScalar::F32(T32(x))
234    }
235
236    pub fn f64(x: f64) -> Self {
237        DScalar::F64(T64(x))
238    }
239
240    pub fn bool(x: bool) -> Self {
241        DScalar::Bool(DBool(x))
242    }
243
244    pub fn dtype(self) -> DType {
245        match self {
246            DScalar::F32(_) => DType::F32,
247            DScalar::F64(_) => DType::F64,
248            DScalar::I8(_) => DType::I8,
249            DScalar::I16(_) => DType::I16,
250            DScalar::I32(_) => DType::I32,
251            DScalar::I64(_) => DType::I64,
252            DScalar::U8(_) => DType::U8,
253            DScalar::U16(_) => DType::U16,
254            DScalar::U32(_) => DType::U32,
255            DScalar::U64(_) => DType::U64,
256            DScalar::Bool(_) => DType::Bool,
257        }
258    }
259
260    pub fn unwrap_f32(self) -> Option<f32> {
261        match self {
262            DScalar::F32(x) => Some(x.0),
263            _ => None,
264        }
265    }
266
267    pub fn to_tensor(self) -> DTensor {
268        match self {
269            DScalar::F32(T32(s)) => DTensor::F32(ArcArray::from_shape_vec(IxDyn(&[]), vec![s]).unwrap()),
270            DScalar::F64(T64(s)) => DTensor::F64(ArcArray::from_shape_vec(IxDyn(&[]), vec![s]).unwrap()),
271            DScalar::I8(x) => DTensor::I8(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
272            DScalar::I16(x) => DTensor::I16(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
273            DScalar::I32(x) => DTensor::I32(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
274            DScalar::I64(x) => DTensor::I64(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
275            DScalar::U8(x) => DTensor::U8(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
276            DScalar::U16(x) => DTensor::U16(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
277            DScalar::U32(x) => DTensor::U32(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
278            DScalar::U64(x) => DTensor::U64(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
279            DScalar::Bool(x) => DTensor::Bool(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
280        }
281    }
282
283    pub fn unwrap_uint(self) -> Option<u64> {
284        match self {
285            DScalar::U8(x) => Some(x as u64),
286            DScalar::U16(x) => Some(x as u64),
287            DScalar::U32(x) => Some(x as u64),
288            DScalar::U64(x) => Some(x),
289            _ => None,
290        }
291    }
292
293    pub fn unwrap_int(self) -> Option<i128> {
294        match self {
295            DScalar::U8(x) => Some(x as i128),
296            DScalar::U16(x) => Some(x as i128),
297            DScalar::U32(x) => Some(x as i128),
298            DScalar::U64(x) => Some(x as i128),
299            DScalar::I8(x) => Some(x as i128),
300            DScalar::I16(x) => Some(x as i128),
301            DScalar::I32(x) => Some(x as i128),
302            DScalar::I64(x) => Some(x as i128),
303            _ => None,
304        }
305    }
306
307    pub fn to_c_str(self) -> String {
308        match self {
309            DScalar::F32(c) => DisplayCFloat(*c as f64).to_string(),
310            DScalar::F64(c) => DisplayCFloat(*c).to_string(),
311            DScalar::U8(c) => format!("{}", c),
312            DScalar::U16(c) => format!("{}", c),
313            DScalar::U32(c) => format!("{}", c),
314            DScalar::U64(c) => format!("{}", c),
315            DScalar::I8(c) => format!("{}", c),
316            DScalar::I16(c) => format!("{}", c),
317            DScalar::I32(c) => format!("{}", c),
318            DScalar::I64(c) => format!("{}", c),
319            DScalar::Bool(c) => format!("{}", *c),
320        }
321    }
322
323    pub fn value_cast(self, to: DType) -> DScalar {
324        // cast to big general value
325        let (yf, yi) = match self {
326            DScalar::F32(T32(x)) => (x as f64, x as i128),
327            DScalar::F64(T64(x)) => (x, x as i128),
328            DScalar::I8(x) => (x as f64, x as i128),
329            DScalar::I16(x) => (x as f64, x as i128),
330            DScalar::I32(x) => (x as f64, x as i128),
331            DScalar::I64(x) => (x as f64, x as i128),
332            DScalar::U8(x) => (x as f64, x as i128),
333            DScalar::U16(x) => (x as f64, x as i128),
334            DScalar::U32(x) => (x as f64, x as i128),
335            DScalar::U64(x) => (x as f64, x as i128),
336            DScalar::Bool(DBool(x)) => (x as u8 as f64, x as u8 as i128),
337        };
338
339        // convert to target
340        match to {
341            DType::F32 => DScalar::f32(yf as f32),
342            DType::F64 => DScalar::f64(yf),
343            DType::I8 => DScalar::I8(yi as i8),
344            DType::I16 => DScalar::I16(yi as i16),
345            DType::I32 => DScalar::I32(yi as i32),
346            DType::I64 => DScalar::I64(yi as i64),
347            DType::U8 => DScalar::U8(yi as u8),
348            DType::U16 => DScalar::U16(yi as u16),
349            DType::U32 => DScalar::U32(yi as u32),
350            DType::U64 => DScalar::U64(yi as u64),
351            DType::Bool => DScalar::bool(yf != 0.0 || yi != 0),
352        }
353    }
354
355    pub fn bit_cast(self, to: DType) -> Option<DScalar> {
356        if self.dtype().size() != to.size() {
357            return None;
358        }
359
360        // convert to bits, zero-extend just to be safe
361        let bits = match self {
362            DScalar::F32(T32(x)) => x.to_bits() as u64,
363            DScalar::F64(T64(x)) => x.to_bits(),
364            DScalar::I8(x) => x as u8 as u64,
365            DScalar::I16(x) => x as u16 as u64,
366            DScalar::I32(x) => x as u32 as u64,
367            DScalar::I64(x) => x as u64,
368            DScalar::U8(x) => x as u64,
369            DScalar::U16(x) => x as u64,
370            DScalar::U32(x) => x as u64,
371            DScalar::U64(x) => x,
372            DScalar::Bool(_) => return None,
373        };
374
375        // convert to target
376        let y = match to {
377            DType::F32 => DScalar::f32(f32::from_bits(bits as u32)),
378            DType::F64 => DScalar::f64(f64::from_bits(bits)),
379            DType::I8 => DScalar::I8(bits as i8),
380            DType::I16 => DScalar::I16(bits as i16),
381            DType::I32 => DScalar::I32(bits as i32),
382            DType::I64 => DScalar::I64(bits as i64),
383            DType::U8 => DScalar::U8(bits as u8),
384            DType::U16 => DScalar::U16(bits as u16),
385            DType::U32 => DScalar::U32(bits as u32),
386            DType::U64 => DScalar::U64(bits),
387            DType::Bool => return None,
388        };
389
390        Some(y)
391    }
392}
393
394pub trait IntoDScalar: LinalgScalar + PartialEq {
395    const DTYPE: DType;
396    fn to_dscalar(&self) -> DScalar;
397    fn from_dscalar(scalar: DScalar) -> Option<Self>;
398    fn vec_to_dtensor(data: Vec<Self>) -> DTensor;
399}
400
401macro_rules! impl_into_dscalar {
402    ($ty:ty, $dtype:expr, $dtensor:ident, |$x:ident| $conv:expr, $pattern:pat => $result:expr) => {
403        impl IntoDScalar for $ty {
404            const DTYPE: DType = $dtype;
405
406            fn to_dscalar(&self) -> DScalar {
407                let &$x = self;
408                $conv
409            }
410
411            fn from_dscalar(scalar: DScalar) -> Option<Self> {
412                match scalar {
413                    $pattern => Some($result),
414                    _ => None,
415                }
416            }
417
418            fn vec_to_dtensor(data: Vec<Self>) -> DTensor {
419                DTensor::$dtensor(ArcArray::from_vec(data).into_dyn())
420            }
421        }
422    };
423}
424
425impl_into_dscalar!(f32, DType::F32, F32, |x| DScalar::f32(x), DScalar::F32(T32(x)) => x);
426impl_into_dscalar!(f64, DType::F64, F64, |x| DScalar::f64(x), DScalar::F64(T64(x)) => x);
427impl_into_dscalar!(i8, DType::I8, I8, |x| DScalar::I8(x), DScalar::I8(x) => x);
428impl_into_dscalar!(i16, DType::I16, I16, |x| DScalar::I16(x), DScalar::I16(x) => x);
429impl_into_dscalar!(i32, DType::I32, I32, |x| DScalar::I32(x), DScalar::I32(x) => x);
430impl_into_dscalar!(i64, DType::I64, I64, |x| DScalar::I64(x), DScalar::I64(x) => x);
431impl_into_dscalar!(u8, DType::U8, U8, |x| DScalar::U8(x), DScalar::U8(x) => x);
432impl_into_dscalar!(u16, DType::U16, U16, |x| DScalar::U16(x), DScalar::U16(x) => x);
433impl_into_dscalar!(u32, DType::U32, U32, |x| DScalar::U32(x), DScalar::U32(x) => x);
434impl_into_dscalar!(u64, DType::U64, U64, |x| DScalar::U64(x), DScalar::U64(x) => x);
435impl_into_dscalar!(DBool, DType::Bool, Bool, |x| DScalar::Bool(x), DScalar::Bool(x) => x);
436
437#[rustfmt::skip]
438#[macro_export]
439macro_rules! dispatch_dtensor {
440    ($outer:expr, |$ty:ident, $f:ident, $inner:ident| $expr:expr) => {{
441        use $crate::dtype::{DBool, DTensor};
442        match $outer {
443            DTensor::F32($inner) => { type $ty=f32; let $f=DTensor::F32; { $expr } }
444            DTensor::F64($inner) => { type $ty=f64; let $f=DTensor::F64; { $expr } }
445            DTensor::I8($inner) => { type $ty=i8; let $f=DTensor::I8; { $expr } }
446            DTensor::I16($inner) => { type $ty=i16; let $f=DTensor::I16; { $expr } }
447            DTensor::I32($inner) => { type $ty=i32; let $f=DTensor::I32; { $expr } }
448            DTensor::I64($inner) => { type $ty=i64; let $f=DTensor::I64; { $expr } }
449            DTensor::U8($inner) => { type $ty=u8; let $f=DTensor::U8; { $expr } }
450            DTensor::U16($inner) => { type $ty=u16; let $f=DTensor::U16; { $expr } }
451            DTensor::U32($inner) => { type $ty=u32; let $f=DTensor::U32; { $expr } }
452            DTensor::U64($inner) => { type $ty=u64; let $f=DTensor::U64; { $expr } }
453            DTensor::Bool($inner) => { type $ty=DBool; let $f=DTensor::Bool; { $expr } }
454        }
455    }};
456}
457
458#[rustfmt::skip]
459#[macro_export]
460macro_rules! dispatch_dtensor_pair {
461    ($out_left:expr, $out_right:expr, |$ty:ident, $f:ident, $in_left:ident, $in_right:ident| $expr:expr) => {{
462        use $crate::dtype::{DBool, DTensor};
463
464        let out_left = $out_left;
465        let out_right = $out_right;
466        let dtype_left = out_left.dtype();
467        let dtype_right = out_right.dtype();
468        
469        match (out_left, out_right) {
470            (DTensor::F32($in_left), DTensor::F32($in_right)) => { type $ty=f32; let $f=DTensor::F32; { $expr } }
471            (DTensor::I8($in_left), DTensor::I8($in_right)) => { type $ty=i8; let $f=DTensor::I8; { $expr } }
472            (DTensor::I16($in_left), DTensor::I16($in_right)) => { type $ty=i16; let $f=DTensor::I16; { $expr } }
473            (DTensor::I32($in_left), DTensor::I32($in_right)) => { type $ty=i32; let $f=DTensor::I32; { $expr } }
474            (DTensor::I64($in_left), DTensor::I64($in_right)) => { type $ty=i64; let $f=DTensor::I64; { $expr } }
475            (DTensor::U8($in_left), DTensor::U8($in_right)) => { type $ty=u8; let $f=DTensor::U8; { $expr } }
476            (DTensor::U16($in_left), DTensor::U16($in_right)) => { type $ty=u16; let $f=DTensor::U16; { $expr } }
477            (DTensor::U32($in_left), DTensor::U32($in_right)) => { type $ty=u32; let $f=DTensor::U32; { $expr } }
478            (DTensor::U64($in_left), DTensor::U64($in_right)) => { type $ty=u64; let $f=DTensor::U64; { $expr } }
479            (DTensor::Bool($in_left), DTensor::Bool($in_right)) => { type $ty=DBool; let $f=DTensor::Bool; { $expr } }
480            _ => panic!("Mismatched dtypes: left {:?}, right {:?}", dtype_left, dtype_right),
481        }
482    }};
483}
484
485#[macro_export]
486macro_rules! map_dtensor {
487    ($outer:expr, |$inner:ident| $expr:expr) => {
488        crate::dtype::dispatch_dtensor!($outer, |_T, f, $inner| f($expr))
489    };
490}
491
492#[macro_export]
493macro_rules! map_dtensor_pair {
494    ($out_left:expr, $out_right:expr, |$in_left:ident, $in_right:ident| $expr:expr) => {
495        crate::dtype::dispatch_dtensor_pair!($out_left, $out_right, |_T, f, $in_left, $in_right| f($expr))
496    };
497}
498
499#[rustfmt::skip]
500#[macro_export]
501macro_rules! map_dscalar_pair {
502    ($out_left:expr, $out_right:expr, |$in_left:ident, $in_right:ident| $expr:expr) => {{
503        use crate::dtype::{DScalar, T32};
504        
505        let out_left = $out_left;
506        let out_right = $out_right;
507        
508        match (out_left, out_right) {
509            (DScalar::F32(T32($in_left)), DScalar::F32(T32($in_right))) => DScalar::F32(T32($expr)),
510            (DScalar::I8($in_left), DScalar::I8($in_right)) => DScalar::I8($expr),
511            (DScalar::I16($in_left), DScalar::I16($in_right)) => DScalar::I16($expr),
512            (DScalar::I32($in_left), DScalar::I32($in_right)) => DScalar::I32($expr),
513            (DScalar::I64($in_left), DScalar::I64($in_right)) => DScalar::I64($expr),
514            (DScalar::U8($in_left), DScalar::U8($in_right)) => DScalar::U8($expr),
515            (DScalar::U16($in_left), DScalar::U16($in_right)) => DScalar::U16($expr),
516            (DScalar::U32($in_left), DScalar::U32($in_right)) => DScalar::U32($expr),
517            (DScalar::U64($in_left), DScalar::U64($in_right)) => DScalar::U64($expr),
518            (DScalar::Bool($in_left), DScalar::Bool($in_right)) => DScalar::Bool($expr),
519            _ => panic!("Mismatched dtypes: left {:?}, right {:?}", out_left, out_right),
520        }
521    }}
522}
523
524// export macros
525pub use dispatch_dtensor;
526pub use dispatch_dtensor_pair;
527pub use dispatch_dtype;
528pub use map_dscalar_pair;
529pub use map_dtensor;
530pub use map_dtensor_pair;
531
532impl DTensor {
533    pub fn shape(&self) -> &[usize] {
534        dispatch_dtensor!(self, |_T, _f, inner| inner.shape())
535    }
536
537    pub fn rank(&self) -> usize {
538        self.shape().len()
539    }
540
541    pub fn len(&self) -> usize {
542        self.shape().iter().copied().product()
543    }
544
545    pub fn dtype(&self) -> DType {
546        dispatch_dtensor!(self, |T, _f, _i| T::DTYPE)
547    }
548
549    pub fn reshape<E: IntoDimension>(&self, shape: E) -> DTensor {
550        map_dtensor!(self, |inner| inner.reshape(shape).into_dyn())
551    }
552
553    pub fn single(&self) -> Option<DScalar> {
554        if self.len() == 1 {
555            Some(dispatch_dtensor!(self, |_T, _f, inner| inner.iter().next().unwrap().to_dscalar()))
556        } else {
557            None
558        }
559    }
560
561    // TODO generic unwrap function?
562    pub fn unwrap_f32(&self) -> Option<&Tensor<f32>> {
563        match self {
564            DTensor::F32(tensor) => Some(tensor),
565            _ => None,
566        }
567    }
568
569    pub fn unwrap_f64(&self) -> Option<&Tensor<f64>> {
570        match self {
571            DTensor::F64(tensor) => Some(tensor),
572            _ => None,
573        }
574    }
575
576    pub fn unwrap_i64(&self) -> Option<&Tensor<i64>> {
577        match self {
578            DTensor::I64(tensor) => Some(tensor),
579            _ => None,
580        }
581    }
582
583    pub fn unwrap_bool(&self) -> Option<&Tensor<DBool>> {
584        match self {
585            DTensor::Bool(tensor) => Some(tensor),
586            _ => None,
587        }
588    }
589}
590
591impl Eq for DTensor {}
592
593impl PartialEq for DTensor {
594    fn eq(&self, other: &Self) -> bool {
595        if self.shape() != other.shape() || self.dtype() != other.dtype() {
596            return false;
597        }
598
599        match (self, other) {
600            // proper float compare
601            (DTensor::F32(a), DTensor::F32(b)) => zip_eq(a.iter(), b.iter()).all(|(a, b)| a.float_eq(b)),
602            (DTensor::F64(a), DTensor::F64(b)) => zip_eq(a.iter(), b.iter()).all(|(a, b)| a.float_eq(b)),
603
604            // ints and bools can be compared like normal
605            (DTensor::I8(a), DTensor::I8(b)) => a == b,
606            (DTensor::I16(a), DTensor::I16(b)) => a == b,
607            (DTensor::I32(a), DTensor::I32(b)) => a == b,
608            (DTensor::I64(a), DTensor::I64(b)) => a == b,
609            (DTensor::U8(a), DTensor::U8(b)) => a == b,
610            (DTensor::U16(a), DTensor::U16(b)) => a == b,
611            (DTensor::U32(a), DTensor::U32(b)) => a == b,
612            (DTensor::U64(a), DTensor::U64(b)) => a == b,
613            (DTensor::Bool(a), DTensor::Bool(b)) => a == b,
614
615            // different types, not equal
616            _ => false,
617        }
618    }
619}
620
621impl Hash for DTensor {
622    fn hash<H: Hasher>(&self, state: &mut H) {
623        // hash shape, dtype and some of the first elements
624        // (not all of them since that could be slow for large tensors)
625        // TODO figure out how to include some of the middle and final elements too?
626        const N: usize = 8;
627
628        self.shape().hash(state);
629        self.dtype().hash(state);
630
631        match self {
632            DTensor::F32(tensor) => tensor.iter().take(N).for_each(|x| x.float_hash(state)),
633            DTensor::F64(tensor) => tensor.iter().take(N).for_each(|x| x.float_hash(state)),
634            DTensor::I8(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
635            DTensor::I16(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
636            DTensor::I32(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
637            DTensor::I64(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
638            DTensor::U8(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
639            DTensor::U16(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
640            DTensor::U32(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
641            DTensor::U64(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
642            DTensor::Bool(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
643        }
644    }
645}
646
647impl Deref for T32 {
648    type Target = f32;
649
650    fn deref(&self) -> &Self::Target {
651        &self.0
652    }
653}
654
655impl PartialEq<Self> for T32 {
656    fn eq(&self, other: &Self) -> bool {
657        self.0.float_eq(&other.0)
658    }
659}
660
661impl Eq for T32 {}
662
663impl Hash for T32 {
664    fn hash<H: Hasher>(&self, state: &mut H) {
665        self.0.float_hash(state)
666    }
667}
668
669impl Deref for T64 {
670    type Target = f64;
671
672    fn deref(&self) -> &Self::Target {
673        &self.0
674    }
675}
676
677impl PartialEq<Self> for T64 {
678    fn eq(&self, other: &Self) -> bool {
679        self.0.float_eq(&other.0)
680    }
681}
682
683impl Eq for T64 {}
684
685impl Hash for T64 {
686    fn hash<H: Hasher>(&self, state: &mut H) {
687        self.0.float_hash(state)
688    }
689}
690
691impl Specials {
692    pub fn new<T: IntoDScalar + num_traits::Zero + num_traits::One>(min: T, max: T) -> Self {
693        Self {
694            zero: T::zero().to_dscalar(),
695            one: T::one().to_dscalar(),
696            min: min.to_dscalar(),
697            max: max.to_dscalar(),
698        }
699    }
700}
701
702#[derive(Debug)]
703pub struct DisplayCFloat(pub f64);
704
705impl Display for DisplayCFloat {
706    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
707        let s = if self.0.is_sign_negative() { "-" } else { "" };
708
709        match self.0.classify() {
710            FpCategory::Nan => write!(f, "({s}(0.0/0.0))"),
711            FpCategory::Infinite => write!(f, "({s}(1.0/0.0))"),
712            FpCategory::Zero => write!(f, "({s}0.0)"),
713            FpCategory::Subnormal | FpCategory::Normal => write!(f, "{}", self.0),
714        }
715    }
716}
717
718impl Deref for DBool {
719    type Target = bool;
720
721    fn deref(&self) -> &Self::Target {
722        &self.0
723    }
724}
725
726impl std::ops::Add for DBool {
727    type Output = DBool;
728
729    fn add(self, rhs: Self) -> Self::Output {
730        DBool(self.0 || rhs.0)
731    }
732}
733
734impl std::ops::Mul for DBool {
735    type Output = DBool;
736
737    fn mul(self, rhs: Self) -> Self::Output {
738        DBool(self.0 && rhs.0)
739    }
740}
741
742// sub and div don't make much sense
743impl std::ops::Sub for DBool {
744    type Output = DBool;
745
746    fn sub(self, rhs: Self) -> Self::Output {
747        DBool(self.0 && !rhs.0)
748    }
749}
750
751impl std::ops::Div for DBool {
752    type Output = DBool;
753
754    fn div(self, rhs: Self) -> Self::Output {
755        DBool(self.0 && !rhs.0)
756    }
757}
758
759impl num_traits::Zero for DBool {
760    fn zero() -> Self {
761        DBool(false)
762    }
763
764    fn is_zero(&self) -> bool {
765        !self.0
766    }
767}
768
769impl num_traits::One for DBool {
770    fn one() -> Self {
771        DBool(true)
772    }
773}
774
775unsafe impl NoUninit for DBool {}