1mod int64;
2use core::num;
3
4pub use int64::*;
5mod int16;
6pub use int16::*;
7mod int32;
8pub use int32::*;
9mod int8;
10pub use int8::*;
11mod float16;
12pub use float16::*;
13mod float32;
14pub use float32::*;
15mod float64;
16pub use float64::*;
17mod bfloat16;
18pub use bfloat16::*;
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22 #[error("Dequantize error: {0}")]
23 DequantizeError(String),
24 #[error("Shape mismatch: {0:?} {1:?}")]
25 ShapeMismatch(Vec<u64>, Vec<u64>),
26 #[error("Shape mismatch: {0:?} {1:?}")]
27 ShapeMismatch_(Vec<u64>, Vec<i64>),
28 #[error("Invalid shape: {0:?} {1:?}")]
29 InvalidShape(Vec<u64>, Vec<u64>),
30 #[error("Invalid shape: {0:?} {1:?}")]
31 InvalidShape_(Vec<u64>, Vec<i64>),
32 #[error("Other error: {0}")]
33 OtherError(String),
34 #[error("Invalid dim: {0}")]
35 InvalidDimension(i64),
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40pub trait TensorOps<T, I>: BaseTensorOps<Item = I>
41where
42 I: NumberOps,
43{
44 fn matmul(&self, rhs: &T) -> Result<T>;
45 fn item(&self) -> Result<Vec<I>>;
46 fn max(&self) -> Result<I>
47 where
48 I: std::cmp::PartialOrd + num_traits::float::FloatCore,
49 {
50 let mut it = self.item()?;
51 if it.len() == 0 {
52 return Err(Error::OtherError("What? empty".to_string()));
53 }
54 for e in it.iter() {
55 if e.is_nan() {
56 return Err(Error::OtherError("Why is there NaN?".to_string()));
57 }
58 }
59 it.sort_by(|a, b| a.partial_cmp(b).unwrap());
60 Ok((*it.first().ok_or(Error::OtherError("What?".to_string()))?).clone())
61 }
62 fn log(&self, i: I) -> Result<T>
63 where
64 I: num_traits::real::Real,
65 {
66 self.apply(
67 #[inline(always)]
68 |x| x.log(i),
69 )
70 }
71 fn ln(&self) -> Result<T>
72 where
73 I: num_traits::real::Real,
74 {
75 self.apply(
76 #[inline(always)]
77 |x| x.ln(),
78 )
79 }
80 fn add(&self, rhs: &T) -> Result<T> {
81 self.apply_xy(
82 rhs,
83 #[inline(always)]
84 |x, y| x.add(y),
85 )
86 }
87 fn add_item(&self, rhs: &Self::Item) -> Result<T> {
88 self.apply(
89 #[inline(always)]
90 |x| x.add((*rhs).clone()),
91 )
92 }
93 fn sub(&self, rhs: &T) -> Result<T> {
94 self.apply_xy(
95 rhs,
96 #[inline(always)]
97 |x, y| x.sub(y),
98 )
99 }
100 fn sub_item(&self, rhs: &Self::Item) -> Result<T> {
101 self.apply(
102 #[inline(always)]
103 |x| x.sub((*rhs).clone()),
104 )
105 }
106 fn mul(&self, rhs: &T) -> Result<T> {
107 self.apply_xy(
108 rhs,
109 #[inline(always)]
110 |x, y| x.mul(y),
111 )
112 }
113 fn mul_item(&self, rhs: &Self::Item) -> Result<T> {
114 self.apply(
115 #[inline(always)]
116 |x| x.mul((*rhs).clone()),
117 )
118 }
119 fn div_item(&self, rhs: &Self::Item) -> Result<T> {
120 self.apply(
121 #[inline(always)]
122 |x| x.div((*rhs).clone()),
123 )
124 }
125 fn div(&self, rhs: &T) -> Result<T> {
126 self.apply_xy(
127 rhs,
128 #[inline(always)]
129 |x, y| x.div(y),
130 )
131 }
132 fn sum(&self, dim: i64) -> Result<T>;
133 fn dim(&self, dim: i64) -> Result<u64> {
134 let shape = self.shape();
135 let index = if dim < 0 {
136 let index = shape.len() as i64 + dim;
137 if index < 0 {
138 return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
139 }
140 index as usize
141 } else {
142 dim as usize
143 };
144 if index >= shape.len() {
145 return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
146 }
147 Ok(shape[index])
148 }
149 fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> Result<T>;
150 fn apply_xy(&self, rhs: &T, f: impl Fn(Self::Item, Self::Item) -> Self::Item) -> Result<T>;
151 fn sqrt(&self) -> Result<T>
152 where
153 I: num_traits::real::Real,
154 {
155 self.apply(
156 #[inline(always)]
157 |x| x.sqrt(),
158 )
159 }
160 fn tanh(&self) -> Result<T>
161 where
162 I: num_traits::real::Real,
163 {
164 self.apply(
165 #[inline(always)]
166 |x| x.tanh(),
167 )
168 }
169 fn neg(&self) -> Result<T>
170 where
171 I: std::ops::Neg<Output = I>,
172 {
173 self.apply(
174 #[inline(always)]
175 |x| -x,
176 )
177 }
178 fn exp(&self) -> Result<T>
179 where
180 I: num_traits::real::Real,
181 {
182 self.apply(
183 #[inline(always)]
184 |x| x.exp(),
185 )
186 }
187 fn size(&self) -> Result<usize>;
188}
189
190pub trait NumberOps: num_traits::Num + Clone {
191 fn rand(len: usize, rng: &mut impl rand::Rng) -> Vec<Self>
192 where
193 Self: Sized;
194}
195
196impl NumberOps for f32 {
197 #[inline(always)]
198 fn rand(len: usize, rng: &mut impl rand::Rng) -> Vec<Self> {
199 (0..len).map(|_| rng.gen_range(0.0..1.0)).collect()
200 }
201}
202
203pub trait BaseTensorOps
204where
205 Self::Item: NumberOps + Clone,
206{
207 type Item;
208 fn shape(&self) -> &Vec<u64>;
209 fn reshape(&self, shape: Vec<i64>) -> Result<Self>
210 where
211 Self: Sized;
212 fn resolve_dim(&self, dim: i64) -> Result<u64> {
213 let shape = self.shape();
214 let index = if dim < 0 {
215 let index = shape.len() as i64 + dim;
216 if index < 0 {
217 return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
218 }
219 index as usize
220 } else {
221 dim as usize
222 };
223 if index >= shape.len() {
224 return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
225 }
226 Ok(index as u64)
227 }
228 fn resolve_shape(&self, shape2: Vec<i64>) -> Result<Vec<u64>> {
229 let shape = self.shape().clone();
230 let mut minus_index = None;
231 let size: u64 = shape.iter().product();
232 let mut new_shape = vec![0; shape2.len()];
233 for (i, a) in shape2.iter().enumerate() {
234 if minus_index.is_some() && *a == -1 {
235 return Err(Error::InvalidShape(shape.clone(), shape));
236 }
237 if *a == -1 {
238 minus_index = Some(i);
239 new_shape[i] = 1;
240 } else {
241 new_shape[i] = *a as u64;
242 }
243 }
244 if let Some(i) = minus_index {
245 new_shape[i] = size / new_shape.iter().product::<u64>();
246 }
247 Ok(new_shape)
248 }
249 fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> Result<Self>
250 where
251 Self: Sized;
252 fn zeros(shape: Vec<u64>) -> Result<Self>
253 where
254 Self: Sized,
255 Self::Item: num_traits::ConstZero,
256 {
257 use num_traits::ConstZero;
258 let values = vec![Self::Item::ZERO; shape.iter().product::<u64>() as usize];
259 Self::from_values(shape, values)
260 }
261 fn ones(shape: Vec<u64>) -> Result<Self>
262 where
263 Self: Sized,
264 Self::Item: num_traits::ConstOne,
265 {
266 use num_traits::ConstOne;
267 let values = vec![Self::Item::ONE; shape.iter().product::<u64>() as usize];
268 Self::from_values(shape, values)
269 }
270 fn of(shape: Vec<u64>, v: Self::Item) -> Result<Self>
271 where
272 Self: Sized,
273 {
274 let values = vec![v; shape.iter().product::<u64>() as usize];
275 Self::from_values(shape, values)
276 }
277 fn rand(shape: Vec<u64>, rng: &mut impl rand::Rng) -> Result<Self>
278 where
279 Self: Sized,
280 {
281 let size = shape.iter().product::<u64>() as usize;
282 Self::from_values(shape, Self::Item::rand(size, rng))
283 }
284}
285
286pub trait Dequantize<T> {
287 fn dequantize(&self) -> Result<T>;
288}
289
290impl<T, U, I> TensorOps<T, I> for U
291where
292 T: TensorOps<T, I>,
293 U: Dequantize<T> + BaseTensorOps<Item = I>,
294 I: NumberOps + Clone,
295{
296 fn item(&self) -> Result<Vec<T::Item>> {
297 self.dequantize()?.item()
298 }
299 fn add(&self, rhs: &T) -> Result<T> {
300 self.dequantize()?.add(rhs)
301 }
302 fn mul(&self, rhs: &T) -> Result<T> {
303 self.dequantize()?.mul(rhs)
304 }
305 fn sum(&self, dim: i64) -> Result<T> {
306 self.dequantize()?.sum(dim)
307 }
308 fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> Result<T> {
309 self.dequantize()?.apply(f)
310 }
311 fn apply_xy(&self, rhs: &T, f: impl Fn(Self::Item, Self::Item) -> Self::Item) -> Result<T> {
312 self.dequantize()?.apply_xy(rhs, f)
313 }
314 fn div(&self, rhs: &T) -> Result<T> {
315 self.dequantize()?.div(rhs)
316 }
317 fn size(&self) -> Result<usize> {
318 self.dequantize()?.size()
319 }
320 fn matmul(&self, rhs: &T) -> Result<T> {
321 self.dequantize()?.matmul(rhs)
322 }
323}