infa_impl/
lib.rs

1mod int64;
2use core::num;
3
4pub use int64::*;
5mod int16;
6pub use int16::*;
7mod int32;
8pub use int32::*;
9mod int8;
10pub use int8::*;
11mod float16;
12pub use float16::*;
13mod float32;
14pub use float32::*;
15mod float64;
16pub use float64::*;
17mod bfloat16;
18pub use bfloat16::*;
19
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22    #[error("Dequantize error: {0}")]
23    DequantizeError(String),
24    #[error("Shape mismatch: {0:?} {1:?}")]
25    ShapeMismatch(Vec<u64>, Vec<u64>),
26    #[error("Shape mismatch: {0:?} {1:?}")]
27    ShapeMismatch_(Vec<u64>, Vec<i64>),
28    #[error("Invalid shape: {0:?} {1:?}")]
29    InvalidShape(Vec<u64>, Vec<u64>),
30    #[error("Invalid shape: {0:?} {1:?}")]
31    InvalidShape_(Vec<u64>, Vec<i64>),
32    #[error("Other error: {0}")]
33    OtherError(String),
34    #[error("Invalid dim: {0}")]
35    InvalidDimension(i64),
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
39
40pub trait TensorOps<T, I>: BaseTensorOps<Item = I>
41where
42    I: NumberOps,
43{
44    fn matmul(&self, rhs: &T) -> Result<T>;
45    fn item(&self) -> Result<Vec<I>>;
46    fn max(&self) -> Result<I>
47    where
48        I: std::cmp::PartialOrd + num_traits::float::FloatCore,
49    {
50        let mut it = self.item()?;
51        if it.len() == 0 {
52            return Err(Error::OtherError("What? empty".to_string()));
53        }
54        for e in it.iter() {
55            if e.is_nan() {
56                return Err(Error::OtherError("Why is there NaN?".to_string()));
57            }
58        }
59        it.sort_by(|a, b| a.partial_cmp(b).unwrap());
60        Ok((*it.first().ok_or(Error::OtherError("What?".to_string()))?).clone())
61    }
62    fn log(&self, i: I) -> Result<T>
63    where
64        I: num_traits::real::Real,
65    {
66        self.apply(
67            #[inline(always)]
68            |x| x.log(i),
69        )
70    }
71    fn ln(&self) -> Result<T>
72    where
73        I: num_traits::real::Real,
74    {
75        self.apply(
76            #[inline(always)]
77            |x| x.ln(),
78        )
79    }
80    fn add(&self, rhs: &T) -> Result<T> {
81        self.apply_xy(
82            rhs,
83            #[inline(always)]
84            |x, y| x.add(y),
85        )
86    }
87    fn add_item(&self, rhs: &Self::Item) -> Result<T> {
88        self.apply(
89            #[inline(always)]
90            |x| x.add((*rhs).clone()),
91        )
92    }
93    fn sub(&self, rhs: &T) -> Result<T> {
94        self.apply_xy(
95            rhs,
96            #[inline(always)]
97            |x, y| x.sub(y),
98        )
99    }
100    fn sub_item(&self, rhs: &Self::Item) -> Result<T> {
101        self.apply(
102            #[inline(always)]
103            |x| x.sub((*rhs).clone()),
104        )
105    }
106    fn mul(&self, rhs: &T) -> Result<T> {
107        self.apply_xy(
108            rhs,
109            #[inline(always)]
110            |x, y| x.mul(y),
111        )
112    }
113    fn mul_item(&self, rhs: &Self::Item) -> Result<T> {
114        self.apply(
115            #[inline(always)]
116            |x| x.mul((*rhs).clone()),
117        )
118    }
119    fn div_item(&self, rhs: &Self::Item) -> Result<T> {
120        self.apply(
121            #[inline(always)]
122            |x| x.div((*rhs).clone()),
123        )
124    }
125    fn div(&self, rhs: &T) -> Result<T> {
126        self.apply_xy(
127            rhs,
128            #[inline(always)]
129            |x, y| x.div(y),
130        )
131    }
132    fn sum(&self, dim: i64) -> Result<T>;
133    fn dim(&self, dim: i64) -> Result<u64> {
134        let shape = self.shape();
135        let index = if dim < 0 {
136            let index = shape.len() as i64 + dim;
137            if index < 0 {
138                return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
139            }
140            index as usize
141        } else {
142            dim as usize
143        };
144        if index >= shape.len() {
145            return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
146        }
147        Ok(shape[index])
148    }
149    fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> Result<T>;
150    fn apply_xy(&self, rhs: &T, f: impl Fn(Self::Item, Self::Item) -> Self::Item) -> Result<T>;
151    fn sqrt(&self) -> Result<T>
152    where
153        I: num_traits::real::Real,
154    {
155        self.apply(
156            #[inline(always)]
157            |x| x.sqrt(),
158        )
159    }
160    fn tanh(&self) -> Result<T>
161    where
162        I: num_traits::real::Real,
163    {
164        self.apply(
165            #[inline(always)]
166            |x| x.tanh(),
167        )
168    }
169    fn neg(&self) -> Result<T>
170    where
171        I: std::ops::Neg<Output = I>,
172    {
173        self.apply(
174            #[inline(always)]
175            |x| -x,
176        )
177    }
178    fn exp(&self) -> Result<T>
179    where
180        I: num_traits::real::Real,
181    {
182        self.apply(
183            #[inline(always)]
184            |x| x.exp(),
185        )
186    }
187    fn size(&self) -> Result<usize>;
188}
189
190pub trait NumberOps: num_traits::Num + Clone {
191    fn rand(len: usize, rng: &mut impl rand::Rng) -> Vec<Self>
192    where
193        Self: Sized;
194}
195
196impl NumberOps for f32 {
197    #[inline(always)]
198    fn rand(len: usize, rng: &mut impl rand::Rng) -> Vec<Self> {
199        (0..len).map(|_| rng.gen_range(0.0..1.0)).collect()
200    }
201}
202
203pub trait BaseTensorOps
204where
205    Self::Item: NumberOps + Clone,
206{
207    type Item;
208    fn shape(&self) -> &Vec<u64>;
209    fn reshape(&self, shape: Vec<i64>) -> Result<Self>
210    where
211        Self: Sized;
212    fn resolve_dim(&self, dim: i64) -> Result<u64> {
213        let shape = self.shape();
214        let index = if dim < 0 {
215            let index = shape.len() as i64 + dim;
216            if index < 0 {
217                return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
218            }
219            index as usize
220        } else {
221            dim as usize
222        };
223        if index >= shape.len() {
224            return Err(Error::InvalidShape(shape.clone(), vec![index as u64]));
225        }
226        Ok(index as u64)
227    }
228    fn resolve_shape(&self, shape2: Vec<i64>) -> Result<Vec<u64>> {
229        let shape = self.shape().clone();
230        let mut minus_index = None;
231        let size: u64 = shape.iter().product();
232        let mut new_shape = vec![0; shape2.len()];
233        for (i, a) in shape2.iter().enumerate() {
234            if minus_index.is_some() && *a == -1 {
235                return Err(Error::InvalidShape(shape.clone(), shape));
236            }
237            if *a == -1 {
238                minus_index = Some(i);
239                new_shape[i] = 1;
240            } else {
241                new_shape[i] = *a as u64;
242            }
243        }
244        if let Some(i) = minus_index {
245            new_shape[i] = size / new_shape.iter().product::<u64>();
246        }
247        Ok(new_shape)
248    }
249    fn from_values(shape: Vec<u64>, values: Vec<Self::Item>) -> Result<Self>
250    where
251        Self: Sized;
252    fn zeros(shape: Vec<u64>) -> Result<Self>
253    where
254        Self: Sized,
255        Self::Item: num_traits::ConstZero,
256    {
257        use num_traits::ConstZero;
258        let values = vec![Self::Item::ZERO; shape.iter().product::<u64>() as usize];
259        Self::from_values(shape, values)
260    }
261    fn ones(shape: Vec<u64>) -> Result<Self>
262    where
263        Self: Sized,
264        Self::Item: num_traits::ConstOne,
265    {
266        use num_traits::ConstOne;
267        let values = vec![Self::Item::ONE; shape.iter().product::<u64>() as usize];
268        Self::from_values(shape, values)
269    }
270    fn of(shape: Vec<u64>, v: Self::Item) -> Result<Self>
271    where
272        Self: Sized,
273    {
274        let values = vec![v; shape.iter().product::<u64>() as usize];
275        Self::from_values(shape, values)
276    }
277    fn rand(shape: Vec<u64>, rng: &mut impl rand::Rng) -> Result<Self>
278    where
279        Self: Sized,
280    {
281        let size = shape.iter().product::<u64>() as usize;
282        Self::from_values(shape, Self::Item::rand(size, rng))
283    }
284}
285
286pub trait Dequantize<T> {
287    fn dequantize(&self) -> Result<T>;
288}
289
290impl<T, U, I> TensorOps<T, I> for U
291where
292    T: TensorOps<T, I>,
293    U: Dequantize<T> + BaseTensorOps<Item = I>,
294    I: NumberOps + Clone,
295{
296    fn item(&self) -> Result<Vec<T::Item>> {
297        self.dequantize()?.item()
298    }
299    fn add(&self, rhs: &T) -> Result<T> {
300        self.dequantize()?.add(rhs)
301    }
302    fn mul(&self, rhs: &T) -> Result<T> {
303        self.dequantize()?.mul(rhs)
304    }
305    fn sum(&self, dim: i64) -> Result<T> {
306        self.dequantize()?.sum(dim)
307    }
308    fn apply(&self, f: impl Fn(Self::Item) -> Self::Item) -> Result<T> {
309        self.dequantize()?.apply(f)
310    }
311    fn apply_xy(&self, rhs: &T, f: impl Fn(Self::Item, Self::Item) -> Self::Item) -> Result<T> {
312        self.dequantize()?.apply_xy(rhs, f)
313    }
314    fn div(&self, rhs: &T) -> Result<T> {
315        self.dequantize()?.div(rhs)
316    }
317    fn size(&self) -> Result<usize> {
318        self.dequantize()?.size()
319    }
320    fn matmul(&self, rhs: &T) -> Result<T> {
321        self.dequantize()?.matmul(rhs)
322    }
323}