1use std::fmt::{Display, Formatter};
2use std::hash::{Hash, Hasher};
3use std::num::FpCategory;
4use std::ops::Deref;
5
6use bytemuck::NoUninit;
7use decorum::cmp::FloatEq;
8use decorum::hash::FloatHash;
9use itertools::zip_eq;
10use ndarray::{ArcArray, IntoDimension, IxDyn, LinalgScalar};
11
12#[derive(Debug, Copy, Clone)]
13pub struct T32(pub f32);
14
15#[derive(Debug, Copy, Clone)]
16pub struct T64(pub f64);
17
18#[derive(Debug, Copy, Clone, Eq, Ord, PartialOrd, PartialEq, Hash)]
21#[repr(transparent)]
22pub struct DBool(pub bool);
23
24#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
25pub enum DType {
26 F32,
27 F64,
28 I8,
29 I16,
30 I32,
31 I64,
32 U8,
33 U16,
34 U32,
35 U64,
36 Bool,
37}
38
39pub type Tensor<T> = ArcArray<T, IxDyn>;
40
41#[derive(Debug, Clone)]
42pub enum DTensor {
43 F32(Tensor<f32>),
44 F64(Tensor<f64>),
45 I8(Tensor<i8>),
46 I16(Tensor<i16>),
47 I32(Tensor<i32>),
48 I64(Tensor<i64>),
49 U8(Tensor<u8>),
50 U16(Tensor<u16>),
51 U32(Tensor<u32>),
52 U64(Tensor<u64>),
53 Bool(Tensor<DBool>),
54}
55
56#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
57pub enum DScalar {
58 F32(T32),
59 F64(T64),
60 I8(i8),
61 I16(i16),
62 I32(i32),
63 I64(i64),
64 U8(u8),
65 U16(u16),
66 U32(u32),
67 U64(u64),
68 Bool(DBool),
69}
70
71#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
72pub enum DSize {
73 S8,
74 S16,
75 S32,
76 S64,
77}
78
79#[derive(Debug, Copy, Clone, Eq, PartialEq)]
80pub struct Specials {
81 pub zero: DScalar,
82 pub one: DScalar,
83 pub min: DScalar,
84 pub max: DScalar,
85}
86
87#[derive(Debug, Copy, Clone, Eq, PartialEq)]
88pub struct DInfo {
89 pub size: DSize,
90 pub signed: bool,
91 pub float: bool,
92 pub int: bool,
93 pub is_bool: bool,
94}
95
96impl DType {
97 pub fn info(self) -> DInfo {
98 match self {
99 DType::F32 => DInfo::float(DSize::S32),
100 DType::F64 => DInfo::float(DSize::S64),
101 DType::I8 => DInfo::int(DSize::S8, true),
102 DType::I16 => DInfo::int(DSize::S16, true),
103 DType::I32 => DInfo::int(DSize::S32, true),
104 DType::I64 => DInfo::int(DSize::S64, true),
105 DType::U8 => DInfo::int(DSize::S8, false),
106 DType::U16 => DInfo::int(DSize::S16, false),
107 DType::U32 => DInfo::int(DSize::S32, false),
108 DType::U64 => DInfo::int(DSize::S64, false),
109 DType::Bool => DInfo::bool(),
110 }
111 }
112
113 pub fn size(self) -> DSize {
114 self.info().size
115 }
116
117 pub fn is_signed(self) -> bool {
118 self.info().signed
119 }
120
121 pub fn is_float(self) -> bool {
122 self.info().float
123 }
124
125 pub fn is_int(self) -> bool {
126 self.info().int
127 }
128
129 pub fn is_bool(self) -> bool {
130 self.info().is_bool
131 }
132
133 pub fn specials(self) -> Specials {
135 match self {
136 DType::F32 => Specials::new(f32::NEG_INFINITY, f32::INFINITY),
137 DType::F64 => Specials::new(f64::NEG_INFINITY, f64::INFINITY),
138 DType::I8 => Specials::new(i8::MIN, i8::MAX),
139 DType::I16 => Specials::new(i16::MIN, i16::MAX),
140 DType::I32 => Specials::new(i32::MIN, i32::MAX),
141 DType::I64 => Specials::new(i64::MIN, i64::MAX),
142 DType::U8 => Specials::new(u8::MIN, u8::MAX),
143 DType::U16 => Specials::new(u16::MIN, u16::MAX),
144 DType::U32 => Specials::new(u32::MIN, u32::MAX),
145 DType::U64 => Specials::new(u64::MIN, u64::MAX),
146 DType::Bool => Specials::new(DBool(false), DBool(true)),
147 }
148 }
149
150 pub fn as_c_str(self) -> &'static str {
151 match self {
152 DType::F32 => "float",
153 DType::F64 => "double",
154 DType::I8 => "int8_t",
155 DType::I16 => "int16_t",
156 DType::I32 => "int32_t",
157 DType::I64 => "int64_t",
158 DType::U8 => "uint8_t",
159 DType::U16 => "uint16_t",
160 DType::U32 => "uint32_t",
161 DType::U64 => "uint64_t",
162 DType::Bool => "bool",
163 }
164 }
165}
166
167impl DInfo {
168 fn int(size: DSize, signed: bool) -> Self {
169 DInfo {
170 size,
171 signed,
172 float: false,
173 int: true,
174 is_bool: false,
175 }
176 }
177
178 fn float(size: DSize) -> Self {
179 DInfo {
180 size,
181 signed: true,
182 float: true,
183 int: false,
184 is_bool: false,
185 }
186 }
187
188 fn bool() -> Self {
189 DInfo {
190 size: DSize::S8,
191 signed: false,
192 float: false,
193 int: false,
194 is_bool: true,
195 }
196 }
197}
198
199impl DSize {
200 pub fn bytes(self) -> usize {
201 match self {
202 DSize::S8 => 1,
203 DSize::S16 => 2,
204 DSize::S32 => 4,
205 DSize::S64 => 8,
206 }
207 }
208}
209
210#[rustfmt::skip]
211#[macro_export]
212macro_rules! dispatch_dtype {
213 ($outer:expr, |$ty:ident, $fs:ident, $ft:ident| $expr:expr) => {{
214 use $crate::dtype::{DType, DBool, DScalar, DTensor};
215 match $outer {
216 DType::F32 => { type $ty=f32; let $fs=DScalar::F32; let $ft=DTensor::F32; { $expr } }
217 DType::F64 => { type $ty=f64; let $fs=DScalar::F64; let $ft=DTensor::F64; { $expr } }
218 DType::I8 => { type $ty=i8; let $fs=DScalar::I8; let $ft=DTensor::I8; { $expr } }
219 DType::I16 => { type $ty=i16; let $fs=DScalar::I16; let $ft=DTensor::I16; { $expr } }
220 DType::I32 => { type $ty=i32; let $fs=DScalar::I32; let $ft=DTensor::I32; { $expr } }
221 DType::I64 => { type $ty=i64; let $fs=DScalar::I64; let $ft=DTensor::I64; { $expr } }
222 DType::U8 => { type $ty=u8; let $fs=DScalar::U8; let $ft=DTensor::U8; { $expr } }
223 DType::U16 => { type $ty=u16; let $fs=DScalar::U16; let $ft=DTensor::U16; { $expr } }
224 DType::U32 => { type $ty=u32; let $fs=DScalar::U32; let $ft=DTensor::U32; { $expr } }
225 DType::U64 => { type $ty=u64; let $fs=DScalar::U64; let $ft=DTensor::U64; { $expr } }
226 DType::Bool => { type $ty=DBool; let $fs=DScalar::Bool; let $ft=DTensor::Bool; { $expr } }
227 }
228 }};
229}
230
231impl DScalar {
232 pub fn f32(x: f32) -> Self {
233 DScalar::F32(T32(x))
234 }
235
236 pub fn f64(x: f64) -> Self {
237 DScalar::F64(T64(x))
238 }
239
240 pub fn bool(x: bool) -> Self {
241 DScalar::Bool(DBool(x))
242 }
243
244 pub fn dtype(self) -> DType {
245 match self {
246 DScalar::F32(_) => DType::F32,
247 DScalar::F64(_) => DType::F64,
248 DScalar::I8(_) => DType::I8,
249 DScalar::I16(_) => DType::I16,
250 DScalar::I32(_) => DType::I32,
251 DScalar::I64(_) => DType::I64,
252 DScalar::U8(_) => DType::U8,
253 DScalar::U16(_) => DType::U16,
254 DScalar::U32(_) => DType::U32,
255 DScalar::U64(_) => DType::U64,
256 DScalar::Bool(_) => DType::Bool,
257 }
258 }
259
260 pub fn unwrap_f32(self) -> Option<f32> {
261 match self {
262 DScalar::F32(x) => Some(x.0),
263 _ => None,
264 }
265 }
266
267 pub fn to_tensor(self) -> DTensor {
268 match self {
269 DScalar::F32(T32(s)) => DTensor::F32(ArcArray::from_shape_vec(IxDyn(&[]), vec![s]).unwrap()),
270 DScalar::F64(T64(s)) => DTensor::F64(ArcArray::from_shape_vec(IxDyn(&[]), vec![s]).unwrap()),
271 DScalar::I8(x) => DTensor::I8(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
272 DScalar::I16(x) => DTensor::I16(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
273 DScalar::I32(x) => DTensor::I32(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
274 DScalar::I64(x) => DTensor::I64(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
275 DScalar::U8(x) => DTensor::U8(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
276 DScalar::U16(x) => DTensor::U16(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
277 DScalar::U32(x) => DTensor::U32(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
278 DScalar::U64(x) => DTensor::U64(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
279 DScalar::Bool(x) => DTensor::Bool(ArcArray::from_shape_vec(IxDyn(&[]), vec![x]).unwrap()),
280 }
281 }
282
283 pub fn unwrap_uint(self) -> Option<u64> {
284 match self {
285 DScalar::U8(x) => Some(x as u64),
286 DScalar::U16(x) => Some(x as u64),
287 DScalar::U32(x) => Some(x as u64),
288 DScalar::U64(x) => Some(x),
289 _ => None,
290 }
291 }
292
293 pub fn unwrap_int(self) -> Option<i128> {
294 match self {
295 DScalar::U8(x) => Some(x as i128),
296 DScalar::U16(x) => Some(x as i128),
297 DScalar::U32(x) => Some(x as i128),
298 DScalar::U64(x) => Some(x as i128),
299 DScalar::I8(x) => Some(x as i128),
300 DScalar::I16(x) => Some(x as i128),
301 DScalar::I32(x) => Some(x as i128),
302 DScalar::I64(x) => Some(x as i128),
303 _ => None,
304 }
305 }
306
307 pub fn to_c_str(self) -> String {
308 match self {
309 DScalar::F32(c) => DisplayCFloat(*c as f64).to_string(),
310 DScalar::F64(c) => DisplayCFloat(*c).to_string(),
311 DScalar::U8(c) => format!("{}", c),
312 DScalar::U16(c) => format!("{}", c),
313 DScalar::U32(c) => format!("{}", c),
314 DScalar::U64(c) => format!("{}", c),
315 DScalar::I8(c) => format!("{}", c),
316 DScalar::I16(c) => format!("{}", c),
317 DScalar::I32(c) => format!("{}", c),
318 DScalar::I64(c) => format!("{}", c),
319 DScalar::Bool(c) => format!("{}", *c),
320 }
321 }
322
323 pub fn value_cast(self, to: DType) -> DScalar {
324 let (yf, yi) = match self {
326 DScalar::F32(T32(x)) => (x as f64, x as i128),
327 DScalar::F64(T64(x)) => (x, x as i128),
328 DScalar::I8(x) => (x as f64, x as i128),
329 DScalar::I16(x) => (x as f64, x as i128),
330 DScalar::I32(x) => (x as f64, x as i128),
331 DScalar::I64(x) => (x as f64, x as i128),
332 DScalar::U8(x) => (x as f64, x as i128),
333 DScalar::U16(x) => (x as f64, x as i128),
334 DScalar::U32(x) => (x as f64, x as i128),
335 DScalar::U64(x) => (x as f64, x as i128),
336 DScalar::Bool(DBool(x)) => (x as u8 as f64, x as u8 as i128),
337 };
338
339 match to {
341 DType::F32 => DScalar::f32(yf as f32),
342 DType::F64 => DScalar::f64(yf),
343 DType::I8 => DScalar::I8(yi as i8),
344 DType::I16 => DScalar::I16(yi as i16),
345 DType::I32 => DScalar::I32(yi as i32),
346 DType::I64 => DScalar::I64(yi as i64),
347 DType::U8 => DScalar::U8(yi as u8),
348 DType::U16 => DScalar::U16(yi as u16),
349 DType::U32 => DScalar::U32(yi as u32),
350 DType::U64 => DScalar::U64(yi as u64),
351 DType::Bool => DScalar::bool(yf != 0.0 || yi != 0),
352 }
353 }
354
355 pub fn bit_cast(self, to: DType) -> Option<DScalar> {
356 if self.dtype().size() != to.size() {
357 return None;
358 }
359
360 let bits = match self {
362 DScalar::F32(T32(x)) => x.to_bits() as u64,
363 DScalar::F64(T64(x)) => x.to_bits(),
364 DScalar::I8(x) => x as u8 as u64,
365 DScalar::I16(x) => x as u16 as u64,
366 DScalar::I32(x) => x as u32 as u64,
367 DScalar::I64(x) => x as u64,
368 DScalar::U8(x) => x as u64,
369 DScalar::U16(x) => x as u64,
370 DScalar::U32(x) => x as u64,
371 DScalar::U64(x) => x,
372 DScalar::Bool(_) => return None,
373 };
374
375 let y = match to {
377 DType::F32 => DScalar::f32(f32::from_bits(bits as u32)),
378 DType::F64 => DScalar::f64(f64::from_bits(bits)),
379 DType::I8 => DScalar::I8(bits as i8),
380 DType::I16 => DScalar::I16(bits as i16),
381 DType::I32 => DScalar::I32(bits as i32),
382 DType::I64 => DScalar::I64(bits as i64),
383 DType::U8 => DScalar::U8(bits as u8),
384 DType::U16 => DScalar::U16(bits as u16),
385 DType::U32 => DScalar::U32(bits as u32),
386 DType::U64 => DScalar::U64(bits),
387 DType::Bool => return None,
388 };
389
390 Some(y)
391 }
392}
393
394pub trait IntoDScalar: LinalgScalar + PartialEq {
395 const DTYPE: DType;
396 fn to_dscalar(&self) -> DScalar;
397 fn from_dscalar(scalar: DScalar) -> Option<Self>;
398 fn vec_to_dtensor(data: Vec<Self>) -> DTensor;
399}
400
401macro_rules! impl_into_dscalar {
402 ($ty:ty, $dtype:expr, $dtensor:ident, |$x:ident| $conv:expr, $pattern:pat => $result:expr) => {
403 impl IntoDScalar for $ty {
404 const DTYPE: DType = $dtype;
405
406 fn to_dscalar(&self) -> DScalar {
407 let &$x = self;
408 $conv
409 }
410
411 fn from_dscalar(scalar: DScalar) -> Option<Self> {
412 match scalar {
413 $pattern => Some($result),
414 _ => None,
415 }
416 }
417
418 fn vec_to_dtensor(data: Vec<Self>) -> DTensor {
419 DTensor::$dtensor(ArcArray::from_vec(data).into_dyn())
420 }
421 }
422 };
423}
424
425impl_into_dscalar!(f32, DType::F32, F32, |x| DScalar::f32(x), DScalar::F32(T32(x)) => x);
426impl_into_dscalar!(f64, DType::F64, F64, |x| DScalar::f64(x), DScalar::F64(T64(x)) => x);
427impl_into_dscalar!(i8, DType::I8, I8, |x| DScalar::I8(x), DScalar::I8(x) => x);
428impl_into_dscalar!(i16, DType::I16, I16, |x| DScalar::I16(x), DScalar::I16(x) => x);
429impl_into_dscalar!(i32, DType::I32, I32, |x| DScalar::I32(x), DScalar::I32(x) => x);
430impl_into_dscalar!(i64, DType::I64, I64, |x| DScalar::I64(x), DScalar::I64(x) => x);
431impl_into_dscalar!(u8, DType::U8, U8, |x| DScalar::U8(x), DScalar::U8(x) => x);
432impl_into_dscalar!(u16, DType::U16, U16, |x| DScalar::U16(x), DScalar::U16(x) => x);
433impl_into_dscalar!(u32, DType::U32, U32, |x| DScalar::U32(x), DScalar::U32(x) => x);
434impl_into_dscalar!(u64, DType::U64, U64, |x| DScalar::U64(x), DScalar::U64(x) => x);
435impl_into_dscalar!(DBool, DType::Bool, Bool, |x| DScalar::Bool(x), DScalar::Bool(x) => x);
436
437#[rustfmt::skip]
438#[macro_export]
439macro_rules! dispatch_dtensor {
440 ($outer:expr, |$ty:ident, $f:ident, $inner:ident| $expr:expr) => {{
441 use $crate::dtype::{DBool, DTensor};
442 match $outer {
443 DTensor::F32($inner) => { type $ty=f32; let $f=DTensor::F32; { $expr } }
444 DTensor::F64($inner) => { type $ty=f64; let $f=DTensor::F64; { $expr } }
445 DTensor::I8($inner) => { type $ty=i8; let $f=DTensor::I8; { $expr } }
446 DTensor::I16($inner) => { type $ty=i16; let $f=DTensor::I16; { $expr } }
447 DTensor::I32($inner) => { type $ty=i32; let $f=DTensor::I32; { $expr } }
448 DTensor::I64($inner) => { type $ty=i64; let $f=DTensor::I64; { $expr } }
449 DTensor::U8($inner) => { type $ty=u8; let $f=DTensor::U8; { $expr } }
450 DTensor::U16($inner) => { type $ty=u16; let $f=DTensor::U16; { $expr } }
451 DTensor::U32($inner) => { type $ty=u32; let $f=DTensor::U32; { $expr } }
452 DTensor::U64($inner) => { type $ty=u64; let $f=DTensor::U64; { $expr } }
453 DTensor::Bool($inner) => { type $ty=DBool; let $f=DTensor::Bool; { $expr } }
454 }
455 }};
456}
457
458#[rustfmt::skip]
459#[macro_export]
460macro_rules! dispatch_dtensor_pair {
461 ($out_left:expr, $out_right:expr, |$ty:ident, $f:ident, $in_left:ident, $in_right:ident| $expr:expr) => {{
462 use $crate::dtype::{DBool, DTensor};
463
464 let out_left = $out_left;
465 let out_right = $out_right;
466 let dtype_left = out_left.dtype();
467 let dtype_right = out_right.dtype();
468
469 match (out_left, out_right) {
470 (DTensor::F32($in_left), DTensor::F32($in_right)) => { type $ty=f32; let $f=DTensor::F32; { $expr } }
471 (DTensor::I8($in_left), DTensor::I8($in_right)) => { type $ty=i8; let $f=DTensor::I8; { $expr } }
472 (DTensor::I16($in_left), DTensor::I16($in_right)) => { type $ty=i16; let $f=DTensor::I16; { $expr } }
473 (DTensor::I32($in_left), DTensor::I32($in_right)) => { type $ty=i32; let $f=DTensor::I32; { $expr } }
474 (DTensor::I64($in_left), DTensor::I64($in_right)) => { type $ty=i64; let $f=DTensor::I64; { $expr } }
475 (DTensor::U8($in_left), DTensor::U8($in_right)) => { type $ty=u8; let $f=DTensor::U8; { $expr } }
476 (DTensor::U16($in_left), DTensor::U16($in_right)) => { type $ty=u16; let $f=DTensor::U16; { $expr } }
477 (DTensor::U32($in_left), DTensor::U32($in_right)) => { type $ty=u32; let $f=DTensor::U32; { $expr } }
478 (DTensor::U64($in_left), DTensor::U64($in_right)) => { type $ty=u64; let $f=DTensor::U64; { $expr } }
479 (DTensor::Bool($in_left), DTensor::Bool($in_right)) => { type $ty=DBool; let $f=DTensor::Bool; { $expr } }
480 _ => panic!("Mismatched dtypes: left {:?}, right {:?}", dtype_left, dtype_right),
481 }
482 }};
483}
484
485#[macro_export]
486macro_rules! map_dtensor {
487 ($outer:expr, |$inner:ident| $expr:expr) => {
488 crate::dtype::dispatch_dtensor!($outer, |_T, f, $inner| f($expr))
489 };
490}
491
492#[macro_export]
493macro_rules! map_dtensor_pair {
494 ($out_left:expr, $out_right:expr, |$in_left:ident, $in_right:ident| $expr:expr) => {
495 crate::dtype::dispatch_dtensor_pair!($out_left, $out_right, |_T, f, $in_left, $in_right| f($expr))
496 };
497}
498
499#[rustfmt::skip]
500#[macro_export]
501macro_rules! map_dscalar_pair {
502 ($out_left:expr, $out_right:expr, |$in_left:ident, $in_right:ident| $expr:expr) => {{
503 use crate::dtype::{DScalar, T32};
504
505 let out_left = $out_left;
506 let out_right = $out_right;
507
508 match (out_left, out_right) {
509 (DScalar::F32(T32($in_left)), DScalar::F32(T32($in_right))) => DScalar::F32(T32($expr)),
510 (DScalar::I8($in_left), DScalar::I8($in_right)) => DScalar::I8($expr),
511 (DScalar::I16($in_left), DScalar::I16($in_right)) => DScalar::I16($expr),
512 (DScalar::I32($in_left), DScalar::I32($in_right)) => DScalar::I32($expr),
513 (DScalar::I64($in_left), DScalar::I64($in_right)) => DScalar::I64($expr),
514 (DScalar::U8($in_left), DScalar::U8($in_right)) => DScalar::U8($expr),
515 (DScalar::U16($in_left), DScalar::U16($in_right)) => DScalar::U16($expr),
516 (DScalar::U32($in_left), DScalar::U32($in_right)) => DScalar::U32($expr),
517 (DScalar::U64($in_left), DScalar::U64($in_right)) => DScalar::U64($expr),
518 (DScalar::Bool($in_left), DScalar::Bool($in_right)) => DScalar::Bool($expr),
519 _ => panic!("Mismatched dtypes: left {:?}, right {:?}", out_left, out_right),
520 }
521 }}
522}
523
524pub use dispatch_dtensor;
526pub use dispatch_dtensor_pair;
527pub use dispatch_dtype;
528pub use map_dscalar_pair;
529pub use map_dtensor;
530pub use map_dtensor_pair;
531
532impl DTensor {
533 pub fn shape(&self) -> &[usize] {
534 dispatch_dtensor!(self, |_T, _f, inner| inner.shape())
535 }
536
537 pub fn rank(&self) -> usize {
538 self.shape().len()
539 }
540
541 pub fn len(&self) -> usize {
542 self.shape().iter().copied().product()
543 }
544
545 pub fn dtype(&self) -> DType {
546 dispatch_dtensor!(self, |T, _f, _i| T::DTYPE)
547 }
548
549 pub fn reshape<E: IntoDimension>(&self, shape: E) -> DTensor {
550 map_dtensor!(self, |inner| inner.reshape(shape).into_dyn())
551 }
552
553 pub fn single(&self) -> Option<DScalar> {
554 if self.len() == 1 {
555 Some(dispatch_dtensor!(self, |_T, _f, inner| inner.iter().next().unwrap().to_dscalar()))
556 } else {
557 None
558 }
559 }
560
561 pub fn unwrap_f32(&self) -> Option<&Tensor<f32>> {
563 match self {
564 DTensor::F32(tensor) => Some(tensor),
565 _ => None,
566 }
567 }
568
569 pub fn unwrap_f64(&self) -> Option<&Tensor<f64>> {
570 match self {
571 DTensor::F64(tensor) => Some(tensor),
572 _ => None,
573 }
574 }
575
576 pub fn unwrap_i64(&self) -> Option<&Tensor<i64>> {
577 match self {
578 DTensor::I64(tensor) => Some(tensor),
579 _ => None,
580 }
581 }
582
583 pub fn unwrap_bool(&self) -> Option<&Tensor<DBool>> {
584 match self {
585 DTensor::Bool(tensor) => Some(tensor),
586 _ => None,
587 }
588 }
589}
590
591impl Eq for DTensor {}
592
593impl PartialEq for DTensor {
594 fn eq(&self, other: &Self) -> bool {
595 if self.shape() != other.shape() || self.dtype() != other.dtype() {
596 return false;
597 }
598
599 match (self, other) {
600 (DTensor::F32(a), DTensor::F32(b)) => zip_eq(a.iter(), b.iter()).all(|(a, b)| a.float_eq(b)),
602 (DTensor::F64(a), DTensor::F64(b)) => zip_eq(a.iter(), b.iter()).all(|(a, b)| a.float_eq(b)),
603
604 (DTensor::I8(a), DTensor::I8(b)) => a == b,
606 (DTensor::I16(a), DTensor::I16(b)) => a == b,
607 (DTensor::I32(a), DTensor::I32(b)) => a == b,
608 (DTensor::I64(a), DTensor::I64(b)) => a == b,
609 (DTensor::U8(a), DTensor::U8(b)) => a == b,
610 (DTensor::U16(a), DTensor::U16(b)) => a == b,
611 (DTensor::U32(a), DTensor::U32(b)) => a == b,
612 (DTensor::U64(a), DTensor::U64(b)) => a == b,
613 (DTensor::Bool(a), DTensor::Bool(b)) => a == b,
614
615 _ => false,
617 }
618 }
619}
620
621impl Hash for DTensor {
622 fn hash<H: Hasher>(&self, state: &mut H) {
623 const N: usize = 8;
627
628 self.shape().hash(state);
629 self.dtype().hash(state);
630
631 match self {
632 DTensor::F32(tensor) => tensor.iter().take(N).for_each(|x| x.float_hash(state)),
633 DTensor::F64(tensor) => tensor.iter().take(N).for_each(|x| x.float_hash(state)),
634 DTensor::I8(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
635 DTensor::I16(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
636 DTensor::I32(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
637 DTensor::I64(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
638 DTensor::U8(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
639 DTensor::U16(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
640 DTensor::U32(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
641 DTensor::U64(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
642 DTensor::Bool(tensor) => tensor.iter().take(N).for_each(|x| x.hash(state)),
643 }
644 }
645}
646
647impl Deref for T32 {
648 type Target = f32;
649
650 fn deref(&self) -> &Self::Target {
651 &self.0
652 }
653}
654
655impl PartialEq<Self> for T32 {
656 fn eq(&self, other: &Self) -> bool {
657 self.0.float_eq(&other.0)
658 }
659}
660
661impl Eq for T32 {}
662
663impl Hash for T32 {
664 fn hash<H: Hasher>(&self, state: &mut H) {
665 self.0.float_hash(state)
666 }
667}
668
669impl Deref for T64 {
670 type Target = f64;
671
672 fn deref(&self) -> &Self::Target {
673 &self.0
674 }
675}
676
677impl PartialEq<Self> for T64 {
678 fn eq(&self, other: &Self) -> bool {
679 self.0.float_eq(&other.0)
680 }
681}
682
683impl Eq for T64 {}
684
685impl Hash for T64 {
686 fn hash<H: Hasher>(&self, state: &mut H) {
687 self.0.float_hash(state)
688 }
689}
690
691impl Specials {
692 pub fn new<T: IntoDScalar + num_traits::Zero + num_traits::One>(min: T, max: T) -> Self {
693 Self {
694 zero: T::zero().to_dscalar(),
695 one: T::one().to_dscalar(),
696 min: min.to_dscalar(),
697 max: max.to_dscalar(),
698 }
699 }
700}
701
702#[derive(Debug)]
703pub struct DisplayCFloat(pub f64);
704
705impl Display for DisplayCFloat {
706 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
707 let s = if self.0.is_sign_negative() { "-" } else { "" };
708
709 match self.0.classify() {
710 FpCategory::Nan => write!(f, "({s}(0.0/0.0))"),
711 FpCategory::Infinite => write!(f, "({s}(1.0/0.0))"),
712 FpCategory::Zero => write!(f, "({s}0.0)"),
713 FpCategory::Subnormal | FpCategory::Normal => write!(f, "{}", self.0),
714 }
715 }
716}
717
718impl Deref for DBool {
719 type Target = bool;
720
721 fn deref(&self) -> &Self::Target {
722 &self.0
723 }
724}
725
726impl std::ops::Add for DBool {
727 type Output = DBool;
728
729 fn add(self, rhs: Self) -> Self::Output {
730 DBool(self.0 || rhs.0)
731 }
732}
733
734impl std::ops::Mul for DBool {
735 type Output = DBool;
736
737 fn mul(self, rhs: Self) -> Self::Output {
738 DBool(self.0 && rhs.0)
739 }
740}
741
742impl std::ops::Sub for DBool {
744 type Output = DBool;
745
746 fn sub(self, rhs: Self) -> Self::Output {
747 DBool(self.0 && !rhs.0)
748 }
749}
750
751impl std::ops::Div for DBool {
752 type Output = DBool;
753
754 fn div(self, rhs: Self) -> Self::Output {
755 DBool(self.0 && !rhs.0)
756 }
757}
758
759impl num_traits::Zero for DBool {
760 fn zero() -> Self {
761 DBool(false)
762 }
763
764 fn is_zero(&self) -> bool {
765 !self.0
766 }
767}
768
769impl num_traits::One for DBool {
770 fn one() -> Self {
771 DBool(true)
772 }
773}
774
775unsafe impl NoUninit for DBool {}