puffpastry/
tensor.rs

1use crate::vec_tools::ValidNumber;
2use std::fmt::Display;
3use std::ops::{Add, Div, Mul, Sub};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Tensor<T: ValidNumber<T>> {
7    pub shape: Vec<usize>,
8    pub data: Vec<T>,
9}
10
11/// Functions for all ranks
12impl<T: ValidNumber<T>> Tensor<T> {
13    pub fn new(shape: Vec<usize>) -> Tensor<T> {
14        let size = shape.iter().fold(1, |size, x| size * x);
15        Tensor {
16            shape: shape,
17            data: vec![T::from(0.0); size],
18        }
19    }
20
21    pub fn rank(&self) -> usize {
22        self.shape.len()
23    }
24
25    pub fn shape(&self) -> &Vec<usize> {
26        &self.shape
27    }
28
29    pub fn iter(&self) -> core::slice::Iter<'_, T> {
30        self.data.iter()
31    }
32
33    pub fn calculate_data_index(&self, loc: &[usize]) -> usize {
34        loc.into_iter()
35            .rev()
36            .enumerate()
37            .fold(0, |data_idx, (loc_idx, loc_val)| {
38                data_idx
39                    + loc_val
40                        * match loc_idx {
41                            0 => 1,
42                            _ => self
43                                .shape
44                                .clone()
45                                .iter()
46                                .rev()
47                                .take(loc_idx)
48                                .fold(1, |prod, x| prod * x),
49                        }
50            })
51    }
52
53    pub fn get(&self, loc: &[usize]) -> Option<&T> {
54        if loc.len() != self.rank() {
55            return None;
56        }
57
58        let idx = self.calculate_data_index(loc);
59        self.data.get(idx)
60    }
61
62    pub fn get_mut(&mut self, loc: &[usize]) -> Option<&mut T> {
63        if loc.len() != self.rank() {
64            return None;
65        }
66
67        let idx = self.calculate_data_index(loc);
68        self.data.get_mut(idx)
69    }
70
71    pub fn elementwise_product(&self, other: &Tensor<T>) -> Result<Tensor<T>, ()> {
72        if self.shape != other.shape {
73            return Err(());
74        }
75
76        let mut out = Tensor::new(self.shape.clone());
77        for i in 0..self.data.len() {
78            out.data[i] = self.data[i] * other.data[i];
79        }
80
81        Ok(out)
82    }
83}
84
85impl<T: ValidNumber<T>> Add<Tensor<T>> for Tensor<T> {
86    type Output = Tensor<T>;
87
88    fn add(self, rhs: Tensor<T>) -> Self::Output {
89        if self.shape != rhs.shape {
90            panic!("Cannot add tensors of different shape: lhs {self:?},  rhs: {rhs:?}")
91        }
92
93        Tensor {
94            shape: self.shape,
95            data: self
96                .data
97                .iter()
98                .zip(rhs.data.iter())
99                .map(|(x, y)| *x + *y)
100                .collect(),
101        }
102    }
103}
104
105impl<T: ValidNumber<T>> Sub<Tensor<T>> for Tensor<T> {
106    type Output = Tensor<T>;
107
108    fn sub(self, rhs: Tensor<T>) -> Self::Output {
109        if self.shape != rhs.shape {
110            panic!("Cannot subtract tensors of different shape: lhs {self:?},  rhs: {rhs:?}")
111        }
112
113        Tensor {
114            shape: self.shape,
115            data: self
116                .data
117                .iter()
118                .zip(rhs.data.iter())
119                .map(|(x, y)| *x - *y)
120                .collect(),
121        }
122    }
123}
124
125impl<T: ValidNumber<T>> Mul<T> for Tensor<T> {
126    type Output = Tensor<T>;
127
128    fn mul(self, rhs: T) -> Self::Output {
129        Tensor {
130            shape: self.shape,
131            data: self.data.iter().map(|x| *x * rhs).collect(),
132        }
133    }
134}
135
136impl<T: ValidNumber<T>> Div<T> for Tensor<T> {
137    type Output = Tensor<T>;
138
139    fn div(self, rhs: T) -> Self::Output {
140        if rhs == T::from(0.0) {
141            panic!("Dividing by 0!")
142        }
143
144        Tensor {
145            shape: self.shape,
146            data: self.data.iter().map(|x| *x * rhs).collect(),
147        }
148    }
149}
150
151impl<T: ValidNumber<T>> From<Vec<T>> for Tensor<T> {
152    fn from(value: Vec<T>) -> Self {
153        Tensor {
154            shape: vec![value.len()],
155            data: value,
156        }
157    }
158}
159
160impl<T: ValidNumber<T>> From<Vec<Vec<T>>> for Tensor<T> {
161    fn from(value: Vec<Vec<T>>) -> Self {
162        let shape = vec![value.len(), value[0].len()];
163        let data: Vec<T> = value.into_iter().fold(vec![], |mut data, mut x| {
164            data.append(&mut x);
165            data
166        });
167
168        Tensor { shape, data }
169    }
170}
171
172impl<T: ValidNumber<T>> From<Vec<Vec<Vec<T>>>> for Tensor<T> {
173    fn from(value: Vec<Vec<Vec<T>>>) -> Self {
174        let shape = vec![value.len(), value[0].len(), value[0][1].len()];
175        let mut data: Vec<T> = vec![];
176
177        for layer in value {
178            for row in layer {
179                for item in row {
180                    data.push(item)
181                }
182            }
183        }
184
185        Tensor { shape, data }
186    }
187}
188
189// Functions for rank 1
190
191impl<T: ValidNumber<T>> Tensor<T> {
192    pub fn dot_product(&self, other: &Tensor<T>) -> Result<T, ()> {
193        if self.rank() != 1 || (self.shape() != other.shape()) {
194            return Err(());
195        }
196
197        Ok(self
198            .iter()
199            .zip(other.iter())
200            .fold(T::from(0.0), |res, (s, o)| res + *s * *o))
201    }
202}
203
204impl<T: ValidNumber<T>> Tensor<T> {
205    pub fn row_count(&self) -> usize {
206        self.shape[0]
207    }
208
209    pub fn col_count(&self) -> usize {
210        self.shape[1]
211    }
212
213    pub fn get_2dims(&self) -> (usize, usize) {
214        if self.rank() != 2 {
215            panic!("only defined for rank 2 tensors!")
216        }
217        (self.shape[0], self.shape[1])
218    }
219
220    pub fn as_rows(&self) -> Vec<Tensor<T>> {
221        let (_, cols) = self.get_2dims();
222
223        self.data
224            .chunks_exact(cols)
225            .map(|x| Tensor::from(x.to_vec()))
226            .collect()
227    }
228
229    pub fn as_columns(&self) -> Vec<Tensor<T>> {
230        let (rows, cols) = self.get_2dims();
231
232        (0..cols)
233            .map(|x| {
234                self.data
235                    .iter()
236                    .skip(x)
237                    .step_by(cols)
238                    .take(rows)
239                    .cloned()
240                    .collect()
241            })
242            .map(|x: Vec<T>| Tensor::from(x))
243            .collect()
244    }
245
246    pub fn matrix_multiply(&self, other: &Tensor<T>) -> Result<Tensor<T>, ()> {
247        if self.shape[1] != other.shape[0] {
248            return Err(());
249        }
250
251        let new_data: Vec<Vec<T>> = self
252            .as_rows()
253            .iter()
254            .map(|row| {
255                other
256                    .as_columns()
257                    .iter()
258                    .map(|col| row.dot_product(col))
259                    .collect::<Result<Vec<T>, ()>>()
260            })
261            .collect::<Result<Vec<Vec<T>>, ()>>()?;
262
263        Ok(Tensor::from(new_data))
264    }
265
266    pub fn transposed(&self) -> Tensor<T> {
267        Tensor {
268            shape: vec![self.shape[1], self.shape[0]],
269            data: self.data.clone(),
270        }
271    }
272
273    pub fn column(data: Vec<T>) -> Tensor<T> {
274        let data: Vec<Vec<T>> = data.into_iter().map(|x| vec![x]).collect();
275
276        Tensor::from(data)
277    }
278}
279
280impl<T: ValidNumber<T>> Display for Tensor<T> {
281    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282        let (rows, cols) = self.get_2dims();
283
284        for row in 0..rows {
285            for col in 0..cols {
286                write!(f, "{:?} ", self.get(&[row, col]).unwrap())?;
287            }
288            writeln!(f)?;
289        }
290
291        Ok(())
292    }
293}
294
295#[cfg(test)]
296mod test {
297    use super::*;
298
299    fn get_generic_tensor2d() -> Tensor<f64> {
300        Tensor::<f64>::from(vec![
301            vec![1.0, 2.0, 3.0],
302            vec![4.0, 5.0, 6.0],
303            vec![7.0, 8.0, 9.0],
304        ])
305    }
306
307    #[test]
308    fn dot_product_test() {
309        let x = Tensor::<f64>::from(vec![1.0, 2.0, 3.0]);
310        let y = Tensor::<f64>::from(vec![1.0, 2.0, 3.0]);
311
312        let z = x.dot_product(&y);
313
314        assert_eq!(z.unwrap(), 14.0)
315    }
316
317    #[test]
318    fn elementwise2d_test() {
319        let x = Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
320
321        let y = Tensor::<f64>::from(vec![vec![1.0, 1.0, 1.0], vec![0.0, 1.0, 1.0]]);
322
323        let res = x
324            .elementwise_product(&y)
325            .expect("Incorrect Dimension Error");
326
327        assert_eq!(
328            Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 5.0, 6.0]]),
329            res
330        )
331    }
332
333    #[test]
334    fn incorrect_elementwise2d_test() {
335        let x = Tensor::<f64>::from(vec![vec![2.0, 3.0]]);
336        let y = Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0]]);
337
338        let res = x.elementwise_product(&y);
339
340        assert!(res.is_err())
341    }
342
343    #[test]
344    fn as_rows_test() {
345        let x = Tensor::<f64>::from(vec![
346            vec![1.0, 2.0, 3.0],
347            vec![4.0, 5.0, 6.0],
348            vec![7.0, 8.0, 9.0],
349        ]);
350
351        assert_eq!(x.as_rows()[0], Tensor::<f64>::from(vec![1.0, 2.0, 3.0]));
352        assert_eq!(x.as_rows()[1], Tensor::<f64>::from(vec![4.0, 5.0, 6.0]));
353        assert_eq!(x.as_rows()[2], Tensor::<f64>::from(vec![7.0, 8.0, 9.0]))
354    }
355
356    #[test]
357    fn tensor2d_matmul_1() {
358        let x = get_generic_tensor2d();
359        let y = get_generic_tensor2d();
360        let res = x.matrix_multiply(&y).expect("e");
361
362        println!("{res}");
363
364        assert_eq!(
365            Tensor::<f64>::from(vec![
366                vec![30.0, 36.0, 42.0],
367                vec![66.0, 81.0, 96.0],
368                vec![102.0, 126.0, 150.0]
369            ]),
370            res
371        )
372    }
373
374    #[test]
375    fn index_calc_test() {
376        let data: Vec<f64> = (0..=8).into_iter().map(|x| x as f64).collect();
377
378        let tensor1d = Tensor {
379            shape: vec![8],
380            data: data.clone(),
381        };
382
383        let tensor2d = Tensor {
384            shape: vec![2, 4],
385            data: data.clone(),
386        };
387
388        let tensor3d = Tensor {
389            shape: vec![2, 2, 2],
390            data: data.clone(),
391        };
392
393        assert_eq!(tensor1d.calculate_data_index(&[2]), 2);
394        assert_eq!(tensor2d.calculate_data_index(&[1, 1]), 5);
395        assert_eq!(tensor3d.calculate_data_index(&[1, 1, 0]), 6);
396    }
397
398    #[test]
399    fn main() {
400        let x = get_generic_tensor2d();
401        let y = get_generic_tensor2d();
402        let res = x - y;
403        println!("{res}");
404        assert!(true)
405    }
406}