redstone_ml/util/
dtype.rs

1use crate::linalg::matrix_ops::MatrixOps;
2use crate::ops::binary_op_add::BinaryOpAdd;
3use crate::ops::binary_op_div::BinaryOpDiv;
4use crate::ops::binary_op_mul::BinaryOpMul;
5use crate::ops::binary_op_sub::BinaryOpSub;
6use crate::ops::dot_product::DotProduct;
7use crate::ops::fill::Fill;
8use crate::ops::reduce_max::ReduceMax;
9use crate::ops::reduce_max_magnitude::ReduceMaxMagnitude;
10use crate::ops::reduce_min::ReduceMin;
11use crate::ops::reduce_min_magnitude::ReduceMinMagnitude;
12use crate::ops::reduce_product::ReduceProduct;
13use crate::ops::reduce_sum::ReduceSum;
14use crate::ops::unary_ops::UnaryOps;
15use crate::sum_of_products::SumOfProductsType;
16use num::traits::MulAdd;
17use num::{Float, NumCast, ToPrimitive};
18use rand::distributions::uniform::SampleUniform;
19use std::fmt::{Debug, Display};
20use std::iter::{Product, Sum};
21use std::ops::{Div, Neg, Sub, SubAssign};
22
23pub trait RawDataType: 'static + Default + Copy + Clone + Debug + Display + Sized
24+ PartialEq + Fill + Send + Sync {}
25
26impl RawDataType for u8 {}
27impl RawDataType for u16 {}
28impl RawDataType for u32 {}
29impl RawDataType for u64 {}
30impl RawDataType for u128 {}
31impl RawDataType for usize {}
32
33impl RawDataType for i8 {}
34impl RawDataType for i16 {}
35impl RawDataType for i32 {}
36impl RawDataType for i64 {}
37impl RawDataType for i128 {}
38impl RawDataType for isize {}
39
40impl RawDataType for f32 {}
41impl RawDataType for f64 {}
42
43impl RawDataType for bool {}
44
45pub trait NumericDataType: RawDataType + ToPrimitive + NumCast + From<bool>
46+ Sum + Product + SubAssign + Sub<Output=Self> + Div<Output=Self> + MulAdd<Output=Self> + DotProduct
47+ ReduceSum + ReduceProduct + ReduceMin + ReduceMax + ReduceMinMagnitude + ReduceMaxMagnitude
48+ BinaryOpAdd + BinaryOpSub + BinaryOpMul
49{
50    type AsFloatType: FloatDataType;
51
52    fn to_float(&self) -> Self::AsFloatType {
53        self.to_f32().unwrap().into()
54    }
55
56    fn ceil(&self) -> Self {
57        *self
58    }
59
60    fn floor(&self) -> Self {
61        *self
62    }
63}
64
65impl NumericDataType for u8 {
66    type AsFloatType = f32;
67}
68
69impl NumericDataType for u16 {
70    type AsFloatType = f32;
71}
72
73impl NumericDataType for u32 {
74    type AsFloatType = f32;
75}
76
77impl NumericDataType for u64 {
78    type AsFloatType = f64;
79}
80
81impl NumericDataType for u128 {
82    type AsFloatType = f64;
83}
84
85impl NumericDataType for usize {
86    type AsFloatType = f64;
87}
88
89impl NumericDataType for i8 {
90    type AsFloatType = f32;
91}
92
93impl NumericDataType for i16 {
94    type AsFloatType = f32;
95}
96
97impl NumericDataType for i32 {
98    type AsFloatType = f32;
99}
100
101impl NumericDataType for i64 {
102    type AsFloatType = f64;
103}
104
105impl NumericDataType for i128 {
106    type AsFloatType = f64;
107}
108
109impl NumericDataType for isize {
110    type AsFloatType = f64;
111}
112
113impl NumericDataType for f32 {
114    type AsFloatType = f32;
115
116    fn ceil(&self) -> Self {
117        num::Float::ceil(*self)
118    }
119
120    fn floor(&self) -> Self {
121        num::Float::floor(*self)
122    }
123}
124
125impl NumericDataType for f64 {
126    type AsFloatType = f64;
127
128    fn ceil(&self) -> Self {
129        num::Float::ceil(*self)
130    }
131
132    fn floor(&self) -> Self {
133        num::Float::floor(*self)
134    }
135}
136
137pub trait IntegerDataType: NumericDataType + Ord {}
138
139impl IntegerDataType for u8 {}
140impl IntegerDataType for u16 {}
141impl IntegerDataType for u32 {}
142impl IntegerDataType for u64 {}
143impl IntegerDataType for u128 {}
144impl IntegerDataType for usize {}
145
146impl IntegerDataType for i8 {}
147impl IntegerDataType for i16 {}
148impl IntegerDataType for i32 {}
149impl IntegerDataType for i64 {}
150impl IntegerDataType for i128 {}
151impl IntegerDataType for isize {}
152
153pub trait FloatDataType: NumericDataType + Float + From<f32> + SampleUniform + Neg<Output=Self>
154+ SumOfProductsType + MatrixOps + BinaryOpDiv + UnaryOps {}
155
156impl FloatDataType for f32 {}
157impl FloatDataType for f64 {}
158
159
160pub trait TensorDataType: FloatDataType {}
161impl<T: FloatDataType> TensorDataType for T {}